xref: /aosp_15_r20/external/pytorch/test/test_autograd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport collections
4*da0073e9SAndroid Build Coastguard Workerimport contextlib
5*da0073e9SAndroid Build Coastguard Workerimport functools
6*da0073e9SAndroid Build Coastguard Workerimport gc
7*da0073e9SAndroid Build Coastguard Workerimport io
8*da0073e9SAndroid Build Coastguard Workerimport math
9*da0073e9SAndroid Build Coastguard Workerimport operator
10*da0073e9SAndroid Build Coastguard Workerimport os
11*da0073e9SAndroid Build Coastguard Workerimport pickle
12*da0073e9SAndroid Build Coastguard Workerimport random
13*da0073e9SAndroid Build Coastguard Workerimport subprocess
14*da0073e9SAndroid Build Coastguard Workerimport sys
15*da0073e9SAndroid Build Coastguard Workerimport tempfile
16*da0073e9SAndroid Build Coastguard Workerimport threading
17*da0073e9SAndroid Build Coastguard Workerimport time
18*da0073e9SAndroid Build Coastguard Workerimport unittest
19*da0073e9SAndroid Build Coastguard Workerimport uuid
20*da0073e9SAndroid Build Coastguard Workerimport warnings
21*da0073e9SAndroid Build Coastguard Workerimport weakref
22*da0073e9SAndroid Build Coastguard Workerfrom collections import OrderedDict
23*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
24*da0073e9SAndroid Build Coastguard Workerfrom functools import partial, reduce
25*da0073e9SAndroid Build Coastguard Workerfrom itertools import product
26*da0073e9SAndroid Build Coastguard Workerfrom operator import mul
27*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple, TYPE_CHECKING
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerimport torch
30*da0073e9SAndroid Build Coastguard Workerimport torch.autograd._functions
31*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD
32*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan, nn
33*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import (
34*da0073e9SAndroid Build Coastguard Worker    _calculate_shape,
35*da0073e9SAndroid Build Coastguard Worker    detect_anomaly,
36*da0073e9SAndroid Build Coastguard Worker    Function,
37*da0073e9SAndroid Build Coastguard Worker    kineto_available,
38*da0073e9SAndroid Build Coastguard Worker    Variable,
39*da0073e9SAndroid Build Coastguard Worker)
40*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.function import InplaceFunction, once_differentiable
41*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.graph import GradientEdge
42*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import emit_itt, emit_nvtx, profile, record_function
43*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler_util import (
44*da0073e9SAndroid Build Coastguard Worker    _format_time,
45*da0073e9SAndroid Build Coastguard Worker    EventList,
46*da0073e9SAndroid Build Coastguard Worker    FunctionEvent,
47*da0073e9SAndroid Build Coastguard Worker    FunctionEventAvg,
48*da0073e9SAndroid Build Coastguard Worker)
49*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor
50*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
51*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
52*da0073e9SAndroid Build Coastguard Worker    deviceCountAtLeast,
53*da0073e9SAndroid Build Coastguard Worker    dtypes,
54*da0073e9SAndroid Build Coastguard Worker    dtypesIfCUDA,
55*da0073e9SAndroid Build Coastguard Worker    dtypesIfMPS,
56*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
57*da0073e9SAndroid Build Coastguard Worker    onlyCPU,
58*da0073e9SAndroid Build Coastguard Worker    onlyCUDA,
59*da0073e9SAndroid Build Coastguard Worker    skipMeta,
60*da0073e9SAndroid Build Coastguard Worker)
61*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and
62*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import mask_not_all_zeros
63*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
64*da0073e9SAndroid Build Coastguard Worker    disable_gc,
65*da0073e9SAndroid Build Coastguard Worker    gradcheck,
66*da0073e9SAndroid Build Coastguard Worker    gradgradcheck,
67*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
68*da0073e9SAndroid Build Coastguard Worker    IS_MACOS,
69*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
70*da0073e9SAndroid Build Coastguard Worker    parametrize,
71*da0073e9SAndroid Build Coastguard Worker    run_tests,
72*da0073e9SAndroid Build Coastguard Worker    set_warn_always_context,
73*da0073e9SAndroid Build Coastguard Worker    skipIfMps,
74*da0073e9SAndroid Build Coastguard Worker    skipIfNoLapack,
75*da0073e9SAndroid Build Coastguard Worker    skipIfTorchDynamo,
76*da0073e9SAndroid Build Coastguard Worker    slowTest,
77*da0073e9SAndroid Build Coastguard Worker    TestCase,
78*da0073e9SAndroid Build Coastguard Worker    xfailIfTorchDynamo,
79*da0073e9SAndroid Build Coastguard Worker)
80*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._mode_utils import no_dispatch
81*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode
82*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import (
83*da0073e9SAndroid Build Coastguard Worker    checkpoint,
84*da0073e9SAndroid Build Coastguard Worker    checkpoint_sequential,
85*da0073e9SAndroid Build Coastguard Worker    CheckpointPolicy,
86*da0073e9SAndroid Build Coastguard Worker    create_selective_checkpoint_contexts,
87*da0073e9SAndroid Build Coastguard Worker)
88*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.cpp_extension import load_inline
89*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.flop_counter import FlopCounterMode
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker
92*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
93*da0073e9SAndroid Build Coastguard Worker    from torch.utils.hooks import RemovableHandle
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Workerdef graph_desc(fn):
97*da0073e9SAndroid Build Coastguard Worker    if fn is None:
98*da0073e9SAndroid Build Coastguard Worker        return "None"
99*da0073e9SAndroid Build Coastguard Worker    result = type(fn).__name__ + "("
100*da0073e9SAndroid Build Coastguard Worker    next_functions = fn.next_functions
101*da0073e9SAndroid Build Coastguard Worker    for next_fn, _ in next_functions:
102*da0073e9SAndroid Build Coastguard Worker        result += graph_desc(next_fn)
103*da0073e9SAndroid Build Coastguard Worker        result += ", "
104*da0073e9SAndroid Build Coastguard Worker    if next_functions:
105*da0073e9SAndroid Build Coastguard Worker        result = result[:-2]
106*da0073e9SAndroid Build Coastguard Worker    return result + ")"
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Workerclass TestAutograd(TestCase):
110*da0073e9SAndroid Build Coastguard Worker    def test_copy_slices_graph_task_updates(self):
111*da0073e9SAndroid Build Coastguard Worker        def f1(x, y):
112*da0073e9SAndroid Build Coastguard Worker            out = x.clone().view(-1)
113*da0073e9SAndroid Build Coastguard Worker            out += y
114*da0073e9SAndroid Build Coastguard Worker            return out
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        def f2(x, y):
117*da0073e9SAndroid Build Coastguard Worker            out = x.clone().view(-1)
118*da0073e9SAndroid Build Coastguard Worker            b = out * 2
119*da0073e9SAndroid Build Coastguard Worker            out += y
120*da0073e9SAndroid Build Coastguard Worker            return out + b
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True)
123*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, requires_grad=True)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        y_safe = torch._C._functions.DelayedError("Boom!", 1)(y)
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        for f in [f1, f2]:
128*da0073e9SAndroid Build Coastguard Worker            # Ensure that the error Node works
129*da0073e9SAndroid Build Coastguard Worker            out = f(x, y_safe)
130*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Boom!"):
131*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker            out = f(x, y_safe)
134*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Boom!"):
135*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(out.sum(), y)
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            # Ensure that if we don't ask for y, it doesn't crash
138*da0073e9SAndroid Build Coastguard Worker            out = f(x, y_safe)
139*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out.sum(), x)
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker            out = f(x, y_safe)
142*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out.sum(), y_safe)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker            out = f(x, y_safe)
145*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out.sum(), (x, y_safe))
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker        # Ensure that we don't run extra view Node
148*da0073e9SAndroid Build Coastguard Worker        def f3(x, y):
149*da0073e9SAndroid Build Coastguard Worker            out = x.clone().view(-1)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker            def hook(*args):
152*da0073e9SAndroid Build Coastguard Worker                # This should never be called!
153*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(False)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker            out.register_hook(hook)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker            b = out + y
158*da0073e9SAndroid Build Coastguard Worker            out += y
159*da0073e9SAndroid Build Coastguard Worker            return out + b, b
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker        out, b = f3(x, y_safe)
162*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(out.sum(), (b, y_safe))
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker    def test_grad_mode_class_decoration(self):
165*da0073e9SAndroid Build Coastguard Worker        # Decorating class is deprecated and should not be used
166*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"):
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker            @torch.no_grad()
169*da0073e9SAndroid Build Coastguard Worker            class Foo:
170*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
171*da0073e9SAndroid Build Coastguard Worker                    assert not torch.is_grad_enabled()
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker                def foo(self):
174*da0073e9SAndroid Build Coastguard Worker                    # Not applied to methods
175*da0073e9SAndroid Build Coastguard Worker                    assert torch.is_grad_enabled()
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker            # Show that we can actually construct the class
178*da0073e9SAndroid Build Coastguard Worker            foo = Foo()
179*da0073e9SAndroid Build Coastguard Worker            foo.foo()
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker        # Decorating functions or methods is fine though
182*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker            @torch.no_grad()
185*da0073e9SAndroid Build Coastguard Worker            def foo():
186*da0073e9SAndroid Build Coastguard Worker                assert not torch.is_grad_enabled()
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker            foo()
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker            class Foo2:
191*da0073e9SAndroid Build Coastguard Worker                @torch.no_grad()
192*da0073e9SAndroid Build Coastguard Worker                def __init__(self) -> None:
193*da0073e9SAndroid Build Coastguard Worker                    assert not torch.is_grad_enabled()
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker                @torch.no_grad()
196*da0073e9SAndroid Build Coastguard Worker                def foo(self):
197*da0073e9SAndroid Build Coastguard Worker                    assert not torch.is_grad_enabled()
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker            foo2 = Foo2()
200*da0073e9SAndroid Build Coastguard Worker            foo2.foo()
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), 0)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    def test_tensor_grad_warnings(self):
205*da0073e9SAndroid Build Coastguard Worker        dummy = torch.empty(1)
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
208*da0073e9SAndroid Build Coastguard Worker            # Accessing .grad on leaf
209*da0073e9SAndroid Build Coastguard Worker            dummy.requires_grad_()
210*da0073e9SAndroid Build Coastguard Worker            foo = dummy.grad
211*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 0)
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker            # Accessing .grad on non-leaf
214*da0073e9SAndroid Build Coastguard Worker            dummy = dummy.clone()
215*da0073e9SAndroid Build Coastguard Worker            foo = dummy.grad
216*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker            # Accessing .grad on non-leaf that retains gradients
219*da0073e9SAndroid Build Coastguard Worker            dummy.retain_grad()
220*da0073e9SAndroid Build Coastguard Worker            foo = dummy.grad
221*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(w), 1)
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker    def _function_test(self, cls):
224*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
225*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
226*da0073e9SAndroid Build Coastguard Worker        result = cls.apply(x, 2, y)
227*da0073e9SAndroid Build Coastguard Worker        go = torch.ones((), requires_grad=True)
228*da0073e9SAndroid Build Coastguard Worker        result.sum().backward(go, create_graph=True)
229*da0073e9SAndroid Build Coastguard Worker
230*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, y + torch.ones(5, 5))
231*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, x + torch.ones(5, 5) * 2)
232*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(x.grad.grad_fn)
233*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(y.grad.grad_fn)
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        return x, y
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker    def test_function(self):
238*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
239*da0073e9SAndroid Build Coastguard Worker            @staticmethod
240*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, tensor1, pyscalar, tensor2):
241*da0073e9SAndroid Build Coastguard Worker                ctx.pyscalar = pyscalar
242*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(tensor1, tensor2)
243*da0073e9SAndroid Build Coastguard Worker                return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker            @staticmethod
246*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
247*da0073e9SAndroid Build Coastguard Worker                var1, var2 = ctx.saved_tensors
248*da0073e9SAndroid Build Coastguard Worker                # NOTE: self is the test case here
249*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(var1, torch.Tensor)
250*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(var2, torch.Tensor)
251*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(grad_output, torch.Tensor)
252*da0073e9SAndroid Build Coastguard Worker                return (
253*da0073e9SAndroid Build Coastguard Worker                    grad_output + grad_output * var2,
254*da0073e9SAndroid Build Coastguard Worker                    None,
255*da0073e9SAndroid Build Coastguard Worker                    grad_output * ctx.pyscalar + grad_output * var1,
256*da0073e9SAndroid Build Coastguard Worker                )
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        x, y = self._function_test(MyFunction)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        x_grad_desc = graph_desc(x.grad.grad_fn)
261*da0073e9SAndroid Build Coastguard Worker        y_grad_desc = graph_desc(y.grad.grad_fn)
262*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(x_grad_desc, "x_grad_desc")
263*da0073e9SAndroid Build Coastguard Worker        self.assertExpected(y_grad_desc, "y_grad_desc")
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    def test_once_differentiable(self):
266*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
267*da0073e9SAndroid Build Coastguard Worker            @staticmethod
268*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, tensor1, pyscalar, tensor2):
269*da0073e9SAndroid Build Coastguard Worker                ctx.pyscalar = pyscalar
270*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(tensor1, tensor2)
271*da0073e9SAndroid Build Coastguard Worker                return tensor1 + pyscalar * tensor2 + tensor1 * tensor2
272*da0073e9SAndroid Build Coastguard Worker
273*da0073e9SAndroid Build Coastguard Worker            @staticmethod
274*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
275*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
276*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_grad_enabled())
277*da0073e9SAndroid Build Coastguard Worker                t1, t2 = ctx.saved_tensors
278*da0073e9SAndroid Build Coastguard Worker                return (
279*da0073e9SAndroid Build Coastguard Worker                    grad_output + grad_output * t2,
280*da0073e9SAndroid Build Coastguard Worker                    None,
281*da0073e9SAndroid Build Coastguard Worker                    grad_output * ctx.pyscalar + grad_output * t1,
282*da0073e9SAndroid Build Coastguard Worker                )
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker        x, y = self._function_test(MyFunction)
285*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
286*da0073e9SAndroid Build Coastguard Worker            graph_desc(x.grad.grad_fn),
287*da0073e9SAndroid Build Coastguard Worker            "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
288*da0073e9SAndroid Build Coastguard Worker        )
289*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
290*da0073e9SAndroid Build Coastguard Worker            graph_desc(y.grad.grad_fn),
291*da0073e9SAndroid Build Coastguard Worker            "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))",
292*da0073e9SAndroid Build Coastguard Worker        )
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    def test_function_returns_input(self):
295*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
296*da0073e9SAndroid Build Coastguard Worker            @staticmethod
297*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
298*da0073e9SAndroid Build Coastguard Worker                return x
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker            @staticmethod
301*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
302*da0073e9SAndroid Build Coastguard Worker                return grad * 2
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker        for shape in [(1,), ()]:
305*da0073e9SAndroid Build Coastguard Worker            v = torch.ones(shape, requires_grad=True)
306*da0073e9SAndroid Build Coastguard Worker            MyFunction.apply(v).backward()
307*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v.grad, torch.full(shape, 2.0))
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
310*da0073e9SAndroid Build Coastguard Worker                v.grad.zero_()
311*da0073e9SAndroid Build Coastguard Worker            MyFunction.apply(v.clone()).backward()
312*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(v.grad, torch.full(shape, 2.0))
313*da0073e9SAndroid Build Coastguard Worker
314*da0073e9SAndroid Build Coastguard Worker    def test_function_returns_undefined_tensor(self):
315*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
316*da0073e9SAndroid Build Coastguard Worker            @staticmethod
317*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
318*da0073e9SAndroid Build Coastguard Worker                return x * 2
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker            @staticmethod
321*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
322*da0073e9SAndroid Build Coastguard Worker                return None
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        # Test that undefined tensors returned from custom backward function
325*da0073e9SAndroid Build Coastguard Worker        # are propagated as undefined and not tensor full of zeroes
326*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker        MyFunction.apply(x).backward()
329*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x.grad)
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        MyFunction.apply(x**2).backward()
332*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x.grad)
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Worker        MyFunction.apply(x).sum().backward()
335*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x.grad)
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(
338*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(MyFunction.apply(x), x, allow_unused=True)[0]
339*da0073e9SAndroid Build Coastguard Worker        )
340*da0073e9SAndroid Build Coastguard Worker
341*da0073e9SAndroid Build Coastguard Worker    def test_materialize_grads(self):
342*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
343*da0073e9SAndroid Build Coastguard Worker            @staticmethod
344*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
345*da0073e9SAndroid Build Coastguard Worker                return x
346*da0073e9SAndroid Build Coastguard Worker
347*da0073e9SAndroid Build Coastguard Worker            @staticmethod
348*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
349*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, torch.zeros(1))
350*da0073e9SAndroid Build Coastguard Worker                return grad
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
353*da0073e9SAndroid Build Coastguard Worker        torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward()
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker    def test_dont_materialize_grads(self):
356*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
357*da0073e9SAndroid Build Coastguard Worker            @staticmethod
358*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
359*da0073e9SAndroid Build Coastguard Worker                ctx.set_materialize_grads(False)
360*da0073e9SAndroid Build Coastguard Worker                return x
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker            @staticmethod
363*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
364*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(grad)
365*da0073e9SAndroid Build Coastguard Worker                return grad
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
368*da0073e9SAndroid Build Coastguard Worker        torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward()
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
371*da0073e9SAndroid Build Coastguard Worker    def test_set_materialize_non_diff_grads(self):
372*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
373*da0073e9SAndroid Build Coastguard Worker            @staticmethod
374*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
375*da0073e9SAndroid Build Coastguard Worker                out0 = x.clone()
376*da0073e9SAndroid Build Coastguard Worker                out1 = x.clone()
377*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(out1)
378*da0073e9SAndroid Build Coastguard Worker                ctx._materialize_non_diff_grads = False
379*da0073e9SAndroid Build Coastguard Worker                return out0, out1
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker            @staticmethod
382*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, g0, g1):
383*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(g1)
384*da0073e9SAndroid Build Coastguard Worker                return g0
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
387*da0073e9SAndroid Build Coastguard Worker        out = Func.apply(a)[0]
388*da0073e9SAndroid Build Coastguard Worker        out.backward()
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker    def test_legacy_function_deprecation_exception(self):
391*da0073e9SAndroid Build Coastguard Worker        # Trigger exception
392*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
393*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
394*da0073e9SAndroid Build Coastguard Worker                return x
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker            def backward(self, grad_output):
397*da0073e9SAndroid Build Coastguard Worker                return grad_output
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker        # Check exception occurs
400*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
401*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
402*da0073e9SAndroid Build Coastguard Worker            "Legacy autograd function with non-static forward method is deprecated",
403*da0073e9SAndroid Build Coastguard Worker        ):
404*da0073e9SAndroid Build Coastguard Worker            MyFunction()(torch.randn(3, 4))
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    class SimulateBackwardError(Function):
407*da0073e9SAndroid Build Coastguard Worker        @staticmethod
408*da0073e9SAndroid Build Coastguard Worker        def forward(ctx, input):
409*da0073e9SAndroid Build Coastguard Worker            return input.clone()
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        @staticmethod
412*da0073e9SAndroid Build Coastguard Worker        @once_differentiable
413*da0073e9SAndroid Build Coastguard Worker        def backward(ctx, input):
414*da0073e9SAndroid Build Coastguard Worker            raise Exception("Simulate error on backward pass")  # noqa: TRY002
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_exception(self):
417*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand((3, 3), requires_grad=True)
418*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand((3, 3), requires_grad=True)
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        tmp = (t1 + t2) * (t1 + t2)
421*da0073e9SAndroid Build Coastguard Worker        t3 = TestAutograd.SimulateBackwardError.apply(tmp)
422*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Simulate error on backward pass"):
423*da0073e9SAndroid Build Coastguard Worker            t3.sum().backward()
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_non_tensor_inputs_outputs(self):
426*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
427*da0073e9SAndroid Build Coastguard Worker            @staticmethod
428*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, t1, t2, scale, t3):
429*da0073e9SAndroid Build Coastguard Worker                t4 = t1 + t2 * t3
430*da0073e9SAndroid Build Coastguard Worker                t5 = t1 * t2 + t3
431*da0073e9SAndroid Build Coastguard Worker                t4 *= scale
432*da0073e9SAndroid Build Coastguard Worker                t5 *= scale
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker                # Save scale
435*da0073e9SAndroid Build Coastguard Worker                ctx.scale = scale
436*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(t1, t2, t3)
437*da0073e9SAndroid Build Coastguard Worker                return scale, t4, None, True, t5, "bar", t1
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker            @staticmethod
440*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
441*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, *grads):
442*da0073e9SAndroid Build Coastguard Worker                # Verify grads
443*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(7, len(grads))
444*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(grads[0])
445*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(grads[2])
446*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(grads[3])
447*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(grads[5])
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker                scale = ctx.scale
450*da0073e9SAndroid Build Coastguard Worker                var1, var2, var3 = ctx.saved_tensors
451*da0073e9SAndroid Build Coastguard Worker                return (
452*da0073e9SAndroid Build Coastguard Worker                    grads[1] * scale + grads[4] * var2 * scale + grads[6],
453*da0073e9SAndroid Build Coastguard Worker                    grads[1] * var3 * scale + grads[4] * var1 * scale,
454*da0073e9SAndroid Build Coastguard Worker                    None,
455*da0073e9SAndroid Build Coastguard Worker                    grads[1] * var2 * scale + grads[4] * scale,
456*da0073e9SAndroid Build Coastguard Worker                )
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(10, dtype=torch.double, requires_grad=True)
459*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(10, dtype=torch.double, requires_grad=True)
460*da0073e9SAndroid Build Coastguard Worker        t3 = torch.rand(10, dtype=torch.double)
461*da0073e9SAndroid Build Coastguard Worker        scale = random.randint(0, 10)
462*da0073e9SAndroid Build Coastguard Worker        res = MyFunction.apply(t1, t2, scale, t3)
463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, res[0])
464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 + t2 * t3) * scale, res[1])
465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, res[2])
466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(True, res[3])
467*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 * t2 + t3) * scale, res[4])
468*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("bar", res[5])
469*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, res[6])
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker        # Validate running backward.
472*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()])
473*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(t1.grad)
474*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(t2.grad)
475*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(t3.grad)
476*da0073e9SAndroid Build Coastguard Worker
477*da0073e9SAndroid Build Coastguard Worker        # Test gradcheck
478*da0073e9SAndroid Build Coastguard Worker        def foo(t1, t2, t3):
479*da0073e9SAndroid Build Coastguard Worker            res = MyFunction.apply(t1, t2, scale, t3)
480*da0073e9SAndroid Build Coastguard Worker            return res[1], res[4], res[6]
481*da0073e9SAndroid Build Coastguard Worker
482*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, (t1, t2, t3))
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_no_tensors(self):
485*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
486*da0073e9SAndroid Build Coastguard Worker            @staticmethod
487*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, t1, t2, scale, t3):
488*da0073e9SAndroid Build Coastguard Worker                t4 = t1 + t2 * t3
489*da0073e9SAndroid Build Coastguard Worker                t5 = t1 * t2 + t3
490*da0073e9SAndroid Build Coastguard Worker                t4 *= scale
491*da0073e9SAndroid Build Coastguard Worker                t5 *= scale
492*da0073e9SAndroid Build Coastguard Worker                return scale, t4, None, True, t5, "bar", t1
493*da0073e9SAndroid Build Coastguard Worker
494*da0073e9SAndroid Build Coastguard Worker            @staticmethod
495*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
496*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, *args):
497*da0073e9SAndroid Build Coastguard Worker                return (args[0], args[1], None, args[2])
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker        t1 = random.random()
500*da0073e9SAndroid Build Coastguard Worker        t2 = random.random()
501*da0073e9SAndroid Build Coastguard Worker        t3 = random.random()
502*da0073e9SAndroid Build Coastguard Worker        scale = random.randint(0, 10)
503*da0073e9SAndroid Build Coastguard Worker        res = MyFunction.apply(t1, t2, scale, t3)
504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(scale, res[0])
505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 + t2 * t3) * scale, res[1])
506*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(None, res[2])
507*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(True, res[3])
508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((t1 * t2 + t3) * scale, res[4])
509*da0073e9SAndroid Build Coastguard Worker        self.assertEqual("bar", res[5])
510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t1, res[6])
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    def test_invalid_gradients(self):
513*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
514*da0073e9SAndroid Build Coastguard Worker            @staticmethod
515*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
516*da0073e9SAndroid Build Coastguard Worker                return x * 2
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker            @staticmethod
519*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
520*da0073e9SAndroid Build Coastguard Worker                return torch.randn(10, dtype=torch.float)
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expected shape"):
523*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(5, 5, dtype=torch.float, requires_grad=True)
524*da0073e9SAndroid Build Coastguard Worker            MyFunction.apply(input).sum().backward()
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker    def test_unrelated_inputs(self):
527*da0073e9SAndroid Build Coastguard Worker        # test to ensure grad(grad)check runs successfully even if there is an
528*da0073e9SAndroid Build Coastguard Worker        # unrelated (but differentiable) inputs
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker        def my_function(x, y):
531*da0073e9SAndroid Build Coastguard Worker            return x * x
532*da0073e9SAndroid Build Coastguard Worker
533*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, dtype=torch.double, requires_grad=True)
534*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(10, dtype=torch.double, requires_grad=True)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        gradcheck(my_function, (x, y))
537*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(my_function, (x, y))
538*da0073e9SAndroid Build Coastguard Worker
539*da0073e9SAndroid Build Coastguard Worker    def test_not_implemented_grad(self):
540*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, requires_grad=True)
541*da0073e9SAndroid Build Coastguard Worker        # if grad for nextafter ends up being implemented, this should be changed
542*da0073e9SAndroid Build Coastguard Worker        y = torch.nextafter(a, a).sum()
543*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
544*da0073e9SAndroid Build Coastguard Worker            NotImplementedError, "the derivative for .* is not implemented"
545*da0073e9SAndroid Build Coastguard Worker        ):
546*da0073e9SAndroid Build Coastguard Worker            y.backward()
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker    def test_not_implemented_fwad(self):
549*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
550*da0073e9SAndroid Build Coastguard Worker        v = torch.rand(3)
551*da0073e9SAndroid Build Coastguard Worker
552*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
553*da0073e9SAndroid Build Coastguard Worker            dual_x = fwAD.make_dual(x, v)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker            err_msg = r"Trying to use forward AD with .* that does not support it"
556*da0073e9SAndroid Build Coastguard Worker            hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError"
557*da0073e9SAndroid Build Coastguard Worker
558*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
559*da0073e9SAndroid Build Coastguard Worker                # if forward AD ends up being implemented for torch.igamma, choose a different op
560*da0073e9SAndroid Build Coastguard Worker                torch.igamma(dual_x, dual_x)
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker    def test_saved_tensor_hooks_extra_exit_during_bw_no_crash(self):
563*da0073e9SAndroid Build Coastguard Worker        # This usage of saved tensor is not supported, but should not crash
564*da0073e9SAndroid Build Coastguard Worker        def unpack(x):
565*da0073e9SAndroid Build Coastguard Worker            ctx_1.__exit__()
566*da0073e9SAndroid Build Coastguard Worker            return x
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker        ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
569*da0073e9SAndroid Build Coastguard Worker        ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x)
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
572*da0073e9SAndroid Build Coastguard Worker            with ctx_2:
573*da0073e9SAndroid Build Coastguard Worker                ctx_1.__enter__()
574*da0073e9SAndroid Build Coastguard Worker                x = torch.randn(3, 3, requires_grad=True)
575*da0073e9SAndroid Build Coastguard Worker                x.sin().sum().backward()
576*da0073e9SAndroid Build Coastguard Worker
577*da0073e9SAndroid Build Coastguard Worker        # Clean up
578*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
579*da0073e9SAndroid Build Coastguard Worker            ctx_1.__exit__()
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker        # Validate there are no more hooks on the stack
582*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
583*da0073e9SAndroid Build Coastguard Worker        y = a.exp()
584*da0073e9SAndroid Build Coastguard Worker        y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
585*da0073e9SAndroid Build Coastguard Worker
586*da0073e9SAndroid Build Coastguard Worker    def test_saved_tensor_hooks_extra_enter_during_bw_no_leak(self):
587*da0073e9SAndroid Build Coastguard Worker        # This usage of saved tensor is not supported, but should not leak
588*da0073e9SAndroid Build Coastguard Worker        def scope():
589*da0073e9SAndroid Build Coastguard Worker            def unpack(x):
590*da0073e9SAndroid Build Coastguard Worker                weak_ctx_1().__enter__()
591*da0073e9SAndroid Build Coastguard Worker                return x
592*da0073e9SAndroid Build Coastguard Worker
593*da0073e9SAndroid Build Coastguard Worker            ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack)
594*da0073e9SAndroid Build Coastguard Worker            weak_ctx_1 = weakref.ref(ctx_1)
595*da0073e9SAndroid Build Coastguard Worker
596*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 3, requires_grad=True)
597*da0073e9SAndroid Build Coastguard Worker            with ctx_1:
598*da0073e9SAndroid Build Coastguard Worker                x.sin().sum().backward()
599*da0073e9SAndroid Build Coastguard Worker            return weakref.ref(unpack)
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
602*da0073e9SAndroid Build Coastguard Worker            unpack_hook_ref = scope()
603*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(unpack_hook_ref())
604*da0073e9SAndroid Build Coastguard Worker
605*da0073e9SAndroid Build Coastguard Worker    def test_will_engine_execute_node(self):
606*da0073e9SAndroid Build Coastguard Worker        counter = [0]
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
609*da0073e9SAndroid Build Coastguard Worker            @staticmethod
610*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
611*da0073e9SAndroid Build Coastguard Worker                return x * 2
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker            @staticmethod
614*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
615*da0073e9SAndroid Build Coastguard Worker                return gO * 2
616*da0073e9SAndroid Build Coastguard Worker
617*da0073e9SAndroid Build Coastguard Worker        def get_grad_fn(t):
618*da0073e9SAndroid Build Coastguard Worker            if t.requires_grad and t.grad_fn is None:
619*da0073e9SAndroid Build Coastguard Worker                return t.clone().grad_fn.next_functions[0][0]
620*da0073e9SAndroid Build Coastguard Worker            else:
621*da0073e9SAndroid Build Coastguard Worker                return t.grad_fn
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4, requires_grad=True)
624*da0073e9SAndroid Build Coastguard Worker        a2 = torch.randn(2, 3, 4, requires_grad=True)
625*da0073e9SAndroid Build Coastguard Worker        b = a * a2
626*da0073e9SAndroid Build Coastguard Worker        b2 = b.cos()
627*da0073e9SAndroid Build Coastguard Worker        c = MyFunction.apply(b)
628*da0073e9SAndroid Build Coastguard Worker
629*da0073e9SAndroid Build Coastguard Worker        should_execute = list(map(get_grad_fn, (a, b, c)))
630*da0073e9SAndroid Build Coastguard Worker        should_not_execute = list(map(get_grad_fn, (a2, b2)))
631*da0073e9SAndroid Build Coastguard Worker
632*da0073e9SAndroid Build Coastguard Worker        def fn(x):
633*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker            for g in should_execute:
636*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch._C._will_engine_execute_node(g))
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker            for g in should_not_execute:
639*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch._C._will_engine_execute_node(g))
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker        b.register_hook(fn)
642*da0073e9SAndroid Build Coastguard Worker        c.register_hook(fn)
643*da0073e9SAndroid Build Coastguard Worker
644*da0073e9SAndroid Build Coastguard Worker        # .backward(inputs=) is OK
645*da0073e9SAndroid Build Coastguard Worker        out = c.sum()
646*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(out, inputs=(a, b), retain_graph=True)
647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 2)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker        # .backward() is OK
650*da0073e9SAndroid Build Coastguard Worker        should_execute = list(map(get_grad_fn, (a, a2, b, c)))
651*da0073e9SAndroid Build Coastguard Worker        should_not_execute = list(map(get_grad_fn, (b2,)))
652*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(out, retain_graph=True)
653*da0073e9SAndroid Build Coastguard Worker
654*da0073e9SAndroid Build Coastguard Worker        # .grad is NOT OK when leaf is passed (this is the current state, subject to change)
655*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
656*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "are currently running autograd.grad()"
657*da0073e9SAndroid Build Coastguard Worker        ):
658*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out, (a,))
659*da0073e9SAndroid Build Coastguard Worker
660*da0073e9SAndroid Build Coastguard Worker        # .grad is OK when non-leaf is passed
661*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2, 3, requires_grad=True) * 2
662*da0073e9SAndroid Build Coastguard Worker        b = a * 2
663*da0073e9SAndroid Build Coastguard Worker
664*da0073e9SAndroid Build Coastguard Worker        def fn(x):
665*da0073e9SAndroid Build Coastguard Worker            # Check a non-leaf
666*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
667*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn))
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker        b.register_hook(fn)
670*da0073e9SAndroid Build Coastguard Worker        counter[0] = 0
671*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(b.sum(), (a,))
672*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
673*da0073e9SAndroid Build Coastguard Worker
674*da0073e9SAndroid Build Coastguard Worker        # Verify other errors are raised
675*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "during the backward pass"):
676*da0073e9SAndroid Build Coastguard Worker            torch._C._will_engine_execute_node(out.grad_fn)
677*da0073e9SAndroid Build Coastguard Worker
678*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"):
679*da0073e9SAndroid Build Coastguard Worker            torch._C._will_engine_execute_node(out)
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_vmap_defaults(self):
682*da0073e9SAndroid Build Coastguard Worker        class MySquare(Function):
683*da0073e9SAndroid Build Coastguard Worker            @staticmethod
684*da0073e9SAndroid Build Coastguard Worker            def forward(x):
685*da0073e9SAndroid Build Coastguard Worker                return x**2
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker            @staticmethod
688*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output):
689*da0073e9SAndroid Build Coastguard Worker                (x,) = inputs
690*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
691*da0073e9SAndroid Build Coastguard Worker
692*da0073e9SAndroid Build Coastguard Worker            @staticmethod
693*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
694*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
695*da0073e9SAndroid Build Coastguard Worker                return gO * 2 * x
696*da0073e9SAndroid Build Coastguard Worker
697*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(MySquare.generate_vmap_rule)
698*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(MySquare, "vmap"))
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_setup_context_simple(self):
701*da0073e9SAndroid Build Coastguard Worker        class MySquare(Function):
702*da0073e9SAndroid Build Coastguard Worker            @staticmethod
703*da0073e9SAndroid Build Coastguard Worker            def forward(x):
704*da0073e9SAndroid Build Coastguard Worker                return x**2
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker            @staticmethod
707*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output):
708*da0073e9SAndroid Build Coastguard Worker                (x,) = inputs
709*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker            @staticmethod
712*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
713*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
714*da0073e9SAndroid Build Coastguard Worker                return gO * 2 * x
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
717*da0073e9SAndroid Build Coastguard Worker        y = MySquare.apply(x)
718*da0073e9SAndroid Build Coastguard Worker        (gx,) = torch.autograd.grad(y, x)
719*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx, 2 * x)
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_setup_context_multi_output(self):
722*da0073e9SAndroid Build Coastguard Worker        # Multiple outputs with some non-Tensor outputs.
723*da0073e9SAndroid Build Coastguard Worker        class MySquare(Function):
724*da0073e9SAndroid Build Coastguard Worker            @staticmethod
725*da0073e9SAndroid Build Coastguard Worker            def forward(x):
726*da0073e9SAndroid Build Coastguard Worker                two_x = x.item() * 2
727*da0073e9SAndroid Build Coastguard Worker                return x**2, two_x
728*da0073e9SAndroid Build Coastguard Worker
729*da0073e9SAndroid Build Coastguard Worker            @staticmethod
730*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output):
731*da0073e9SAndroid Build Coastguard Worker                (x,) = inputs
732*da0073e9SAndroid Build Coastguard Worker                _, two_x = output
733*da0073e9SAndroid Build Coastguard Worker                ctx.two_x = two_x
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker            @staticmethod
736*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
737*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO, _):
738*da0073e9SAndroid Build Coastguard Worker                return gO * ctx.two_x
739*da0073e9SAndroid Build Coastguard Worker
740*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
741*da0073e9SAndroid Build Coastguard Worker        y, _ = MySquare.apply(x)
742*da0073e9SAndroid Build Coastguard Worker        (gx,) = torch.autograd.grad(y, x)
743*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(gx, 2 * x)
744*da0073e9SAndroid Build Coastguard Worker
745*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_setup_context_multi_input(self):
746*da0073e9SAndroid Build Coastguard Worker        class MyReshape(Function):
747*da0073e9SAndroid Build Coastguard Worker            @staticmethod
748*da0073e9SAndroid Build Coastguard Worker            def forward(x, shape, scale_forward, scale_backward):
749*da0073e9SAndroid Build Coastguard Worker                return x.reshape(shape) * scale_forward
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker            @staticmethod
752*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output):
753*da0073e9SAndroid Build Coastguard Worker                x, shape, scale_forward, scale_backward = inputs
754*da0073e9SAndroid Build Coastguard Worker                ctx.scale_backward = scale_backward
755*da0073e9SAndroid Build Coastguard Worker                ctx.x_shape = x.shape
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker            @staticmethod
758*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
759*da0073e9SAndroid Build Coastguard Worker                return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        class MyReshapeRef(Function):
762*da0073e9SAndroid Build Coastguard Worker            @staticmethod
763*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, shape, scale_forward, scale_backward):
764*da0073e9SAndroid Build Coastguard Worker                ctx.scale_backward = scale_backward
765*da0073e9SAndroid Build Coastguard Worker                ctx.x_shape = x.shape
766*da0073e9SAndroid Build Coastguard Worker                return x.reshape(shape) * scale_forward
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker            @staticmethod
769*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
770*da0073e9SAndroid Build Coastguard Worker                return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker        def test(x, shape, scale_forward, scale_backward):
773*da0073e9SAndroid Build Coastguard Worker            y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum()
774*da0073e9SAndroid Build Coastguard Worker            (gx,) = torch.autograd.grad(y, x)
775*da0073e9SAndroid Build Coastguard Worker
776*da0073e9SAndroid Build Coastguard Worker            y_expected = MyReshapeRef.apply(
777*da0073e9SAndroid Build Coastguard Worker                x, shape, scale_forward, scale_backward
778*da0073e9SAndroid Build Coastguard Worker            ).sum()
779*da0073e9SAndroid Build Coastguard Worker            (gx_expected,) = torch.autograd.grad(y_expected, x)
780*da0073e9SAndroid Build Coastguard Worker
781*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y_expected, y)
782*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gx_expected, gx)
783*da0073e9SAndroid Build Coastguard Worker
784*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(24, requires_grad=True), (3, 8), 7, 11)
785*da0073e9SAndroid Build Coastguard Worker        test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2)
786*da0073e9SAndroid Build Coastguard Worker
787*da0073e9SAndroid Build Coastguard Worker    def test_multiple_insert_removal_caching(self):
788*da0073e9SAndroid Build Coastguard Worker        torch._C._set_cached_tensors_enabled(True)
789*da0073e9SAndroid Build Coastguard Worker        try:
790*da0073e9SAndroid Build Coastguard Worker            x = torch.rand([4])
791*da0073e9SAndroid Build Coastguard Worker
792*da0073e9SAndroid Build Coastguard Worker            torch._C._add_cached_tensor(x)
793*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch._C._is_cached_tensor(x))
794*da0073e9SAndroid Build Coastguard Worker
795*da0073e9SAndroid Build Coastguard Worker            torch._C._add_cached_tensor(x)
796*da0073e9SAndroid Build Coastguard Worker            torch._C._remove_cached_tensor(x)
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._C._is_cached_tensor(x))
799*da0073e9SAndroid Build Coastguard Worker        finally:
800*da0073e9SAndroid Build Coastguard Worker            torch._C._set_cached_tensors_enabled(False)
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker    def test_accumulate_grad(self):
803*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.ones(5, 5)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        def compute_grad(create_graph):
806*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 5, requires_grad=True)
807*da0073e9SAndroid Build Coastguard Worker            y = x + 2
808*da0073e9SAndroid Build Coastguard Worker            y.backward(grad_output, retain_graph=True)
809*da0073e9SAndroid Build Coastguard Worker            x_grad = x.grad
810*da0073e9SAndroid Build Coastguard Worker            x_grad_clone = x.grad.clone()
811*da0073e9SAndroid Build Coastguard Worker            y.backward(grad_output, create_graph=create_graph)
812*da0073e9SAndroid Build Coastguard Worker            return x_grad, x_grad_clone
813*da0073e9SAndroid Build Coastguard Worker
814*da0073e9SAndroid Build Coastguard Worker        # Accumulate in-place when create_graph is False
815*da0073e9SAndroid Build Coastguard Worker        x_grad, x_grad_clone = compute_grad(create_graph=False)
816*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_grad, x_grad_clone * 2)
817*da0073e9SAndroid Build Coastguard Worker
818*da0073e9SAndroid Build Coastguard Worker        # Accumulate out-of-place when create_graph is False
819*da0073e9SAndroid Build Coastguard Worker        x_grad, x_grad_clone = compute_grad(create_graph=True)
820*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_grad, x_grad_clone)
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    def test_accumulate_grad_tensor_reference(self):
823*da0073e9SAndroid Build Coastguard Worker        def _test_grad_tensor(
824*da0073e9SAndroid Build Coastguard Worker            params_grad_tensor,
825*da0073e9SAndroid Build Coastguard Worker            backward_grad_tensor,
826*da0073e9SAndroid Build Coastguard Worker            should_preserve_reference,
827*da0073e9SAndroid Build Coastguard Worker            create_graph,
828*da0073e9SAndroid Build Coastguard Worker        ):
829*da0073e9SAndroid Build Coastguard Worker            params = torch.tensor([1.5, 1.5]).requires_grad_()
830*da0073e9SAndroid Build Coastguard Worker            params.grad = params_grad_tensor
831*da0073e9SAndroid Build Coastguard Worker            grad_saved = params.grad
832*da0073e9SAndroid Build Coastguard Worker            params.backward(backward_grad_tensor, create_graph=create_graph)
833*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
834*da0073e9SAndroid Build Coastguard Worker                id(grad_saved) == id(params.grad), should_preserve_reference
835*da0073e9SAndroid Build Coastguard Worker            )
836*da0073e9SAndroid Build Coastguard Worker
837*da0073e9SAndroid Build Coastguard Worker        for create_graph in (False, True):
838*da0073e9SAndroid Build Coastguard Worker            # Accumulate dense gradient to sparse gradient will change the `params.grad` reference
839*da0073e9SAndroid Build Coastguard Worker            _test_grad_tensor(
840*da0073e9SAndroid Build Coastguard Worker                torch.sparse_coo_tensor(
841*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
842*da0073e9SAndroid Build Coastguard Worker                ),
843*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1.5, 1.5]),
844*da0073e9SAndroid Build Coastguard Worker                False,  # never accumulates in-place
845*da0073e9SAndroid Build Coastguard Worker                create_graph,
846*da0073e9SAndroid Build Coastguard Worker            )
847*da0073e9SAndroid Build Coastguard Worker
848*da0073e9SAndroid Build Coastguard Worker            # Accumulate dense gradient to dense gradient will preserve the `params.grad` reference,
849*da0073e9SAndroid Build Coastguard Worker            # but only if create_graph=False.
850*da0073e9SAndroid Build Coastguard Worker            _test_grad_tensor(
851*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1.5, 1.5]),
852*da0073e9SAndroid Build Coastguard Worker                torch.tensor([1.5, 1.5]),
853*da0073e9SAndroid Build Coastguard Worker                not create_graph,
854*da0073e9SAndroid Build Coastguard Worker                create_graph,
855*da0073e9SAndroid Build Coastguard Worker            )
856*da0073e9SAndroid Build Coastguard Worker
857*da0073e9SAndroid Build Coastguard Worker            # Accumulate sparse gradient to sparse gradient will preserve the `params.grad` reference,
858*da0073e9SAndroid Build Coastguard Worker            # but only if create_graph=False.
859*da0073e9SAndroid Build Coastguard Worker            _test_grad_tensor(
860*da0073e9SAndroid Build Coastguard Worker                torch.sparse_coo_tensor(
861*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
862*da0073e9SAndroid Build Coastguard Worker                ),
863*da0073e9SAndroid Build Coastguard Worker                torch.sparse_coo_tensor(
864*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0])
865*da0073e9SAndroid Build Coastguard Worker                ),
866*da0073e9SAndroid Build Coastguard Worker                not create_graph,
867*da0073e9SAndroid Build Coastguard Worker                create_graph,
868*da0073e9SAndroid Build Coastguard Worker            )
869*da0073e9SAndroid Build Coastguard Worker
870*da0073e9SAndroid Build Coastguard Worker    def test_accumulate_grad_with_zero_numel_grad(self):
871*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4, 0, requires_grad=True)
872*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4, 1, requires_grad=True)
873*da0073e9SAndroid Build Coastguard Worker        c = a + b
874*da0073e9SAndroid Build Coastguard Worker        assert c.shape == (4, 0)
875*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
876*da0073e9SAndroid Build Coastguard Worker
877*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, torch.zeros(4, 1))
878*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.zeros(4, 0))
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker    def test_hessian_vector(self):
881*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
882*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, requires_grad=True)
883*da0073e9SAndroid Build Coastguard Worker
884*da0073e9SAndroid Build Coastguard Worker        z = x**2 + y * x + y**2
885*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(2, 2), create_graph=True)
886*da0073e9SAndroid Build Coastguard Worker
887*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
888*da0073e9SAndroid Build Coastguard Worker            x_grad = 2 * x + y
889*da0073e9SAndroid Build Coastguard Worker            y_grad = x + 2 * y
890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad)
891*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad)
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker        grad_sum = 2 * x.grad + y.grad
894*da0073e9SAndroid Build Coastguard Worker        grad_sum.backward(torch.ones(2, 2))
895*da0073e9SAndroid Build Coastguard Worker        x_hv = torch.ones(2, 2) * 5
896*da0073e9SAndroid Build Coastguard Worker        y_hv = torch.ones(2, 2) * 4
897*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad + x_hv)
898*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad + y_hv)
899*da0073e9SAndroid Build Coastguard Worker
900*da0073e9SAndroid Build Coastguard Worker    def test_grad(self):
901*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
902*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, requires_grad=True)
903*da0073e9SAndroid Build Coastguard Worker        z = x**2 + y * x + y**2
904*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(2, 2), create_graph=True)
905*da0073e9SAndroid Build Coastguard Worker
906*da0073e9SAndroid Build Coastguard Worker        x_grad = 2 * x + y
907*da0073e9SAndroid Build Coastguard Worker        y_grad = x + 2 * y
908*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad)
909*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad)
910*da0073e9SAndroid Build Coastguard Worker
911*da0073e9SAndroid Build Coastguard Worker        grad_sum = 2 * x.grad + y.grad
912*da0073e9SAndroid Build Coastguard Worker        x_hv = torch.autograd.grad(
913*da0073e9SAndroid Build Coastguard Worker            outputs=[grad_sum],
914*da0073e9SAndroid Build Coastguard Worker            grad_outputs=[torch.ones(2, 2)],
915*da0073e9SAndroid Build Coastguard Worker            inputs=[x],
916*da0073e9SAndroid Build Coastguard Worker            create_graph=True,
917*da0073e9SAndroid Build Coastguard Worker        )
918*da0073e9SAndroid Build Coastguard Worker        expected_x_hv = torch.ones(2, 2) * 5
919*da0073e9SAndroid Build Coastguard Worker        expected_y_hv = torch.ones(2, 2) * 4
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_hv[0], expected_x_hv)
922*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad)
923*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad)
924*da0073e9SAndroid Build Coastguard Worker
925*da0073e9SAndroid Build Coastguard Worker        # Test that grad_outputs and outputs have the same shape
926*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.ones(2)
927*da0073e9SAndroid Build Coastguard Worker        try:
928*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
929*da0073e9SAndroid Build Coastguard Worker                outputs=[grad_sum],
930*da0073e9SAndroid Build Coastguard Worker                grad_outputs=[grad_out],
931*da0073e9SAndroid Build Coastguard Worker                inputs=[x],
932*da0073e9SAndroid Build Coastguard Worker                create_graph=True,
933*da0073e9SAndroid Build Coastguard Worker            )
934*da0073e9SAndroid Build Coastguard Worker            self.assertFail()
935*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as error:
936*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
937*da0073e9SAndroid Build Coastguard Worker                str(error),
938*da0073e9SAndroid Build Coastguard Worker                "Mismatch in shape: grad_output[0] has a shape of "
939*da0073e9SAndroid Build Coastguard Worker                + str(grad_out.shape)
940*da0073e9SAndroid Build Coastguard Worker                + " and output[0] has a shape of "
941*da0073e9SAndroid Build Coastguard Worker                + str(grad_sum.shape)
942*da0073e9SAndroid Build Coastguard Worker                + ".",
943*da0073e9SAndroid Build Coastguard Worker            )
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker    def test_grad_to_node(self):
946*da0073e9SAndroid Build Coastguard Worker        def check_matches(out, inp):
947*da0073e9SAndroid Build Coastguard Worker            ref = torch.autograd.grad(out.sum(), inp)
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker            edge = torch.autograd.graph.get_gradient_edge(inp)
950*da0073e9SAndroid Build Coastguard Worker            new = torch.autograd.grad(out.sum(), edge)
951*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ref, new)
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker        # We need to ensure that our main types of Node work (regular cpp Nodes,
954*da0073e9SAndroid Build Coastguard Worker        # AccumulateGrad Nodes and custom Function)
955*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True)
956*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
957*da0073e9SAndroid Build Coastguard Worker        check_matches(out, x)
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker        x = x.clone()
960*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
961*da0073e9SAndroid Build Coastguard Worker        check_matches(out, x)
962*da0073e9SAndroid Build Coastguard Worker
963*da0073e9SAndroid Build Coastguard Worker        x = torch.autograd._functions.Resize.apply(x, (2,))
964*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
965*da0073e9SAndroid Build Coastguard Worker        check_matches(out, x)
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker        x = torch.var_mean(x)[1]
968*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
969*da0073e9SAndroid Build Coastguard Worker        check_matches(out, x)
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker    def test_grad_to_node_set(self):
972*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True)
973*da0073e9SAndroid Build Coastguard Worker        x_edge = torch.autograd.graph.get_gradient_edge(x)
974*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
975*da0073e9SAndroid Build Coastguard Worker
976*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
977*da0073e9SAndroid Build Coastguard Worker            x.set_(torch.rand_like(x))
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "to not have been used in the graph"):
980*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out.sum(), x)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        # Works
983*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(out.sum(), x_edge)
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker    def test_grad_to_node_inplace(self):
986*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True).clone()
987*da0073e9SAndroid Build Coastguard Worker        x_edge = torch.autograd.graph.get_gradient_edge(x)
988*da0073e9SAndroid Build Coastguard Worker        x *= 2
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        g_old, g_new = torch.autograd.grad(x.sum(), (x_edge, x))
991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g_old, 2 * torch.ones_like(x))
992*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g_new, torch.ones_like(x))
993*da0073e9SAndroid Build Coastguard Worker
994*da0073e9SAndroid Build Coastguard Worker    def test_grad_to_node_multi(self):
995*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True).clone()
996*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, requires_grad=True).clone()
997*da0073e9SAndroid Build Coastguard Worker
998*da0073e9SAndroid Build Coastguard Worker        out = x + y
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        ref = torch.autograd.grad(out.sum(), (x, y))
1001*da0073e9SAndroid Build Coastguard Worker
1002*da0073e9SAndroid Build Coastguard Worker        inp_edges = (
1003*da0073e9SAndroid Build Coastguard Worker            GradientEdge(x.grad_fn, x.output_nr),
1004*da0073e9SAndroid Build Coastguard Worker            GradientEdge(y.grad_fn, y.output_nr),
1005*da0073e9SAndroid Build Coastguard Worker        )
1006*da0073e9SAndroid Build Coastguard Worker        new = torch.autograd.grad(out.sum(), inp_edges)
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref, new)
1009*da0073e9SAndroid Build Coastguard Worker
1010*da0073e9SAndroid Build Coastguard Worker    def test_grad_to_node_materialize(self):
1011*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True).clone()
1012*da0073e9SAndroid Build Coastguard Worker        edge_x = GradientEdge(x.grad_fn, x.output_nr)
1013*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, requires_grad=True).clone()
1014*da0073e9SAndroid Build Coastguard Worker        edge_y = GradientEdge(y.grad_fn, y.output_nr)
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
1017*da0073e9SAndroid Build Coastguard Worker
1018*da0073e9SAndroid Build Coastguard Worker        # Works
1019*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(
1020*da0073e9SAndroid Build Coastguard Worker            out.sum(), (edge_x, y), allow_unused=True, materialize_grads=True
1021*da0073e9SAndroid Build Coastguard Worker        )
1022*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(
1023*da0073e9SAndroid Build Coastguard Worker            out.sum(), (x, y), allow_unused=True, materialize_grads=True
1024*da0073e9SAndroid Build Coastguard Worker        )
1025*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(out.sum(), (x, edge_y), allow_unused=True)
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1028*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1029*da0073e9SAndroid Build Coastguard Worker            "materialize_grads cannot be used when the given input is a GradientEdge",
1030*da0073e9SAndroid Build Coastguard Worker        ):
1031*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
1032*da0073e9SAndroid Build Coastguard Worker                out.sum(), (x, edge_y), allow_unused=True, materialize_grads=True
1033*da0073e9SAndroid Build Coastguard Worker            )
1034*da0073e9SAndroid Build Coastguard Worker
1035*da0073e9SAndroid Build Coastguard Worker    def test_backward_to_node(self):
1036*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True).clone()
1037*da0073e9SAndroid Build Coastguard Worker        edge_x = GradientEdge(x.grad_fn, x.output_nr)
1038*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, requires_grad=True).clone()
1039*da0073e9SAndroid Build Coastguard Worker        edge_y = GradientEdge(y.grad_fn, y.output_nr)
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker        out = x.clone()
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker        # All should work in this case
1044*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(out.sum(), inputs=(edge_x, y))
1045*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(out.sum(), inputs=(x, y))
1046*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(out.sum(), inputs=(x, edge_y))
1047*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y))
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker    def test_grad_fn_input_metadata(self):
1050*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, requires_grad=True, dtype=torch.float32)
1051*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, requires_grad=True, dtype=torch.float32)
1052*da0073e9SAndroid Build Coastguard Worker        z = x * y
1053*da0073e9SAndroid Build Coastguard Worker        z_metadata = z.grad_fn._input_metadata[0]
1054*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z_metadata.shape, (2,))
1055*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z_metadata.dtype, torch.float32)
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker        # Multiple outputs
1058*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(3, 3, requires_grad=True)
1059*da0073e9SAndroid Build Coastguard Worker        var, _ = torch.var_mean(b, dim=0)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        metadata_0 = var.grad_fn._input_metadata[0]
1062*da0073e9SAndroid Build Coastguard Worker        metadata_1 = var.grad_fn._input_metadata[1]
1063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(metadata_0.shape, (3,))
1064*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(metadata_1.shape, (3,))
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        # Preserves symints
1067*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
1068*da0073e9SAndroid Build Coastguard Worker            [torch.randn(3, 2), torch.randn(2, 2)],
1069*da0073e9SAndroid Build Coastguard Worker            layout=torch.jagged,
1070*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
1071*da0073e9SAndroid Build Coastguard Worker        )
1072*da0073e9SAndroid Build Coastguard Worker        nt_metadata = nt.clone().grad_fn._input_metadata[0]
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(nt_metadata.shape[1], torch.SymInt)
1075*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_metadata.shape, nt.shape)
1076*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(nt_metadata.is_nested_tensor)
1077*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(nt_metadata.is_cpp_nested_tensor)
1078*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(nt_metadata.dtype, nt.dtype)
1079*da0073e9SAndroid Build Coastguard Worker
1080*da0073e9SAndroid Build Coastguard Worker        class Test(torch.autograd.Function):
1081*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1082*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
1083*da0073e9SAndroid Build Coastguard Worker                return x
1084*da0073e9SAndroid Build Coastguard Worker
1085*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1086*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
1087*da0073e9SAndroid Build Coastguard Worker                return grad_output
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, requires_grad=True)
1090*da0073e9SAndroid Build Coastguard Worker        x = Test.apply(x)
1091*da0073e9SAndroid Build Coastguard Worker        metadata = x.grad_fn._input_metadata[0]
1092*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(metadata.shape, (3, 3))
1093*da0073e9SAndroid Build Coastguard Worker
1094*da0073e9SAndroid Build Coastguard Worker    def test_gradient_edge_output(self):
1095*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0, 2.0], requires_grad=True)
1096*da0073e9SAndroid Build Coastguard Worker
1097*da0073e9SAndroid Build Coastguard Worker        def fn(x, reduce=True):
1098*da0073e9SAndroid Build Coastguard Worker            tmp = x.sin().cos()
1099*da0073e9SAndroid Build Coastguard Worker            if reduce:
1100*da0073e9SAndroid Build Coastguard Worker                tmp = tmp.sum()
1101*da0073e9SAndroid Build Coastguard Worker            out = tmp.exp().clone().sin().sum()
1102*da0073e9SAndroid Build Coastguard Worker            tmp_edge = torch.autograd.graph.get_gradient_edge(tmp)
1103*da0073e9SAndroid Build Coastguard Worker            return out, tmp_edge
1104*da0073e9SAndroid Build Coastguard Worker
1105*da0073e9SAndroid Build Coastguard Worker        # Compute fn backward in two steps
1106*da0073e9SAndroid Build Coastguard Worker        out, tmp_edge = fn(x)
1107*da0073e9SAndroid Build Coastguard Worker        (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker        (x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,))
1110*da0073e9SAndroid Build Coastguard Worker
1111*da0073e9SAndroid Build Coastguard Worker        # Compare with as if we did it in one go.
1112*da0073e9SAndroid Build Coastguard Worker        out, _ = fn(x)
1113*da0073e9SAndroid Build Coastguard Worker        (x_grad_ref,) = torch.autograd.grad(out, (x,))
1114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_grad, x_grad_ref)
1115*da0073e9SAndroid Build Coastguard Worker
1116*da0073e9SAndroid Build Coastguard Worker        # Incorrect case: grad_outputs not passed/implicitly None and output is
1117*da0073e9SAndroid Build Coastguard Worker        # not a scalar
1118*da0073e9SAndroid Build Coastguard Worker        out, tmp_edge = fn(x, reduce=False)
1119*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1120*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "grad can be implicitly created only for scalar output"
1121*da0073e9SAndroid Build Coastguard Worker        ):
1122*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(tmp_edge, (x,))
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker        # grad_outputs is None, and output is a scalar is fine
1125*da0073e9SAndroid Build Coastguard Worker        out, tmp_edge = fn(x, reduce=True)
1126*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(tmp_edge, (x,))
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker        # Incorrect case: grad_outputs wrong size
1129*da0073e9SAndroid Build Coastguard Worker        out, tmp_edge = fn(x)
1130*da0073e9SAndroid Build Coastguard Worker        (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
1131*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
1132*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
1133*da0073e9SAndroid Build Coastguard Worker                tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
1134*da0073e9SAndroid Build Coastguard Worker            )
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker        # Incorrect case: wrong dtype
1137*da0073e9SAndroid Build Coastguard Worker        out, tmp_edge = fn(x)
1138*da0073e9SAndroid Build Coastguard Worker        (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
1139*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"):
1140*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
1141*da0073e9SAndroid Build Coastguard Worker                tmp_edge,
1142*da0073e9SAndroid Build Coastguard Worker                (x,),
1143*da0073e9SAndroid Build Coastguard Worker                grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
1144*da0073e9SAndroid Build Coastguard Worker            )
1145*da0073e9SAndroid Build Coastguard Worker
1146*da0073e9SAndroid Build Coastguard Worker    def test_grad_nonleaf(self):
1147*da0073e9SAndroid Build Coastguard Worker        x_init = torch.randn(2, 2, requires_grad=True)
1148*da0073e9SAndroid Build Coastguard Worker        x = x_init
1149*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, requires_grad=True)
1150*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.ones(2, 2)
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker        def fn(x):
1153*da0073e9SAndroid Build Coastguard Worker            return x**2 + y * x + y**2
1154*da0073e9SAndroid Build Coastguard Worker
1155*da0073e9SAndroid Build Coastguard Worker        for _ in range(5):
1156*da0073e9SAndroid Build Coastguard Worker            (grad_x,) = torch.autograd.grad(
1157*da0073e9SAndroid Build Coastguard Worker                fn(x), x, grad_outputs=grad_output, create_graph=True
1158*da0073e9SAndroid Build Coastguard Worker            )
1159*da0073e9SAndroid Build Coastguard Worker
1160*da0073e9SAndroid Build Coastguard Worker            grad_x_expected = 2 * x + y
1161*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(y.grad)
1162*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(x.grad)
1163*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad_x, grad_x_expected)
1164*da0073e9SAndroid Build Coastguard Worker
1165*da0073e9SAndroid Build Coastguard Worker            x = x + 0.05 * grad_x
1166*da0073e9SAndroid Build Coastguard Worker
1167*da0073e9SAndroid Build Coastguard Worker        val_init = fn(x_init).sum()
1168*da0073e9SAndroid Build Coastguard Worker        val_final = fn(x).sum()
1169*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(val_final, val_init)
1170*da0073e9SAndroid Build Coastguard Worker
1171*da0073e9SAndroid Build Coastguard Worker        x.backward(grad_output)
1172*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(y.grad)
1173*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(x_init.grad)
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker    def test_grad_nonleaf_many_outputs(self):
1176*da0073e9SAndroid Build Coastguard Worker        # This checks an edge case for function callbacks
1177*da0073e9SAndroid Build Coastguard Worker        # We want to capture two grads of a function, but can only
1178*da0073e9SAndroid Build Coastguard Worker        # register a single callback.
1179*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 2, requires_grad=True)
1180*da0073e9SAndroid Build Coastguard Worker        a, b = x.chunk(2)
1181*da0073e9SAndroid Build Coastguard Worker
1182*da0073e9SAndroid Build Coastguard Worker        def hook(*grads):
1183*da0073e9SAndroid Build Coastguard Worker            hook_called[0] = True
1184*da0073e9SAndroid Build Coastguard Worker
1185*da0073e9SAndroid Build Coastguard Worker        hook_called = [False]
1186*da0073e9SAndroid Build Coastguard Worker        x.register_hook(hook)
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker        go = torch.randn(2, 2)
1189*da0073e9SAndroid Build Coastguard Worker        grad_a, grad_b = torch.autograd.grad(
1190*da0073e9SAndroid Build Coastguard Worker            (a + 2 * b), [a, b], grad_outputs=go, create_graph=True
1191*da0073e9SAndroid Build Coastguard Worker        )
1192*da0073e9SAndroid Build Coastguard Worker
1193*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_a, go)
1194*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_b, go * 2)
1195*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(hook_called[0])
1196*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x.grad)
1197*da0073e9SAndroid Build Coastguard Worker
1198*da0073e9SAndroid Build Coastguard Worker    def test_grad_nonleaf_register_hook(self):
1199*da0073e9SAndroid Build Coastguard Worker        # This checks an edge case for register_hook.
1200*da0073e9SAndroid Build Coastguard Worker        # We want to capture grad of a nonleaf tensor,
1201*da0073e9SAndroid Build Coastguard Worker        # but avoid segfault during backward of other nonleaf tensors
1202*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, requires_grad=True)
1203*da0073e9SAndroid Build Coastguard Worker        x_list = x.unbind()
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker        x0 = x_list[0]
1206*da0073e9SAndroid Build Coastguard Worker        hook_results = [None]
1207*da0073e9SAndroid Build Coastguard Worker
1208*da0073e9SAndroid Build Coastguard Worker        def hook(grad):
1209*da0073e9SAndroid Build Coastguard Worker            hook_results[0] = grad
1210*da0073e9SAndroid Build Coastguard Worker
1211*da0073e9SAndroid Build Coastguard Worker        x0.register_hook(hook)
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Worker        x_list[0].backward()
1214*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hook_results[0], torch.tensor(1.0))
1215*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.tensor([1.0, 0, 0, 0, 0])
1216*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, expected_grad)
1217*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x_list[0].grad)
1218*da0073e9SAndroid Build Coastguard Worker
1219*da0073e9SAndroid Build Coastguard Worker        for i in range(1, 5, 1):
1220*da0073e9SAndroid Build Coastguard Worker            x_list[i].backward()
1221*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(hook_results[0], None)
1222*da0073e9SAndroid Build Coastguard Worker            expected_grad[i] = 1.0
1223*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, expected_grad)
1224*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(x_list[i].grad)
1225*da0073e9SAndroid Build Coastguard Worker
1226*da0073e9SAndroid Build Coastguard Worker    def test_grad_materialize_grads(self):
1227*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(0.5, requires_grad=True)
1228*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1229*da0073e9SAndroid Build Coastguard Worker        y = x * a
1230*da0073e9SAndroid Build Coastguard Worker        dydx = torch.autograd.grad(y, x, create_graph=True)
1231*da0073e9SAndroid Build Coastguard Worker        d2ydx2_none = torch.autograd.grad(dydx, x, create_graph=True, allow_unused=True)
1232*da0073e9SAndroid Build Coastguard Worker        d2ydx2 = torch.autograd.grad(
1233*da0073e9SAndroid Build Coastguard Worker            dydx, x, create_graph=True, allow_unused=True, materialize_grads=True
1234*da0073e9SAndroid Build Coastguard Worker        )
1235*da0073e9SAndroid Build Coastguard Worker        # `allow_unused` set to True implicitly
1236*da0073e9SAndroid Build Coastguard Worker        d3ydx3 = torch.autograd.grad(d2ydx2, x, materialize_grads=True)
1237*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(d2ydx2_none[0])
1238*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d2ydx2[0].item(), 0)
1239*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d3ydx3[0].item(), 0)
1240*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1241*da0073e9SAndroid Build Coastguard Worker            ValueError, "Expected allow_unused to be True or not passed when"
1242*da0073e9SAndroid Build Coastguard Worker        ):
1243*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True)
1244*da0073e9SAndroid Build Coastguard Worker
1245*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_on_non_leaf(self):
1246*da0073e9SAndroid Build Coastguard Worker        def hook(tensor):
1247*da0073e9SAndroid Build Coastguard Worker            tensor.sub_(1.0)
1248*da0073e9SAndroid Build Coastguard Worker
1249*da0073e9SAndroid Build Coastguard Worker        leaf = torch.rand(3, requires_grad=True)
1250*da0073e9SAndroid Build Coastguard Worker        non_leaf = 2.0 * leaf
1251*da0073e9SAndroid Build Coastguard Worker
1252*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1253*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1254*da0073e9SAndroid Build Coastguard Worker            "post accumulate grad hooks cannot be registered on non-leaf tensors",
1255*da0073e9SAndroid Build Coastguard Worker        ):
1256*da0073e9SAndroid Build Coastguard Worker            non_leaf.register_post_accumulate_grad_hook(hook)
1257*da0073e9SAndroid Build Coastguard Worker
1258*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_multiple_hooks(self):
1259*da0073e9SAndroid Build Coastguard Worker        def hook1(tensor):
1260*da0073e9SAndroid Build Coastguard Worker            tensor.sub_(tensor.grad)
1261*da0073e9SAndroid Build Coastguard Worker
1262*da0073e9SAndroid Build Coastguard Worker        def hook2(tensor):
1263*da0073e9SAndroid Build Coastguard Worker            tensor.mul_(4.0)
1264*da0073e9SAndroid Build Coastguard Worker
1265*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(3, requires_grad=True)
1266*da0073e9SAndroid Build Coastguard Worker        tensor_ref = tensor.clone().detach()
1267*da0073e9SAndroid Build Coastguard Worker        tensor.register_post_accumulate_grad_hook(hook1)
1268*da0073e9SAndroid Build Coastguard Worker        tensor.register_post_accumulate_grad_hook(hook2)
1269*da0073e9SAndroid Build Coastguard Worker        sum = tensor.sum()
1270*da0073e9SAndroid Build Coastguard Worker        sum.backward()
1271*da0073e9SAndroid Build Coastguard Worker        # both hooks should be called, in order
1272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(4.0 * (tensor_ref - 1.0), tensor)
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_multiple_tensors(self):
1275*da0073e9SAndroid Build Coastguard Worker        def hook(tensor):
1276*da0073e9SAndroid Build Coastguard Worker            tensor.sub_(tensor.grad)
1277*da0073e9SAndroid Build Coastguard Worker
1278*da0073e9SAndroid Build Coastguard Worker        tensor1 = torch.rand(3, requires_grad=True)
1279*da0073e9SAndroid Build Coastguard Worker        tensor1_ref = tensor1.clone().detach()
1280*da0073e9SAndroid Build Coastguard Worker        tensor2 = torch.rand(5, requires_grad=True)
1281*da0073e9SAndroid Build Coastguard Worker        tensor2_ref = tensor2.clone().detach()
1282*da0073e9SAndroid Build Coastguard Worker        tensor1.register_post_accumulate_grad_hook(hook)
1283*da0073e9SAndroid Build Coastguard Worker        tensor2.register_post_accumulate_grad_hook(hook)
1284*da0073e9SAndroid Build Coastguard Worker        tensor1.sum().backward()
1285*da0073e9SAndroid Build Coastguard Worker        tensor2.sum().backward()
1286*da0073e9SAndroid Build Coastguard Worker        # both tensors should have been modified
1287*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor1_ref - 1.0, tensor1)
1288*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor2_ref - 1.0, tensor2)
1289*da0073e9SAndroid Build Coastguard Worker
1290*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_returns_not_None(self):
1291*da0073e9SAndroid Build Coastguard Worker        def bad_hook(tensor):
1292*da0073e9SAndroid Build Coastguard Worker            return tensor.grad
1293*da0073e9SAndroid Build Coastguard Worker
1294*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(2, 3, requires_grad=True)
1295*da0073e9SAndroid Build Coastguard Worker        tensor.register_post_accumulate_grad_hook(bad_hook)
1296*da0073e9SAndroid Build Coastguard Worker        # should error!
1297*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "hooks should return None."):
1298*da0073e9SAndroid Build Coastguard Worker            tensor.sum().backward()
1299*da0073e9SAndroid Build Coastguard Worker
1300*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_e2e(self):
1301*da0073e9SAndroid Build Coastguard Worker        def setup_optim_in_bwd(model):
1302*da0073e9SAndroid Build Coastguard Worker            optims = {}
1303*da0073e9SAndroid Build Coastguard Worker            handles = []
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker            def optim_step_hook(param):
1306*da0073e9SAndroid Build Coastguard Worker                optims[param].step()
1307*da0073e9SAndroid Build Coastguard Worker                optims[param].zero_grad()
1308*da0073e9SAndroid Build Coastguard Worker
1309*da0073e9SAndroid Build Coastguard Worker            for p in model.parameters():
1310*da0073e9SAndroid Build Coastguard Worker                optims[p] = torch.optim.Adam([p])
1311*da0073e9SAndroid Build Coastguard Worker                handles.append(p.register_post_accumulate_grad_hook(optim_step_hook))
1312*da0073e9SAndroid Build Coastguard Worker
1313*da0073e9SAndroid Build Coastguard Worker            return handles
1314*da0073e9SAndroid Build Coastguard Worker
1315*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.Linear(3, 2)
1316*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(2, 3)
1317*da0073e9SAndroid Build Coastguard Worker        handles = setup_optim_in_bwd(model)
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker        # make a copy for reference
1320*da0073e9SAndroid Build Coastguard Worker        model_copy = deepcopy(model)
1321*da0073e9SAndroid Build Coastguard Worker        optim_copy = torch.optim.Adam(model_copy.parameters())
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        iters = 5
1324*da0073e9SAndroid Build Coastguard Worker
1325*da0073e9SAndroid Build Coastguard Worker        for _ in range(iters):
1326*da0073e9SAndroid Build Coastguard Worker            loss = model(input).sum()
1327*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1328*da0073e9SAndroid Build Coastguard Worker
1329*da0073e9SAndroid Build Coastguard Worker            loss_copy = model_copy(input).sum()
1330*da0073e9SAndroid Build Coastguard Worker            loss_copy.backward()
1331*da0073e9SAndroid Build Coastguard Worker            optim_copy.step()
1332*da0073e9SAndroid Build Coastguard Worker            optim_copy.zero_grad()
1333*da0073e9SAndroid Build Coastguard Worker
1334*da0073e9SAndroid Build Coastguard Worker        params_copy = []  # freeze a copy of the params to compare later
1335*da0073e9SAndroid Build Coastguard Worker        for p_reference, p in zip(model_copy.parameters(), model.parameters()):
1336*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p_reference, p)
1337*da0073e9SAndroid Build Coastguard Worker            params_copy.append(p_reference.clone().detach())
1338*da0073e9SAndroid Build Coastguard Worker
1339*da0073e9SAndroid Build Coastguard Worker        # After removing the handle, the model should no longer update.
1340*da0073e9SAndroid Build Coastguard Worker        for h in handles:
1341*da0073e9SAndroid Build Coastguard Worker            h.remove()
1342*da0073e9SAndroid Build Coastguard Worker
1343*da0073e9SAndroid Build Coastguard Worker        for _ in range(iters):
1344*da0073e9SAndroid Build Coastguard Worker            loss = model(input).sum()
1345*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker            loss_copy = model_copy(input).sum()
1348*da0073e9SAndroid Build Coastguard Worker            loss_copy.backward()
1349*da0073e9SAndroid Build Coastguard Worker            optim_copy.step()
1350*da0073e9SAndroid Build Coastguard Worker            optim_copy.zero_grad()
1351*da0073e9SAndroid Build Coastguard Worker
1352*da0073e9SAndroid Build Coastguard Worker        for p_static, p_reference, p in zip(
1353*da0073e9SAndroid Build Coastguard Worker            params_copy, model_copy.parameters(), model.parameters()
1354*da0073e9SAndroid Build Coastguard Worker        ):
1355*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p_static, p)
1356*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(p_reference, p)
1357*da0073e9SAndroid Build Coastguard Worker
1358*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_gets_cleaned_up(self):
1359*da0073e9SAndroid Build Coastguard Worker        def fun_stuff_with_hook():
1360*da0073e9SAndroid Build Coastguard Worker            thing_to_put_in_hook = torch.rand(3)
1361*da0073e9SAndroid Build Coastguard Worker
1362*da0073e9SAndroid Build Coastguard Worker            def hook(tensor):
1363*da0073e9SAndroid Build Coastguard Worker                tensor.sub_(tensor.grad)
1364*da0073e9SAndroid Build Coastguard Worker                tensor.add_(thing_to_put_in_hook)
1365*da0073e9SAndroid Build Coastguard Worker
1366*da0073e9SAndroid Build Coastguard Worker            tensor = torch.rand(3, requires_grad=True)
1367*da0073e9SAndroid Build Coastguard Worker            tensor.register_post_accumulate_grad_hook(hook)
1368*da0073e9SAndroid Build Coastguard Worker            tensor.sum().backward()
1369*da0073e9SAndroid Build Coastguard Worker            ref = weakref.ref(thing_to_put_in_hook)
1370*da0073e9SAndroid Build Coastguard Worker            gc.collect()
1371*da0073e9SAndroid Build Coastguard Worker            return tensor, ref
1372*da0073e9SAndroid Build Coastguard Worker
1373*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
1374*da0073e9SAndroid Build Coastguard Worker            tensor, ref = fun_stuff_with_hook()
1375*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(
1376*da0073e9SAndroid Build Coastguard Worker                ref()
1377*da0073e9SAndroid Build Coastguard Worker            )  # thing_to_put_in_hook should be kept alive by tensor
1378*da0073e9SAndroid Build Coastguard Worker
1379*da0073e9SAndroid Build Coastguard Worker            del tensor
1380*da0073e9SAndroid Build Coastguard Worker            gc.collect()
1381*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(ref())  # thing_to_put_in_hook should be cleaned
1382*da0073e9SAndroid Build Coastguard Worker
1383*da0073e9SAndroid Build Coastguard Worker    def test_post_accumulate_grad_hook_ordering(self):
1384*da0073e9SAndroid Build Coastguard Worker        tensor = torch.rand(3, requires_grad=True)
1385*da0073e9SAndroid Build Coastguard Worker
1386*da0073e9SAndroid Build Coastguard Worker        def pre_hook(grad):
1387*da0073e9SAndroid Build Coastguard Worker            return grad.sub(2.0)
1388*da0073e9SAndroid Build Coastguard Worker
1389*da0073e9SAndroid Build Coastguard Worker        def acc_grad_node_pre_hook(grad_out):
1390*da0073e9SAndroid Build Coastguard Worker            return (grad_out[0].div(5.0),)
1391*da0073e9SAndroid Build Coastguard Worker
1392*da0073e9SAndroid Build Coastguard Worker        def post_acc_grad_hook(tensor):
1393*da0073e9SAndroid Build Coastguard Worker            tensor.grad.add_(0.5)
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker        def acc_grad_node_post_hook(grad_in, grad_out):
1396*da0073e9SAndroid Build Coastguard Worker            tensor.grad = grad_out[0].mul(10)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker        acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0]
1399*da0073e9SAndroid Build Coastguard Worker        tensor.register_hook(pre_hook)
1400*da0073e9SAndroid Build Coastguard Worker        acc_grad.register_prehook(acc_grad_node_pre_hook)
1401*da0073e9SAndroid Build Coastguard Worker        tensor.register_post_accumulate_grad_hook(post_acc_grad_hook)
1402*da0073e9SAndroid Build Coastguard Worker        acc_grad.register_hook(acc_grad_node_post_hook)
1403*da0073e9SAndroid Build Coastguard Worker        tensor.sum().backward()
1404*da0073e9SAndroid Build Coastguard Worker
1405*da0073e9SAndroid Build Coastguard Worker        # the hooks should run in the order of:
1406*da0073e9SAndroid Build Coastguard Worker        #   1. tensor prehook
1407*da0073e9SAndroid Build Coastguard Worker        #   2. acc_grad prehook
1408*da0073e9SAndroid Build Coastguard Worker        #   3. tensor post acc_grad hook
1409*da0073e9SAndroid Build Coastguard Worker        #   4. acc_grad posthook
1410*da0073e9SAndroid Build Coastguard Worker        # so that would be ((1 - 2) / 5 + 0.5) * 10 = 3
1411*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.tensor([3.0, 3.0, 3.0]), tensor.grad)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    def test_hook_with_no_name(self):
1414*da0073e9SAndroid Build Coastguard Worker        # Create a hook that do not have a __name__ attribute
1415*da0073e9SAndroid Build Coastguard Worker        class MyHookClass:
1416*da0073e9SAndroid Build Coastguard Worker            def __call__(self, grad):
1417*da0073e9SAndroid Build Coastguard Worker                return grad.clone()
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, requires_grad=True).clone()
1420*da0073e9SAndroid Build Coastguard Worker        x.register_hook(MyHookClass())
1421*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
1422*da0073e9SAndroid Build Coastguard Worker        # Should run fine
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker    def test_prehook_ordering(self):
1425*da0073e9SAndroid Build Coastguard Worker        # Hooks registered to tensor are ordered before those
1426*da0073e9SAndroid Build Coastguard Worker        # that are registered to grad_fn
1427*da0073e9SAndroid Build Coastguard Worker        log = []
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker        def hook1(g):
1430*da0073e9SAndroid Build Coastguard Worker            log.append(1)
1431*da0073e9SAndroid Build Coastguard Worker            return g * 3
1432*da0073e9SAndroid Build Coastguard Worker
1433*da0073e9SAndroid Build Coastguard Worker        def hook2(gs):
1434*da0073e9SAndroid Build Coastguard Worker            log.append(2)
1435*da0073e9SAndroid Build Coastguard Worker            return tuple(g * 2 for g in gs)
1436*da0073e9SAndroid Build Coastguard Worker
1437*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1438*da0073e9SAndroid Build Coastguard Worker        b = a.clone()
1439*da0073e9SAndroid Build Coastguard Worker
1440*da0073e9SAndroid Build Coastguard Worker        b.grad_fn.register_prehook(hook2)
1441*da0073e9SAndroid Build Coastguard Worker        b.register_hook(hook1)
1442*da0073e9SAndroid Build Coastguard Worker        b.grad_fn.register_prehook(hook2)
1443*da0073e9SAndroid Build Coastguard Worker
1444*da0073e9SAndroid Build Coastguard Worker        acc = b.grad_fn.next_functions[0][0]
1445*da0073e9SAndroid Build Coastguard Worker        a.register_hook(hook1)
1446*da0073e9SAndroid Build Coastguard Worker        acc.register_prehook(hook2)
1447*da0073e9SAndroid Build Coastguard Worker        a.register_hook(hook1)
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        b.sum().backward(retain_graph=True)
1450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log, [1, 2, 2, 1, 1, 2])
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker        # grad also runs hooks on accumulate grad nodes, even though
1453*da0073e9SAndroid Build Coastguard Worker        # the accumulate grad nodes are not actually executed
1454*da0073e9SAndroid Build Coastguard Worker        log = []
1455*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(b.sum(), inputs=(a,), retain_graph=True)
1456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log, [1, 2, 2, 1, 1])
1457*da0073e9SAndroid Build Coastguard Worker
1458*da0073e9SAndroid Build Coastguard Worker        log = []
1459*da0073e9SAndroid Build Coastguard Worker        b.sum().backward(inputs=(b,))
1460*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(log, [1, 2, 2])
1461*da0073e9SAndroid Build Coastguard Worker        # retains_grad hooks would not observe modifications by all pre hooks
1462*da0073e9SAndroid Build Coastguard Worker        # because they are executed after
1463*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad.item(), 3)
1464*da0073e9SAndroid Build Coastguard Worker
1465*da0073e9SAndroid Build Coastguard Worker    def test_retains_grad_can_always_observe_tensor_prehook(self):
1466*da0073e9SAndroid Build Coastguard Worker        def tensor_prehook(g):
1467*da0073e9SAndroid Build Coastguard Worker            return g * 2
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1470*da0073e9SAndroid Build Coastguard Worker        b = a.clone()
1471*da0073e9SAndroid Build Coastguard Worker        b.register_hook(tensor_prehook)
1472*da0073e9SAndroid Build Coastguard Worker        b.retain_grad()
1473*da0073e9SAndroid Build Coastguard Worker        b.register_hook(tensor_prehook)
1474*da0073e9SAndroid Build Coastguard Worker
1475*da0073e9SAndroid Build Coastguard Worker        b.clone().backward()
1476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad.item(), 4)
1477*da0073e9SAndroid Build Coastguard Worker
1478*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1479*da0073e9SAndroid Build Coastguard Worker        b = a.clone()
1480*da0073e9SAndroid Build Coastguard Worker        b.retain_grad()
1481*da0073e9SAndroid Build Coastguard Worker        b.register_hook(tensor_prehook)
1482*da0073e9SAndroid Build Coastguard Worker
1483*da0073e9SAndroid Build Coastguard Worker        b.clone().backward()
1484*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad.item(), 2)
1485*da0073e9SAndroid Build Coastguard Worker
1486*da0073e9SAndroid Build Coastguard Worker    def test_accumulate_grad_posthooks_can_observe_tensor_prehook(self):
1487*da0073e9SAndroid Build Coastguard Worker        # Post hooks on accumulate should be able to observe changes to
1488*da0073e9SAndroid Build Coastguard Worker        # grad made by tensor prehooks
1489*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1490*da0073e9SAndroid Build Coastguard Worker
1491*da0073e9SAndroid Build Coastguard Worker        def tensor_prehook(g):
1492*da0073e9SAndroid Build Coastguard Worker            return g * 2
1493*da0073e9SAndroid Build Coastguard Worker
1494*da0073e9SAndroid Build Coastguard Worker        def posthook(gO, gI):
1495*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(gI[0], a * 2))
1496*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(gO), 0)
1497*da0073e9SAndroid Build Coastguard Worker
1498*da0073e9SAndroid Build Coastguard Worker        def prehook(gI):
1499*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(gI[0], a * 2))
1500*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(gI), 1)
1501*da0073e9SAndroid Build Coastguard Worker
1502*da0073e9SAndroid Build Coastguard Worker        b = a.clone()
1503*da0073e9SAndroid Build Coastguard Worker        acc = b.grad_fn.next_functions[0][0]
1504*da0073e9SAndroid Build Coastguard Worker        acc.register_hook(posthook)
1505*da0073e9SAndroid Build Coastguard Worker        acc.register_prehook(prehook)
1506*da0073e9SAndroid Build Coastguard Worker        a.register_hook(tensor_prehook)
1507*da0073e9SAndroid Build Coastguard Worker
1508*da0073e9SAndroid Build Coastguard Worker        b.backward()
1509*da0073e9SAndroid Build Coastguard Worker
1510*da0073e9SAndroid Build Coastguard Worker    def test_accumulate_grad_posthooks_should_not_execute(self):
1511*da0073e9SAndroid Build Coastguard Worker        def tensor_prehook(g):
1512*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker        def posthook(gO, gI):
1515*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError
1516*da0073e9SAndroid Build Coastguard Worker
1517*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1518*da0073e9SAndroid Build Coastguard Worker        a.register_hook(tensor_prehook)
1519*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(1.0, requires_grad=True)
1520*da0073e9SAndroid Build Coastguard Worker        c = a.clone()
1521*da0073e9SAndroid Build Coastguard Worker        acc = c.grad_fn.next_functions[0][0]
1522*da0073e9SAndroid Build Coastguard Worker        acc.register_hook(posthook)
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        out = a + b + c
1525*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(inputs=[b])
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker    def test_hook_edge_case_when_called_with_grad(self):
1528*da0073e9SAndroid Build Coastguard Worker        # grad executes the tensor hooks of the next node but not
1529*da0073e9SAndroid Build Coastguard Worker        # grad_fn pre hooks or the post hooks
1530*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
1531*da0073e9SAndroid Build Coastguard Worker        b = a * 2
1532*da0073e9SAndroid Build Coastguard Worker        c = b * 2
1533*da0073e9SAndroid Build Coastguard Worker
1534*da0073e9SAndroid Build Coastguard Worker        tensor_hook_count = [0]
1535*da0073e9SAndroid Build Coastguard Worker        prehook_count = [0]
1536*da0073e9SAndroid Build Coastguard Worker        posthook_count = [0]
1537*da0073e9SAndroid Build Coastguard Worker
1538*da0073e9SAndroid Build Coastguard Worker        def reset_counts():
1539*da0073e9SAndroid Build Coastguard Worker            nonlocal tensor_hook_count, prehook_count, posthook_count
1540*da0073e9SAndroid Build Coastguard Worker            tensor_hook_count = [0]
1541*da0073e9SAndroid Build Coastguard Worker            prehook_count = [0]
1542*da0073e9SAndroid Build Coastguard Worker            posthook_count = [0]
1543*da0073e9SAndroid Build Coastguard Worker
1544*da0073e9SAndroid Build Coastguard Worker        def tensor_prehook(g):
1545*da0073e9SAndroid Build Coastguard Worker            tensor_hook_count[0] += 1
1546*da0073e9SAndroid Build Coastguard Worker
1547*da0073e9SAndroid Build Coastguard Worker        def prehook(g):
1548*da0073e9SAndroid Build Coastguard Worker            prehook_count[0] += 1
1549*da0073e9SAndroid Build Coastguard Worker
1550*da0073e9SAndroid Build Coastguard Worker        def posthook(gI, gO):
1551*da0073e9SAndroid Build Coastguard Worker            posthook_count[0] += 1
1552*da0073e9SAndroid Build Coastguard Worker
1553*da0073e9SAndroid Build Coastguard Worker        a.register_hook(tensor_prehook)
1554*da0073e9SAndroid Build Coastguard Worker        b.register_hook(tensor_prehook)
1555*da0073e9SAndroid Build Coastguard Worker        acc = b.grad_fn.next_functions[0][0]
1556*da0073e9SAndroid Build Coastguard Worker        acc.register_hook(posthook)
1557*da0073e9SAndroid Build Coastguard Worker        acc.register_prehook(prehook)
1558*da0073e9SAndroid Build Coastguard Worker        b.grad_fn.register_hook(posthook)
1559*da0073e9SAndroid Build Coastguard Worker        b.grad_fn.register_prehook(prehook)
1560*da0073e9SAndroid Build Coastguard Worker
1561*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(c, inputs=(b), retain_graph=True)
1562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_hook_count[0], 1)
1563*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(posthook_count[0], 0)
1564*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prehook_count[0], 0)
1565*da0073e9SAndroid Build Coastguard Worker        reset_counts()
1566*da0073e9SAndroid Build Coastguard Worker
1567*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(c, inputs=(a, b), retain_graph=True)
1568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_hook_count[0], 2)
1569*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(posthook_count[0], 1)
1570*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prehook_count[0], 1)
1571*da0073e9SAndroid Build Coastguard Worker        reset_counts()
1572*da0073e9SAndroid Build Coastguard Worker
1573*da0073e9SAndroid Build Coastguard Worker        c.backward(retain_graph=True)
1574*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_hook_count[0], 2)
1575*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(posthook_count[0], 2)
1576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prehook_count[0], 2)
1577*da0073e9SAndroid Build Coastguard Worker        reset_counts()
1578*da0073e9SAndroid Build Coastguard Worker
1579*da0073e9SAndroid Build Coastguard Worker        c.backward(inputs=(a, b), retain_graph=True)
1580*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor_hook_count[0], 2)
1581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(posthook_count[0], 2)
1582*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prehook_count[0], 2)
1583*da0073e9SAndroid Build Coastguard Worker
1584*da0073e9SAndroid Build Coastguard Worker    def test_sharded_grad(self):
1585*da0073e9SAndroid Build Coastguard Worker        leaves = [torch.zeros(5, 5, requires_grad=True) for _ in range(10)]
1586*da0073e9SAndroid Build Coastguard Worker        intermediates = [l * i + l * l for i, l in enumerate(leaves)]
1587*da0073e9SAndroid Build Coastguard Worker        loss = sum(v * i for i, v in enumerate(intermediates)).sum()
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker        # define a helper for dividing intermediates into groups
1590*da0073e9SAndroid Build Coastguard Worker        def group(l, group_size):
1591*da0073e9SAndroid Build Coastguard Worker            return (l[i : i + group_size] for i in range(0, len(l), group_size))
1592*da0073e9SAndroid Build Coastguard Worker
1593*da0073e9SAndroid Build Coastguard Worker        # Compute the d loss / d intermediates in chunks of shard_size
1594*da0073e9SAndroid Build Coastguard Worker        shard_size = 2
1595*da0073e9SAndroid Build Coastguard Worker        d_intermediates = [
1596*da0073e9SAndroid Build Coastguard Worker            d_i
1597*da0073e9SAndroid Build Coastguard Worker            for intermediates_batch in group(intermediates, shard_size)
1598*da0073e9SAndroid Build Coastguard Worker            for d_i in torch.autograd.grad(loss, intermediates_batch)
1599*da0073e9SAndroid Build Coastguard Worker        ]
1600*da0073e9SAndroid Build Coastguard Worker        # Compute rest of backward pass
1601*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(intermediates, d_intermediates)
1602*da0073e9SAndroid Build Coastguard Worker
1603*da0073e9SAndroid Build Coastguard Worker        for i, l in enumerate(leaves):
1604*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(l.grad, i * i * (1 + l))
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker    def test_backward_badcalls(self):
1607*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
1608*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1609*da0073e9SAndroid Build Coastguard Worker            x.backward()
1610*da0073e9SAndroid Build Coastguard Worker
1611*da0073e9SAndroid Build Coastguard Worker    def test_grad_badcalls(self):
1612*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1)
1613*da0073e9SAndroid Build Coastguard Worker        y = x**2
1614*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1615*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(x, y)
1616*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1617*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(y, x)
1618*da0073e9SAndroid Build Coastguard Worker
1619*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
1620*da0073e9SAndroid Build Coastguard Worker        y = x**2
1621*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(y, x)  # this should succeed now
1622*da0073e9SAndroid Build Coastguard Worker
1623*da0073e9SAndroid Build Coastguard Worker    def test_grad_empty_inputs(self):
1624*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0], requires_grad=True)
1625*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "grad requires non-empty inputs."):
1626*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(2 * x, [], grad_outputs=torch.tensor([1.0]))
1627*da0073e9SAndroid Build Coastguard Worker
1628*da0073e9SAndroid Build Coastguard Worker    def test_grad_fn_badcalls(self):
1629*da0073e9SAndroid Build Coastguard Worker        error_regex = "expected .* arguments, got .* instead"
1630*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
1631*da0073e9SAndroid Build Coastguard Worker        y = x**2
1632*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, error_regex):
1633*da0073e9SAndroid Build Coastguard Worker            y.grad_fn(x.detach(), x.detach())  # too many
1634*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, error_regex):
1635*da0073e9SAndroid Build Coastguard Worker            y.grad_fn()  # too few
1636*da0073e9SAndroid Build Coastguard Worker
1637*da0073e9SAndroid Build Coastguard Worker        y.grad_fn(x.detach())  # this should succeed
1638*da0073e9SAndroid Build Coastguard Worker
1639*da0073e9SAndroid Build Coastguard Worker    def test_grad_unreachable(self):
1640*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
1641*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(1, requires_grad=True)
1642*da0073e9SAndroid Build Coastguard Worker        # Make sure x and y have grad accumulators allocated
1643*da0073e9SAndroid Build Coastguard Worker        z = x * 2
1644*da0073e9SAndroid Build Coastguard Worker        w = y * 2
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker        grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=True)
1647*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_x, x * 2)
1648*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(grad_y)
1649*da0073e9SAndroid Build Coastguard Worker
1650*da0073e9SAndroid Build Coastguard Worker        # This is slightly different than the case above, because z doesn't even
1651*da0073e9SAndroid Build Coastguard Worker        # have a grad accumulator allocated.
1652*da0073e9SAndroid Build Coastguard Worker        z = torch.ones(1, requires_grad=True)
1653*da0073e9SAndroid Build Coastguard Worker        grad_x, grad_z = torch.autograd.grad(x * 2, [x, z], allow_unused=True)
1654*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_x, x * 2)
1655*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(grad_z)
1656*da0073e9SAndroid Build Coastguard Worker
1657*da0073e9SAndroid Build Coastguard Worker        # allow_unused=False, but grads contains None inside, should throw
1658*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Set allow_unused=True"):
1659*da0073e9SAndroid Build Coastguard Worker            grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=False)
1660*da0073e9SAndroid Build Coastguard Worker
1661*da0073e9SAndroid Build Coastguard Worker    def test_grad_unreachable_discovery(self):
1662*da0073e9SAndroid Build Coastguard Worker        # Test that certain nodes are not erroneously executed when an input
1663*da0073e9SAndroid Build Coastguard Worker        # is unreachable. See #39784
1664*da0073e9SAndroid Build Coastguard Worker        class MyFunc(torch.autograd.Function):
1665*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1666*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
1667*da0073e9SAndroid Build Coastguard Worker                return x
1668*da0073e9SAndroid Build Coastguard Worker
1669*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1670*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
1671*da0073e9SAndroid Build Coastguard Worker                self.fail("This node should not be executed!")
1672*da0073e9SAndroid Build Coastguard Worker
1673*da0073e9SAndroid Build Coastguard Worker        x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
1674*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1, requires_grad=True)
1675*da0073e9SAndroid Build Coastguard Worker        (gY,) = torch.autograd.grad(x, (y,), allow_unused=True)
1676*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(gY)
1677*da0073e9SAndroid Build Coastguard Worker
1678*da0073e9SAndroid Build Coastguard Worker        x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
1679*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1, requires_grad=True)
1680*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(1, requires_grad=True)
1681*da0073e9SAndroid Build Coastguard Worker        (gY, gZ) = torch.autograd.grad(x + z, (y, z), allow_unused=True)
1682*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(gY)
1683*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(gZ)
1684*da0073e9SAndroid Build Coastguard Worker
1685*da0073e9SAndroid Build Coastguard Worker        x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2)
1686*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1, requires_grad=True)
1687*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(x, inputs=(y,))  # allow_unused is implicitly True!
1688*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(y.grad)
1689*da0073e9SAndroid Build Coastguard Worker
1690*da0073e9SAndroid Build Coastguard Worker    def test_grad_batched_grad(self):
1691*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
1692*da0073e9SAndroid Build Coastguard Worker
1693*da0073e9SAndroid Build Coastguard Worker        out = x.clone()  # Size([2, 2])
1694*da0073e9SAndroid Build Coastguard Worker        batched_grad = (
1695*da0073e9SAndroid Build Coastguard Worker            torch.arange(3).expand(2, 2, 3).transpose(0, 2)
1696*da0073e9SAndroid Build Coastguard Worker        )  # Size([3, 2, 2])
1697*da0073e9SAndroid Build Coastguard Worker        (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
1698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1699*da0073e9SAndroid Build Coastguard Worker            grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype)
1700*da0073e9SAndroid Build Coastguard Worker        )
1701*da0073e9SAndroid Build Coastguard Worker
1702*da0073e9SAndroid Build Coastguard Worker        # Detect shape mismatch
1703*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.ones(2, 2)
1704*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1705*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "If `is_grads_batched=True`, we interpret the first"
1706*da0073e9SAndroid Build Coastguard Worker        ):
1707*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
1708*da0073e9SAndroid Build Coastguard Worker                outputs=out,
1709*da0073e9SAndroid Build Coastguard Worker                grad_outputs=(grad_out,),
1710*da0073e9SAndroid Build Coastguard Worker                inputs=(x,),
1711*da0073e9SAndroid Build Coastguard Worker                is_grads_batched=True,
1712*da0073e9SAndroid Build Coastguard Worker            )
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker        # Scalar outputs
1715*da0073e9SAndroid Build Coastguard Worker        out = x.sum()  # Size([])
1716*da0073e9SAndroid Build Coastguard Worker        batched_grad = torch.arange(3)  # Size([3])
1717*da0073e9SAndroid Build Coastguard Worker        (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True)
1718*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
1719*da0073e9SAndroid Build Coastguard Worker            grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype)
1720*da0073e9SAndroid Build Coastguard Worker        )
1721*da0073e9SAndroid Build Coastguard Worker
1722*da0073e9SAndroid Build Coastguard Worker        # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior.
1723*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.ones(2).unsqueeze(1)
1724*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1725*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "If `is_grads_batched=True`, we interpret the first"
1726*da0073e9SAndroid Build Coastguard Worker        ):
1727*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(
1728*da0073e9SAndroid Build Coastguard Worker                outputs=out,
1729*da0073e9SAndroid Build Coastguard Worker                grad_outputs=(grad_out,),
1730*da0073e9SAndroid Build Coastguard Worker                inputs=(x,),
1731*da0073e9SAndroid Build Coastguard Worker                is_grads_batched=True,
1732*da0073e9SAndroid Build Coastguard Worker            )
1733*da0073e9SAndroid Build Coastguard Worker
1734*da0073e9SAndroid Build Coastguard Worker    def test_hooks(self):
1735*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
1736*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(5, 5) * 4
1737*da0073e9SAndroid Build Coastguard Worker        y.requires_grad_(True)
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker        counter = [0]
1740*da0073e9SAndroid Build Coastguard Worker
1741*da0073e9SAndroid Build Coastguard Worker        def bw_hook(inc, grad):
1742*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(grad, torch.Tensor)
1743*da0073e9SAndroid Build Coastguard Worker            counter[0] += inc
1744*da0073e9SAndroid Build Coastguard Worker
1745*da0073e9SAndroid Build Coastguard Worker        z = x**2 + x * 2 + x * y + y
1746*da0073e9SAndroid Build Coastguard Worker        x.register_hook(lambda *args: bw_hook(0, *args))
1747*da0073e9SAndroid Build Coastguard Worker        test = z.register_hook(lambda *args: bw_hook(1, *args))
1748*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(5, 5), retain_graph=True)
1749*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
1750*da0073e9SAndroid Build Coastguard Worker
1751*da0073e9SAndroid Build Coastguard Worker        test2 = z.register_hook(lambda *args: bw_hook(2, *args))
1752*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(5, 5), retain_graph=True)
1753*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 4)
1754*da0073e9SAndroid Build Coastguard Worker
1755*da0073e9SAndroid Build Coastguard Worker        test2.remove()
1756*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(5, 5), retain_graph=True)
1757*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 5)
1758*da0073e9SAndroid Build Coastguard Worker
1759*da0073e9SAndroid Build Coastguard Worker        def bw_hook_modify(grad):
1760*da0073e9SAndroid Build Coastguard Worker            return grad.mul(2)
1761*da0073e9SAndroid Build Coastguard Worker
1762*da0073e9SAndroid Build Coastguard Worker        test.remove()
1763*da0073e9SAndroid Build Coastguard Worker        z.register_hook(bw_hook_modify)
1764*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1765*da0073e9SAndroid Build Coastguard Worker            y.grad.zero_()
1766*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(5, 5), retain_graph=True)
1767*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, (x + 1) * 2)
1768*da0073e9SAndroid Build Coastguard Worker
1769*da0073e9SAndroid Build Coastguard Worker        y.register_hook(bw_hook_modify)
1770*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1771*da0073e9SAndroid Build Coastguard Worker            y.grad.zero_()
1772*da0073e9SAndroid Build Coastguard Worker        z.backward(torch.ones(5, 5))
1773*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, (x + 1) * 4)
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker    def _get_mul2(self, use_custom_function):
1776*da0073e9SAndroid Build Coastguard Worker        if use_custom_function:
1777*da0073e9SAndroid Build Coastguard Worker
1778*da0073e9SAndroid Build Coastguard Worker            class Mul2(Function):
1779*da0073e9SAndroid Build Coastguard Worker                @staticmethod
1780*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x):
1781*da0073e9SAndroid Build Coastguard Worker                    return x * 2
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker                @staticmethod
1784*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, gO):
1785*da0073e9SAndroid Build Coastguard Worker                    return gO * 2
1786*da0073e9SAndroid Build Coastguard Worker
1787*da0073e9SAndroid Build Coastguard Worker            return Mul2.apply
1788*da0073e9SAndroid Build Coastguard Worker        else:
1789*da0073e9SAndroid Build Coastguard Worker            return lambda x: x * 2
1790*da0073e9SAndroid Build Coastguard Worker
1791*da0073e9SAndroid Build Coastguard Worker    def test_grad_fn_prehooks(self):
1792*da0073e9SAndroid Build Coastguard Worker        for use_custom_function in (True, False):
1793*da0073e9SAndroid Build Coastguard Worker            mul2 = self._get_mul2(use_custom_function)
1794*da0073e9SAndroid Build Coastguard Worker
1795*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([1.0], requires_grad=True)
1796*da0073e9SAndroid Build Coastguard Worker            b = mul2(a)
1797*da0073e9SAndroid Build Coastguard Worker
1798*da0073e9SAndroid Build Coastguard Worker            post_counter = [0]
1799*da0073e9SAndroid Build Coastguard Worker            pre_counter = [0]
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker            def posthook(grad_input, grad_output):
1802*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(pre_counter[0], 3)
1803*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(grad_output[0], torch.ones(1) * 8))
1804*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(grad_input[0], torch.ones(1) * 16))
1805*da0073e9SAndroid Build Coastguard Worker                post_counter[0] += 1
1806*da0073e9SAndroid Build Coastguard Worker                return grad_input
1807*da0073e9SAndroid Build Coastguard Worker
1808*da0073e9SAndroid Build Coastguard Worker            def prehook(grad_output):
1809*da0073e9SAndroid Build Coastguard Worker                pre_counter[0] += 1
1810*da0073e9SAndroid Build Coastguard Worker                return (grad_output[0] * 2,)
1811*da0073e9SAndroid Build Coastguard Worker
1812*da0073e9SAndroid Build Coastguard Worker            # register posthook x 2
1813*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_hook(posthook)
1814*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_hook(posthook)
1815*da0073e9SAndroid Build Coastguard Worker            # register prehook x 3
1816*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(prehook)
1817*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(lambda x: None)
1818*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(prehook)
1819*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(prehook)
1820*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(lambda x: x)
1821*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(lambda x: None)
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker            b.sum().backward()
1824*da0073e9SAndroid Build Coastguard Worker
1825*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(post_counter[0], 2)
1826*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pre_counter[0], 3)
1827*da0073e9SAndroid Build Coastguard Worker
1828*da0073e9SAndroid Build Coastguard Worker            # Return None
1829*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3, 3, requires_grad=True)
1830*da0073e9SAndroid Build Coastguard Worker            b = mul2(a)
1831*da0073e9SAndroid Build Coastguard Worker
1832*da0073e9SAndroid Build Coastguard Worker            def prehook(grad_output):
1833*da0073e9SAndroid Build Coastguard Worker                pre_counter[0] += 1
1834*da0073e9SAndroid Build Coastguard Worker                return None
1835*da0073e9SAndroid Build Coastguard Worker
1836*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(prehook)
1837*da0073e9SAndroid Build Coastguard Worker            b.sum().backward()
1838*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pre_counter[0], 4)
1839*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker    def test_grad_fn_prehooks_multiple_outputs(self):
1842*da0073e9SAndroid Build Coastguard Worker        # Compute gradients without hooks
1843*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(3, 3, requires_grad=True)
1844*da0073e9SAndroid Build Coastguard Worker        var, mean = torch.var_mean(b, dim=0)
1845*da0073e9SAndroid Build Coastguard Worker        (var + mean).sum().backward()
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker        # Compute gradients with hooks
1848*da0073e9SAndroid Build Coastguard Worker        a = b.detach().requires_grad_()
1849*da0073e9SAndroid Build Coastguard Worker        counter = [0]
1850*da0073e9SAndroid Build Coastguard Worker
1851*da0073e9SAndroid Build Coastguard Worker        def prehook(grad_output):
1852*da0073e9SAndroid Build Coastguard Worker            gvar, gmean = grad_output
1853*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
1854*da0073e9SAndroid Build Coastguard Worker            return (gvar * 2, gmean * 2)
1855*da0073e9SAndroid Build Coastguard Worker
1856*da0073e9SAndroid Build Coastguard Worker        var, mean = torch.var_mean(a, dim=0)
1857*da0073e9SAndroid Build Coastguard Worker        mean.grad_fn.register_prehook(prehook)
1858*da0073e9SAndroid Build Coastguard Worker        (var + mean).sum().backward()
1859*da0073e9SAndroid Build Coastguard Worker
1860*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
1861*da0073e9SAndroid Build Coastguard Worker        # Compare
1862*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(a.grad, b.grad * 2))
1863*da0073e9SAndroid Build Coastguard Worker
1864*da0073e9SAndroid Build Coastguard Worker        # Test with custom Function
1865*da0073e9SAndroid Build Coastguard Worker        class DoubleMul2(Function):
1866*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1867*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, a, y):
1868*da0073e9SAndroid Build Coastguard Worker                ctx.a = a
1869*da0073e9SAndroid Build Coastguard Worker                return a * x * 2, a, a * y * 2
1870*da0073e9SAndroid Build Coastguard Worker
1871*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1872*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, g1, _a, g2):
1873*da0073e9SAndroid Build Coastguard Worker                return ctx.a * g1 * 2, None, ctx.a * g2 * 2
1874*da0073e9SAndroid Build Coastguard Worker
1875*da0073e9SAndroid Build Coastguard Worker        counter = [0]
1876*da0073e9SAndroid Build Coastguard Worker
1877*da0073e9SAndroid Build Coastguard Worker        def prehook(grad_output):
1878*da0073e9SAndroid Build Coastguard Worker            g1, ga, g2 = grad_output
1879*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(ga)
1880*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
1881*da0073e9SAndroid Build Coastguard Worker            return (g1 * 2, None, g2 * 2)
1882*da0073e9SAndroid Build Coastguard Worker
1883*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 3, requires_grad=True)
1884*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, 3, requires_grad=True)
1885*da0073e9SAndroid Build Coastguard Worker        k = 3
1886*da0073e9SAndroid Build Coastguard Worker        c, _, d = DoubleMul2.apply(a, k, b)
1887*da0073e9SAndroid Build Coastguard Worker        c.grad_fn.register_prehook(prehook)
1888*da0073e9SAndroid Build Coastguard Worker        (c + d).sum().backward()
1889*da0073e9SAndroid Build Coastguard Worker
1890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
1891*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(a.grad, torch.ones(1) * 4 * k))
1892*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(b.grad, torch.ones(1) * 4 * k))
1893*da0073e9SAndroid Build Coastguard Worker
1894*da0073e9SAndroid Build Coastguard Worker    def test_grad_fn_prehooks_remove_hooks(self):
1895*da0073e9SAndroid Build Coastguard Worker        for use_custom_function in (True, False):
1896*da0073e9SAndroid Build Coastguard Worker            mul2 = self._get_mul2(use_custom_function)
1897*da0073e9SAndroid Build Coastguard Worker
1898*da0073e9SAndroid Build Coastguard Worker            # Simply remove hooks
1899*da0073e9SAndroid Build Coastguard Worker
1900*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3, 3, requires_grad=True)
1901*da0073e9SAndroid Build Coastguard Worker            b = mul2(a)
1902*da0073e9SAndroid Build Coastguard Worker            counter = [0]
1903*da0073e9SAndroid Build Coastguard Worker
1904*da0073e9SAndroid Build Coastguard Worker            def prehook(grad_output):
1905*da0073e9SAndroid Build Coastguard Worker                counter[0] += 1
1906*da0073e9SAndroid Build Coastguard Worker                return None
1907*da0073e9SAndroid Build Coastguard Worker
1908*da0073e9SAndroid Build Coastguard Worker            handle = b.grad_fn.register_prehook(prehook)
1909*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(prehook)
1910*da0073e9SAndroid Build Coastguard Worker            handle.remove()
1911*da0073e9SAndroid Build Coastguard Worker            b.sum().backward()
1912*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
1913*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter[0], 1)
1914*da0073e9SAndroid Build Coastguard Worker
1915*da0073e9SAndroid Build Coastguard Worker            # Remove hooks during backward
1916*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3, 3, requires_grad=True)
1917*da0073e9SAndroid Build Coastguard Worker            b = mul2(a)
1918*da0073e9SAndroid Build Coastguard Worker            counter = [0]
1919*da0073e9SAndroid Build Coastguard Worker
1920*da0073e9SAndroid Build Coastguard Worker            def prehook1(grad_output):
1921*da0073e9SAndroid Build Coastguard Worker                handle2.remove()
1922*da0073e9SAndroid Build Coastguard Worker                # Remove hook that is already removed is OK
1923*da0073e9SAndroid Build Coastguard Worker                handle3.remove()
1924*da0073e9SAndroid Build Coastguard Worker                return None
1925*da0073e9SAndroid Build Coastguard Worker
1926*da0073e9SAndroid Build Coastguard Worker            def prehook2(grad_output):
1927*da0073e9SAndroid Build Coastguard Worker                counter[0] += 1
1928*da0073e9SAndroid Build Coastguard Worker                return None
1929*da0073e9SAndroid Build Coastguard Worker
1930*da0073e9SAndroid Build Coastguard Worker            # Hooks that registered first run first
1931*da0073e9SAndroid Build Coastguard Worker            b.grad_fn.register_prehook(prehook1)
1932*da0073e9SAndroid Build Coastguard Worker            handle2 = b.grad_fn.register_prehook(prehook2)
1933*da0073e9SAndroid Build Coastguard Worker            handle3 = b.grad_fn.register_prehook(prehook2)
1934*da0073e9SAndroid Build Coastguard Worker            handle3.remove()
1935*da0073e9SAndroid Build Coastguard Worker            b.sum().backward()
1936*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2))
1937*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter[0], 1)
1938*da0073e9SAndroid Build Coastguard Worker
1939*da0073e9SAndroid Build Coastguard Worker    def test_node_post_hook_registered_during_unpack_hook(self):
1940*da0073e9SAndroid Build Coastguard Worker        """
1941*da0073e9SAndroid Build Coastguard Worker        Test that post hooks registered during one of the node's
1942*da0073e9SAndroid Build Coastguard Worker        unpack hooks are properly restricted and will run properly.
1943*da0073e9SAndroid Build Coastguard Worker        """
1944*da0073e9SAndroid Build Coastguard Worker        test_case = self
1945*da0073e9SAndroid Build Coastguard Worker
1946*da0073e9SAndroid Build Coastguard Worker        class RegisterPostNodeHook(torch.autograd.graph.saved_tensors_hooks):
1947*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1948*da0073e9SAndroid Build Coastguard Worker                def pack_tensor(tensor: torch.Tensor) -> torch.Tensor:
1949*da0073e9SAndroid Build Coastguard Worker                    return tensor
1950*da0073e9SAndroid Build Coastguard Worker
1951*da0073e9SAndroid Build Coastguard Worker                def unpack_tensor(tensor: torch.Tensor) -> torch.Tensor:
1952*da0073e9SAndroid Build Coastguard Worker                    node = torch._C._current_autograd_node()
1953*da0073e9SAndroid Build Coastguard Worker
1954*da0073e9SAndroid Build Coastguard Worker                    def hook(outputs, inputs):
1955*da0073e9SAndroid Build Coastguard Worker                        # Assert that inputs passed in are None
1956*da0073e9SAndroid Build Coastguard Worker                        test_case.assertTrue(all(i is None for i in inputs))
1957*da0073e9SAndroid Build Coastguard Worker                        halved_outputs = tuple(
1958*da0073e9SAndroid Build Coastguard Worker                            o / 2.0 if o is not None else None for o in outputs
1959*da0073e9SAndroid Build Coastguard Worker                        )
1960*da0073e9SAndroid Build Coastguard Worker                        return halved_outputs
1961*da0073e9SAndroid Build Coastguard Worker
1962*da0073e9SAndroid Build Coastguard Worker                    node.register_hook(hook)
1963*da0073e9SAndroid Build Coastguard Worker                    return tensor
1964*da0073e9SAndroid Build Coastguard Worker
1965*da0073e9SAndroid Build Coastguard Worker                super().__init__(pack_tensor, unpack_tensor)
1966*da0073e9SAndroid Build Coastguard Worker
1967*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(3, 3, requires_grad=True)
1968*da0073e9SAndroid Build Coastguard Worker
1969*da0073e9SAndroid Build Coastguard Worker        def model():
1970*da0073e9SAndroid Build Coastguard Worker            var, mean = torch.var_mean(a, dim=0)
1971*da0073e9SAndroid Build Coastguard Worker            loss = (var + mean).sum()
1972*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1973*da0073e9SAndroid Build Coastguard Worker
1974*da0073e9SAndroid Build Coastguard Worker        model()
1975*da0073e9SAndroid Build Coastguard Worker        ref_grad = a.grad.clone()
1976*da0073e9SAndroid Build Coastguard Worker
1977*da0073e9SAndroid Build Coastguard Worker        with RegisterPostNodeHook():
1978*da0073e9SAndroid Build Coastguard Worker            model()
1979*da0073e9SAndroid Build Coastguard Worker
1980*da0073e9SAndroid Build Coastguard Worker        # Verify that the post hook got called and the grad propagation worked
1981*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(ref_grad / 2.0 + ref_grad, a.grad)
1982*da0073e9SAndroid Build Coastguard Worker
1983*da0073e9SAndroid Build Coastguard Worker    def test_hooks_cpp(self):
1984*da0073e9SAndroid Build Coastguard Worker        # Tests hooks for autograd function implemented in C++
1985*da0073e9SAndroid Build Coastguard Worker        bn = torch.nn.BatchNorm1d(5, affine=False)
1986*da0073e9SAndroid Build Coastguard Worker        bn.double()
1987*da0073e9SAndroid Build Coastguard Worker        bn.eval()
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Worker        counter = [0]
1990*da0073e9SAndroid Build Coastguard Worker
1991*da0073e9SAndroid Build Coastguard Worker        def bw_hook(grad):
1992*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
1993*da0073e9SAndroid Build Coastguard Worker            return grad * 2
1994*da0073e9SAndroid Build Coastguard Worker
1995*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, dtype=torch.double, requires_grad=True)
1996*da0073e9SAndroid Build Coastguard Worker        z = bn(x)
1997*da0073e9SAndroid Build Coastguard Worker        z.register_hook(bw_hook)
1998*da0073e9SAndroid Build Coastguard Worker        z.sum().backward()
1999*da0073e9SAndroid Build Coastguard Worker
2000*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1, msg="bw_hook not called")
2001*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2002*da0073e9SAndroid Build Coastguard Worker            x.grad, torch.ones(5, 5, dtype=torch.double) * 2, atol=1e-5, rtol=0
2003*da0073e9SAndroid Build Coastguard Worker        )
2004*da0073e9SAndroid Build Coastguard Worker
2005*da0073e9SAndroid Build Coastguard Worker    def test_hook_none(self):
2006*da0073e9SAndroid Build Coastguard Worker        # WARNING: this is a test for autograd internals.
2007*da0073e9SAndroid Build Coastguard Worker        # You should never have to use such things in your code.
2008*da0073e9SAndroid Build Coastguard Worker        class NoneGradientFunction(Function):
2009*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2010*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
2011*da0073e9SAndroid Build Coastguard Worker                assert ctx.needs_input_grad[0]
2012*da0073e9SAndroid Build Coastguard Worker                assert not ctx.needs_input_grad[1]
2013*da0073e9SAndroid Build Coastguard Worker                return x, y
2014*da0073e9SAndroid Build Coastguard Worker
2015*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2016*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_x, grad_y):
2017*da0073e9SAndroid Build Coastguard Worker                return grad_x, None
2018*da0073e9SAndroid Build Coastguard Worker
2019*da0073e9SAndroid Build Coastguard Worker        was_called = [False]
2020*da0073e9SAndroid Build Coastguard Worker
2021*da0073e9SAndroid Build Coastguard Worker        def hook(grad):
2022*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(grad)
2023*da0073e9SAndroid Build Coastguard Worker            was_called[0] = True
2024*da0073e9SAndroid Build Coastguard Worker
2025*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
2026*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5)
2027*da0073e9SAndroid Build Coastguard Worker        rx, ry = NoneGradientFunction.apply(x, y)
2028*da0073e9SAndroid Build Coastguard Worker        rx.register_hook(hook)
2029*da0073e9SAndroid Build Coastguard Worker        ry.register_hook(hook)
2030*da0073e9SAndroid Build Coastguard Worker        sum(rx, ry).sum().backward()
2031*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(was_called[0])
2032*da0073e9SAndroid Build Coastguard Worker
2033*da0073e9SAndroid Build Coastguard Worker    def test_retain_grad(self):
2034*da0073e9SAndroid Build Coastguard Worker        input = torch.rand(1, 3, requires_grad=True)
2035*da0073e9SAndroid Build Coastguard Worker        h1 = input * 3
2036*da0073e9SAndroid Build Coastguard Worker        out = (h1 * h1).sum()
2037*da0073e9SAndroid Build Coastguard Worker
2038*da0073e9SAndroid Build Coastguard Worker        # It should be possible to call retain_grad() multiple times
2039*da0073e9SAndroid Build Coastguard Worker        h1.retain_grad()
2040*da0073e9SAndroid Build Coastguard Worker        h1.retain_grad()
2041*da0073e9SAndroid Build Coastguard Worker
2042*da0073e9SAndroid Build Coastguard Worker        # Gradient should be accumulated
2043*da0073e9SAndroid Build Coastguard Worker        out.backward(retain_graph=True)
2044*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(h1 * 2, h1.grad)
2045*da0073e9SAndroid Build Coastguard Worker        out.backward(retain_graph=True)
2046*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(h1 * 4, h1.grad)
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2049*da0073e9SAndroid Build Coastguard Worker            input.grad.zero_()
2050*da0073e9SAndroid Build Coastguard Worker        # It should be a no-op for leaves
2051*da0073e9SAndroid Build Coastguard Worker        input.retain_grad()
2052*da0073e9SAndroid Build Coastguard Worker        input.retain_grad()
2053*da0073e9SAndroid Build Coastguard Worker        out.backward()
2054*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input * 18, input.grad)
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker    # NB: See test/cpp/api/autograd.cpp for more tests on the interaction between
2057*da0073e9SAndroid Build Coastguard Worker    #     retains_grad and hooks in cpp
2058*da0073e9SAndroid Build Coastguard Worker    def test_retain_grad_inplace(self):
2059*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0], requires_grad=True).clone()
2060*da0073e9SAndroid Build Coastguard Worker        a.retain_grad()
2061*da0073e9SAndroid Build Coastguard Worker        a.mul_(2)
2062*da0073e9SAndroid Build Coastguard Worker        a.sum().backward()
2063*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.tensor([1.0]))
2064*da0073e9SAndroid Build Coastguard Worker
2065*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0], requires_grad=True).clone()
2066*da0073e9SAndroid Build Coastguard Worker        a.retain_grad()
2067*da0073e9SAndroid Build Coastguard Worker        # Inplace multiple times is OK
2068*da0073e9SAndroid Build Coastguard Worker        a.mul_(2)
2069*da0073e9SAndroid Build Coastguard Worker        a.mul_(2)
2070*da0073e9SAndroid Build Coastguard Worker        a.sum().backward()
2071*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.tensor([1.0]))
2072*da0073e9SAndroid Build Coastguard Worker
2073*da0073e9SAndroid Build Coastguard Worker        # When in-place over view is done, the retains_grad hooks should be
2074*da0073e9SAndroid Build Coastguard Worker        # moved from base's original grad_fn to the copyslices node.
2075*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0], requires_grad=True).clone()
2076*da0073e9SAndroid Build Coastguard Worker        x.retain_grad()
2077*da0073e9SAndroid Build Coastguard Worker        x_view = x[:]
2078*da0073e9SAndroid Build Coastguard Worker        x_view *= 2
2079*da0073e9SAndroid Build Coastguard Worker        x *= 2
2080*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
2081*da0073e9SAndroid Build Coastguard Worker        # The grad is 1, not 4, because we are computing grad wrt the latest
2082*da0073e9SAndroid Build Coastguard Worker        # version of x.
2083*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.tensor([1.0]))
2084*da0073e9SAndroid Build Coastguard Worker
2085*da0073e9SAndroid Build Coastguard Worker        # If the base did not originally require grad, there should be no hook
2086*da0073e9SAndroid Build Coastguard Worker        # to move. Make sure this case runs without error.
2087*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(4)
2088*da0073e9SAndroid Build Coastguard Worker        y = x.view(2, 2)
2089*da0073e9SAndroid Build Coastguard Worker        y.add_(torch.randn(2, 2, requires_grad=True))
2090*da0073e9SAndroid Build Coastguard Worker
2091*da0073e9SAndroid Build Coastguard Worker    def test_retains_grad_inplace_multiple_outputs(self):
2092*da0073e9SAndroid Build Coastguard Worker        class DoubleMul(Function):
2093*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2094*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
2095*da0073e9SAndroid Build Coastguard Worker                return x * 2, x * 3
2096*da0073e9SAndroid Build Coastguard Worker
2097*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2098*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, g1, g2):
2099*da0073e9SAndroid Build Coastguard Worker                return g1 * 2 + g2 * 3
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker        var_mean = partial(torch.var_mean, dim=0)
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Worker        for fn in (DoubleMul.apply, var_mean):
2104*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(3, 3, requires_grad=True)
2105*da0073e9SAndroid Build Coastguard Worker            var, mean = fn(b)
2106*da0073e9SAndroid Build Coastguard Worker            var.retain_grad()
2107*da0073e9SAndroid Build Coastguard Worker            mean.retain_grad()
2108*da0073e9SAndroid Build Coastguard Worker            # node has two retains_grad hooks
2109*da0073e9SAndroid Build Coastguard Worker            var.mul_(2)
2110*da0073e9SAndroid Build Coastguard Worker            # the retain_grad hook multi-output node refers should now be a nullptr
2111*da0073e9SAndroid Build Coastguard Worker            (var + mean).sum().backward()
2112*da0073e9SAndroid Build Coastguard Worker            gvar = var.grad
2113*da0073e9SAndroid Build Coastguard Worker            gmean = mean.grad
2114*da0073e9SAndroid Build Coastguard Worker
2115*da0073e9SAndroid Build Coastguard Worker            a = b.detach().requires_grad_(True)
2116*da0073e9SAndroid Build Coastguard Worker            var, mean = fn(a)
2117*da0073e9SAndroid Build Coastguard Worker            var.mul_(2)
2118*da0073e9SAndroid Build Coastguard Worker            out = (var + mean).sum()
2119*da0073e9SAndroid Build Coastguard Worker            gvar_expected, gmean_expected = torch.autograd.grad(out, inputs=(var, mean))
2120*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(gvar, gvar_expected))
2121*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(gmean, gmean_expected))
2122*da0073e9SAndroid Build Coastguard Worker
2123*da0073e9SAndroid Build Coastguard Worker    def test_retain_grad_inplace_over_view(self):
2124*da0073e9SAndroid Build Coastguard Worker        base = torch.tensor([1.0], requires_grad=True).clone()
2125*da0073e9SAndroid Build Coastguard Worker        view = base[:]
2126*da0073e9SAndroid Build Coastguard Worker        view2 = base[:]
2127*da0073e9SAndroid Build Coastguard Worker        view.retain_grad()
2128*da0073e9SAndroid Build Coastguard Worker        view2.retain_grad()
2129*da0073e9SAndroid Build Coastguard Worker        view.mul_(2)
2130*da0073e9SAndroid Build Coastguard Worker        (view + view2).sum().backward()
2131*da0073e9SAndroid Build Coastguard Worker
2132*da0073e9SAndroid Build Coastguard Worker        # The old grad_fn, slice, wouldn't be part of the graph during backward
2133*da0073e9SAndroid Build Coastguard Worker        # so if the retains grad were not properly updated to the new grad_fn,
2134*da0073e9SAndroid Build Coastguard Worker        # the grad would still be None
2135*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(view.grad, view2.grad)
2136*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(view.grad, torch.tensor([1.0]))
2137*da0073e9SAndroid Build Coastguard Worker
2138*da0073e9SAndroid Build Coastguard Worker    def test_tensor_hooks_inplace(self):
2139*da0073e9SAndroid Build Coastguard Worker        # Check that the second hook gets registered to the new version of tensor
2140*da0073e9SAndroid Build Coastguard Worker        count1 = [0]
2141*da0073e9SAndroid Build Coastguard Worker        count2 = [0]
2142*da0073e9SAndroid Build Coastguard Worker
2143*da0073e9SAndroid Build Coastguard Worker        def fn1(grad):
2144*da0073e9SAndroid Build Coastguard Worker            count1[0] += 1
2145*da0073e9SAndroid Build Coastguard Worker            # x2 from mul, x2 from fn2
2146*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad, torch.tensor([4.0]))
2147*da0073e9SAndroid Build Coastguard Worker            return grad * 2
2148*da0073e9SAndroid Build Coastguard Worker
2149*da0073e9SAndroid Build Coastguard Worker        def fn2(grad):
2150*da0073e9SAndroid Build Coastguard Worker            count2[0] += 1
2151*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad, torch.tensor([1.0]))
2152*da0073e9SAndroid Build Coastguard Worker            return grad * 2
2153*da0073e9SAndroid Build Coastguard Worker
2154*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0], requires_grad=True)
2155*da0073e9SAndroid Build Coastguard Worker        b = a.clone()
2156*da0073e9SAndroid Build Coastguard Worker        b.register_hook(fn1)
2157*da0073e9SAndroid Build Coastguard Worker        b.mul_(2)
2158*da0073e9SAndroid Build Coastguard Worker        b.register_hook(fn2)
2159*da0073e9SAndroid Build Coastguard Worker        b.sum().backward()
2160*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count1[0], 1)
2161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count2[0], 1)
2162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.tensor([8.0]))
2163*da0073e9SAndroid Build Coastguard Worker
2164*da0073e9SAndroid Build Coastguard Worker        count3 = [0]
2165*da0073e9SAndroid Build Coastguard Worker
2166*da0073e9SAndroid Build Coastguard Worker        def fn3(grad):
2167*da0073e9SAndroid Build Coastguard Worker            count3[0] += 1
2168*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad, torch.tensor([4.0]))
2169*da0073e9SAndroid Build Coastguard Worker            return grad * 2
2170*da0073e9SAndroid Build Coastguard Worker
2171*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0], requires_grad=True)
2172*da0073e9SAndroid Build Coastguard Worker        b = a.clone()
2173*da0073e9SAndroid Build Coastguard Worker        b.register_hook(fn3)
2174*da0073e9SAndroid Build Coastguard Worker        # Inplace multiple times is OK
2175*da0073e9SAndroid Build Coastguard Worker        b.mul_(2)
2176*da0073e9SAndroid Build Coastguard Worker        b.mul_(2)
2177*da0073e9SAndroid Build Coastguard Worker        b.sum().backward()
2178*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count1[0], 1)
2179*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.tensor([8.0]))
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker    def test_tensor_hooks_inplace_multiple_outputs(self):
2182*da0073e9SAndroid Build Coastguard Worker        class DoubleMul(Function):
2183*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2184*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
2185*da0073e9SAndroid Build Coastguard Worker                return x * 2, x * 3
2186*da0073e9SAndroid Build Coastguard Worker
2187*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2188*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, g1, g2):
2189*da0073e9SAndroid Build Coastguard Worker                return g1 * 2 + g2 * 3
2190*da0073e9SAndroid Build Coastguard Worker
2191*da0073e9SAndroid Build Coastguard Worker        var_mean = partial(torch.var_mean, dim=0)
2192*da0073e9SAndroid Build Coastguard Worker
2193*da0073e9SAndroid Build Coastguard Worker        for fn in (DoubleMul.apply, var_mean):
2194*da0073e9SAndroid Build Coastguard Worker            counts = [0, 0, 0]
2195*da0073e9SAndroid Build Coastguard Worker
2196*da0073e9SAndroid Build Coastguard Worker            def fn0(grad):
2197*da0073e9SAndroid Build Coastguard Worker                counts[0] += 1
2198*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, torch.ones_like(out1) * 2)
2199*da0073e9SAndroid Build Coastguard Worker
2200*da0073e9SAndroid Build Coastguard Worker            def fn1(grad):
2201*da0073e9SAndroid Build Coastguard Worker                counts[1] += 1
2202*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, torch.ones_like(out1) * 3)
2203*da0073e9SAndroid Build Coastguard Worker
2204*da0073e9SAndroid Build Coastguard Worker            def fn2(grad):
2205*da0073e9SAndroid Build Coastguard Worker                counts[2] += 1
2206*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(grad, torch.ones_like(out1))
2207*da0073e9SAndroid Build Coastguard Worker
2208*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(3, 3, requires_grad=True)
2209*da0073e9SAndroid Build Coastguard Worker            out1, out2 = fn(b)
2210*da0073e9SAndroid Build Coastguard Worker            out1.register_hook(fn0)
2211*da0073e9SAndroid Build Coastguard Worker            out2.register_hook(fn1)
2212*da0073e9SAndroid Build Coastguard Worker            # node refers to two hook dicts
2213*da0073e9SAndroid Build Coastguard Worker            # out1 no longer no longer points to its old hook dict
2214*da0073e9SAndroid Build Coastguard Worker            out1.mul_(2)
2215*da0073e9SAndroid Build Coastguard Worker            # fn2 is registered to out1's new hook dict
2216*da0073e9SAndroid Build Coastguard Worker            out1.register_hook(fn2)
2217*da0073e9SAndroid Build Coastguard Worker            (out1 + out2 * 3).sum().backward()
2218*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counts, [1, 1, 1])
2219*da0073e9SAndroid Build Coastguard Worker
2220*da0073e9SAndroid Build Coastguard Worker    def test_tensor_hooks_inplace_over_view(self):
2221*da0073e9SAndroid Build Coastguard Worker        # There might be a better UX here, but this is the way it is now
2222*da0073e9SAndroid Build Coastguard Worker        count = [0]
2223*da0073e9SAndroid Build Coastguard Worker
2224*da0073e9SAndroid Build Coastguard Worker        def fn0(grad):
2225*da0073e9SAndroid Build Coastguard Worker            self.fail()
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker        def fn1(grad):
2228*da0073e9SAndroid Build Coastguard Worker            self.fail()
2229*da0073e9SAndroid Build Coastguard Worker
2230*da0073e9SAndroid Build Coastguard Worker        def fn2(grad):
2231*da0073e9SAndroid Build Coastguard Worker            count[0] += 1
2232*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad, torch.tensor([1.0]))
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker        base = torch.tensor([1.0], requires_grad=True).clone()
2235*da0073e9SAndroid Build Coastguard Worker        view = base[:]
2236*da0073e9SAndroid Build Coastguard Worker        view2 = base[:]
2237*da0073e9SAndroid Build Coastguard Worker        view.register_hook(fn0)
2238*da0073e9SAndroid Build Coastguard Worker        view2.register_hook(fn1)
2239*da0073e9SAndroid Build Coastguard Worker        view.mul_(2)
2240*da0073e9SAndroid Build Coastguard Worker        # We need to explicitly trigger an update to view to update its grad_fn
2241*da0073e9SAndroid Build Coastguard Worker        view2.grad_fn
2242*da0073e9SAndroid Build Coastguard Worker        view2.register_hook(fn2)
2243*da0073e9SAndroid Build Coastguard Worker        (view + view2).sum().backward()
2244*da0073e9SAndroid Build Coastguard Worker        # The hooks originally registered to view are not fired, one must explicitly
2245*da0073e9SAndroid Build Coastguard Worker        # trigger an update to the view's grad_fn, and then register a new hook
2246*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 1)
2247*da0073e9SAndroid Build Coastguard Worker
2248*da0073e9SAndroid Build Coastguard Worker    def test_retain_grad_cycle(self):
2249*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
2250*da0073e9SAndroid Build Coastguard Worker
2251*da0073e9SAndroid Build Coastguard Worker        def run_test():
2252*da0073e9SAndroid Build Coastguard Worker            y = x * 2
2253*da0073e9SAndroid Build Coastguard Worker            y.retain_grad()
2254*da0073e9SAndroid Build Coastguard Worker
2255*da0073e9SAndroid Build Coastguard Worker            return y / 2, torch._C._WeakTensorRef(y)
2256*da0073e9SAndroid Build Coastguard Worker
2257*da0073e9SAndroid Build Coastguard Worker        z, ref = run_test()
2258*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref.expired())
2259*da0073e9SAndroid Build Coastguard Worker        z.sum().backward()
2260*da0073e9SAndroid Build Coastguard Worker
2261*da0073e9SAndroid Build Coastguard Worker    def test_backward(self):
2262*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(5, 5, requires_grad=True)
2263*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
2264*da0073e9SAndroid Build Coastguard Worker        y = (torch.rand(5, 5) + 0.1).requires_grad_(True)
2265*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(5, 5, requires_grad=True)
2266*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(5, 5)
2267*da0073e9SAndroid Build Coastguard Worker
2268*da0073e9SAndroid Build Coastguard Worker        v.backward(grad_output)
2269*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, grad_output)
2270*da0073e9SAndroid Build Coastguard Worker
2271*da0073e9SAndroid Build Coastguard Worker        a = x + (y * z) + 4 * z**2 * x / y
2272*da0073e9SAndroid Build Coastguard Worker        a.backward(grad_output)
2273*da0073e9SAndroid Build Coastguard Worker        x_grad = 4 * z.pow(2) / y + 1
2274*da0073e9SAndroid Build Coastguard Worker        y_grad = z - 4 * x * z.pow(2) / y.pow(2)
2275*da0073e9SAndroid Build Coastguard Worker        z_grad = 8 * x * z / y + y
2276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad * grad_output)
2277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad * grad_output)
2278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.grad, z_grad * grad_output)
2279*da0073e9SAndroid Build Coastguard Worker
2280*da0073e9SAndroid Build Coastguard Worker    def test_to_sparse_backward(self):
2281*da0073e9SAndroid Build Coastguard Worker        to_attr_names = (
2282*da0073e9SAndroid Build Coastguard Worker            "to_dense",
2283*da0073e9SAndroid Build Coastguard Worker            "to_sparse",
2284*da0073e9SAndroid Build Coastguard Worker            "to_sparse_csr",
2285*da0073e9SAndroid Build Coastguard Worker            "to_sparse_csc",
2286*da0073e9SAndroid Build Coastguard Worker            "to_sparse_bsr",
2287*da0073e9SAndroid Build Coastguard Worker            "to_sparse_bsc",
2288*da0073e9SAndroid Build Coastguard Worker        )
2289*da0073e9SAndroid Build Coastguard Worker        to_params = ((), (), (), (), (2,), (2,))
2290*da0073e9SAndroid Build Coastguard Worker        to_attr_names_params = dict(zip(to_attr_names, to_params))
2291*da0073e9SAndroid Build Coastguard Worker
2292*da0073e9SAndroid Build Coastguard Worker        def check_inversion_possible(
2293*da0073e9SAndroid Build Coastguard Worker            t, layout1, layout1_params, layout2, layout2_params
2294*da0073e9SAndroid Build Coastguard Worker        ):
2295*da0073e9SAndroid Build Coastguard Worker            l = (layout1, layout2)
2296*da0073e9SAndroid Build Coastguard Worker            p = (layout1_params, layout2_params)
2297*da0073e9SAndroid Build Coastguard Worker            for l1, l2, p1, p2 in ((*l, *p), (*l[::-1], *p[::-1])):
2298*da0073e9SAndroid Build Coastguard Worker                try:
2299*da0073e9SAndroid Build Coastguard Worker                    to_l1 = getattr(t, l1)(*p1)
2300*da0073e9SAndroid Build Coastguard Worker                    to_l2 = getattr(to_l1, l2)(*p2)
2301*da0073e9SAndroid Build Coastguard Worker                except RuntimeError:
2302*da0073e9SAndroid Build Coastguard Worker                    return False
2303*da0073e9SAndroid Build Coastguard Worker
2304*da0073e9SAndroid Build Coastguard Worker            return True
2305*da0073e9SAndroid Build Coastguard Worker
2306*da0073e9SAndroid Build Coastguard Worker        self_strided = torch.rand(4, 4, dtype=torch.double) + 1
2307*da0073e9SAndroid Build Coastguard Worker        grad_strided = torch.rand(4, 4, dtype=torch.double) + 1
2308*da0073e9SAndroid Build Coastguard Worker
2309*da0073e9SAndroid Build Coastguard Worker        for from_to_attr in to_attr_names:
2310*da0073e9SAndroid Build Coastguard Worker            from_params = to_attr_names_params[from_to_attr]
2311*da0073e9SAndroid Build Coastguard Worker            self_from = getattr(self_strided, from_to_attr)(
2312*da0073e9SAndroid Build Coastguard Worker                *from_params
2313*da0073e9SAndroid Build Coastguard Worker            ).requires_grad_(True)
2314*da0073e9SAndroid Build Coastguard Worker
2315*da0073e9SAndroid Build Coastguard Worker            for to_to_attr in to_attr_names[1:]:
2316*da0073e9SAndroid Build Coastguard Worker                to_params = to_attr_names_params[to_to_attr]
2317*da0073e9SAndroid Build Coastguard Worker
2318*da0073e9SAndroid Build Coastguard Worker                if check_inversion_possible(
2319*da0073e9SAndroid Build Coastguard Worker                    self_strided, from_to_attr, from_params, to_to_attr, to_params
2320*da0073e9SAndroid Build Coastguard Worker                ):
2321*da0073e9SAndroid Build Coastguard Worker                    self_to = getattr(self_from, to_to_attr)(*to_params)
2322*da0073e9SAndroid Build Coastguard Worker                    grad_to = getattr(grad_strided, to_to_attr)(*to_params)
2323*da0073e9SAndroid Build Coastguard Worker
2324*da0073e9SAndroid Build Coastguard Worker                    # No gradcheck support for BSR/BSC, so the grads are checked explicitly
2325*da0073e9SAndroid Build Coastguard Worker                    grad_res = torch.autograd.grad(self_to, self_from, grad_to)[0]
2326*da0073e9SAndroid Build Coastguard Worker
2327*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad_res.layout, self_from.layout)
2328*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(grad_res.to_dense(), grad_strided)
2329*da0073e9SAndroid Build Coastguard Worker
2330*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mm_backward(self):
2331*da0073e9SAndroid Build Coastguard Worker        size = (3, 3)
2332*da0073e9SAndroid Build Coastguard Worker
2333*da0073e9SAndroid Build Coastguard Worker        mm_test_cases = product(*(([False, True],) * 4))
2334*da0073e9SAndroid Build Coastguard Worker
2335*da0073e9SAndroid Build Coastguard Worker        for a_req_grad, a_is_sparse, b_req_grad, b_is_sparse in mm_test_cases:
2336*da0073e9SAndroid Build Coastguard Worker            # We should only be testing cases with sparse inputs, and at least one
2337*da0073e9SAndroid Build Coastguard Worker            # input needs to require grad so we can call a backward pass
2338*da0073e9SAndroid Build Coastguard Worker            if not ((a_is_sparse or b_is_sparse) and (a_req_grad or b_req_grad)):
2339*da0073e9SAndroid Build Coastguard Worker                continue
2340*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(size)
2341*da0073e9SAndroid Build Coastguard Worker            if a_is_sparse:
2342*da0073e9SAndroid Build Coastguard Worker                # detaching as `a` needs to be a leaf
2343*da0073e9SAndroid Build Coastguard Worker                a = a.to_sparse().detach()
2344*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(size)
2345*da0073e9SAndroid Build Coastguard Worker            if b_is_sparse:
2346*da0073e9SAndroid Build Coastguard Worker                # detaching as `b` needs to be a leaf
2347*da0073e9SAndroid Build Coastguard Worker                b = b.to_sparse().detach()
2348*da0073e9SAndroid Build Coastguard Worker
2349*da0073e9SAndroid Build Coastguard Worker            a = a.requires_grad_(a_req_grad)
2350*da0073e9SAndroid Build Coastguard Worker            b = b.requires_grad_(b_req_grad)
2351*da0073e9SAndroid Build Coastguard Worker
2352*da0073e9SAndroid Build Coastguard Worker            r = a.mm(b)
2353*da0073e9SAndroid Build Coastguard Worker            s = r.sum().backward()
2354*da0073e9SAndroid Build Coastguard Worker            a_grad = None if a.grad is None else a.grad.clone().detach()
2355*da0073e9SAndroid Build Coastguard Worker            b_grad = None if b.grad is None else b.grad.clone().detach()
2356*da0073e9SAndroid Build Coastguard Worker
2357*da0073e9SAndroid Build Coastguard Worker            # Redo with only dense tensors
2358*da0073e9SAndroid Build Coastguard Worker            a = (
2359*da0073e9SAndroid Build Coastguard Worker                (a.to_dense() if a.is_sparse else a)
2360*da0073e9SAndroid Build Coastguard Worker                .clone()
2361*da0073e9SAndroid Build Coastguard Worker                .detach()
2362*da0073e9SAndroid Build Coastguard Worker                .requires_grad_(a_req_grad)
2363*da0073e9SAndroid Build Coastguard Worker            )
2364*da0073e9SAndroid Build Coastguard Worker            b = (
2365*da0073e9SAndroid Build Coastguard Worker                (b.to_dense() if b.is_sparse else b)
2366*da0073e9SAndroid Build Coastguard Worker                .clone()
2367*da0073e9SAndroid Build Coastguard Worker                .detach()
2368*da0073e9SAndroid Build Coastguard Worker                .requires_grad_(b_req_grad)
2369*da0073e9SAndroid Build Coastguard Worker            )
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker            r = a.mm(b)
2372*da0073e9SAndroid Build Coastguard Worker            r.sum().backward()
2373*da0073e9SAndroid Build Coastguard Worker
2374*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a_grad, a.grad)
2375*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(b_grad, b.grad)
2376*da0073e9SAndroid Build Coastguard Worker
2377*da0073e9SAndroid Build Coastguard Worker    def test_multi_backward(self):
2378*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
2379*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
2380*da0073e9SAndroid Build Coastguard Worker
2381*da0073e9SAndroid Build Coastguard Worker        q = torch.randn(5, 5, requires_grad=True)
2382*da0073e9SAndroid Build Coastguard Worker
2383*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5, requires_grad=True)
2384*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5, requires_grad=True)
2385*da0073e9SAndroid Build Coastguard Worker
2386*da0073e9SAndroid Build Coastguard Worker        q2 = q * 2
2387*da0073e9SAndroid Build Coastguard Worker        z = x + y + q2
2388*da0073e9SAndroid Build Coastguard Worker        c = a * b + q2
2389*da0073e9SAndroid Build Coastguard Worker        grad_z = torch.randn(5, 5)
2390*da0073e9SAndroid Build Coastguard Worker        grad_c = torch.randn(5, 5)
2391*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([z, c], [grad_z, grad_c])
2392*da0073e9SAndroid Build Coastguard Worker
2393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, grad_z)
2394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, grad_z)
2395*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, grad_c * b)
2396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, grad_c * a)
2397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(q.grad, (grad_c + grad_z) * 2)
2398*da0073e9SAndroid Build Coastguard Worker
2399*da0073e9SAndroid Build Coastguard Worker    def test_multi_backward_no_grad(self):
2400*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
2401*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=False)
2402*da0073e9SAndroid Build Coastguard Worker
2403*da0073e9SAndroid Build Coastguard Worker        z = x + y
2404*da0073e9SAndroid Build Coastguard Worker        q = y * 2
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker        # NB: we currently raise an exception if any arguments to backwards
2407*da0073e9SAndroid Build Coastguard Worker        # have requires_grad=False and don't have a grad_fn. We may want to
2408*da0073e9SAndroid Build Coastguard Worker        # relax that check to a warning.
2409*da0073e9SAndroid Build Coastguard Worker        def call_backwards():
2410*da0073e9SAndroid Build Coastguard Worker            torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)])
2411*da0073e9SAndroid Build Coastguard Worker
2412*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, call_backwards)
2413*da0073e9SAndroid Build Coastguard Worker
2414*da0073e9SAndroid Build Coastguard Worker    def test_backward_with_inputs(self):
2415*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2416*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2417*da0073e9SAndroid Build Coastguard Worker
2418*da0073e9SAndroid Build Coastguard Worker        def fn():
2419*da0073e9SAndroid Build Coastguard Worker            return x**2 + y * x + y**2
2420*da0073e9SAndroid Build Coastguard Worker
2421*da0073e9SAndroid Build Coastguard Worker        gradient = torch.ones(2, 2)
2422*da0073e9SAndroid Build Coastguard Worker        x_grad_expected = 2 * x + y
2423*da0073e9SAndroid Build Coastguard Worker        y_grad_expected = x + 2 * y
2424*da0073e9SAndroid Build Coastguard Worker
2425*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2426*da0073e9SAndroid Build Coastguard Worker        def reset_grad():
2427*da0073e9SAndroid Build Coastguard Worker            x.grad.zero_()
2428*da0073e9SAndroid Build Coastguard Worker            y.grad.zero_()
2429*da0073e9SAndroid Build Coastguard Worker
2430*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(fn(), gradient, inputs=[x, y])
2431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad_expected)
2432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad_expected)
2433*da0073e9SAndroid Build Coastguard Worker
2434*da0073e9SAndroid Build Coastguard Worker        reset_grad()
2435*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(fn(), gradient, inputs=[x])
2436*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad_expected)
2437*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, torch.zeros(2, 2), exact_dtype=False)
2438*da0073e9SAndroid Build Coastguard Worker
2439*da0073e9SAndroid Build Coastguard Worker        reset_grad()
2440*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(fn(), gradient, inputs=[y])
2441*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad_expected)
2442*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
2443*da0073e9SAndroid Build Coastguard Worker
2444*da0073e9SAndroid Build Coastguard Worker        reset_grad()
2445*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(fn(), gradient, inputs=y)
2446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad_expected)
2447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False)
2448*da0073e9SAndroid Build Coastguard Worker
2449*da0073e9SAndroid Build Coastguard Worker        reset_grad()
2450*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
2451*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
2452*da0073e9SAndroid Build Coastguard Worker            "cannot be empty",
2453*da0073e9SAndroid Build Coastguard Worker            lambda: torch.autograd.backward(fn(), gradient, inputs=[]),
2454*da0073e9SAndroid Build Coastguard Worker        )
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker    def test_backward_with_nonleaf_inputs(self):
2457*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2458*da0073e9SAndroid Build Coastguard Worker        x_nonleaf = x * 1
2459*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2460*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
2461*da0073e9SAndroid Build Coastguard Worker
2462*da0073e9SAndroid Build Coastguard Worker        out = x_nonleaf**2 + y * x_nonleaf + y**2
2463*da0073e9SAndroid Build Coastguard Worker
2464*da0073e9SAndroid Build Coastguard Worker        out.backward(
2465*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, dtype=torch.double),
2466*da0073e9SAndroid Build Coastguard Worker            create_graph=True,
2467*da0073e9SAndroid Build Coastguard Worker            inputs=[x, y, x_nonleaf],
2468*da0073e9SAndroid Build Coastguard Worker        )
2469*da0073e9SAndroid Build Coastguard Worker        x_grad_expected = 2 * x + y
2470*da0073e9SAndroid Build Coastguard Worker        y_grad_expected = x + 2 * y
2471*da0073e9SAndroid Build Coastguard Worker        x_non_leaf_expected = 2 * x_nonleaf + y
2472*da0073e9SAndroid Build Coastguard Worker
2473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, y_grad_expected)
2474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, x_grad_expected)
2475*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_nonleaf.grad, x_non_leaf_expected)
2476*da0073e9SAndroid Build Coastguard Worker
2477*da0073e9SAndroid Build Coastguard Worker        # backward doesn't have an allow_unused flag, so the behavior of backward
2478*da0073e9SAndroid Build Coastguard Worker        # when variable is not part of the graph is as if allow_used were true
2479*da0073e9SAndroid Build Coastguard Worker        # x.grad will simply be None.
2480*da0073e9SAndroid Build Coastguard Worker        out.backward(
2481*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[z]
2482*da0073e9SAndroid Build Coastguard Worker        )
2483*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(z.grad)
2484*da0073e9SAndroid Build Coastguard Worker
2485*da0073e9SAndroid Build Coastguard Worker    def test_dependent_backward(self):
2486*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, requires_grad=True)
2487*da0073e9SAndroid Build Coastguard Worker        y = x**2
2488*da0073e9SAndroid Build Coastguard Worker        z = y**3
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        go_y = torch.randn(10)
2491*da0073e9SAndroid Build Coastguard Worker        go_z = torch.randn(10)
2492*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([y, z], [go_y, go_z])
2493*da0073e9SAndroid Build Coastguard Worker
2494*da0073e9SAndroid Build Coastguard Worker        xd = x
2495*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, 2 * xd * go_y + 6 * xd.pow(5) * go_z)
2496*da0073e9SAndroid Build Coastguard Worker
2497*da0073e9SAndroid Build Coastguard Worker    def test_save_output_nr(self):
2498*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, requires_grad=True)
2499*da0073e9SAndroid Build Coastguard Worker
2500*da0073e9SAndroid Build Coastguard Worker        class MultiOutputFn(Function):
2501*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2502*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
2503*da0073e9SAndroid Build Coastguard Worker                return x[:5], x[5:]
2504*da0073e9SAndroid Build Coastguard Worker
2505*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2506*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, *grad):
2507*da0073e9SAndroid Build Coastguard Worker                return torch.cat(grad)
2508*da0073e9SAndroid Build Coastguard Worker
2509*da0073e9SAndroid Build Coastguard Worker        a, b = MultiOutputFn.apply(x)
2510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.output_nr, 1)
2511*da0073e9SAndroid Build Coastguard Worker
2512*da0073e9SAndroid Build Coastguard Worker        class TestFn(Function):
2513*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2514*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, b):
2515*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(b)
2516*da0073e9SAndroid Build Coastguard Worker                return b * 2
2517*da0073e9SAndroid Build Coastguard Worker
2518*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2519*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_b):
2520*da0073e9SAndroid Build Coastguard Worker                (b,) = ctx.saved_tensors
2521*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.output_nr, 1)
2522*da0073e9SAndroid Build Coastguard Worker
2523*da0073e9SAndroid Build Coastguard Worker        TestFn.apply(b).sum().backward()
2524*da0073e9SAndroid Build Coastguard Worker
2525*da0073e9SAndroid Build Coastguard Worker    def test_first_grad_fn_access_in_no_grad_mode(self):
2526*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1 + 1j], requires_grad=True).clone()
2527*da0073e9SAndroid Build Coastguard Worker        v = a.real
2528*da0073e9SAndroid Build Coastguard Worker        a.add_(1)
2529*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.grad_mode.no_grad():
2530*da0073e9SAndroid Build Coastguard Worker            v.grad_fn
2531*da0073e9SAndroid Build Coastguard Worker
2532*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("too slow")
2533*da0073e9SAndroid Build Coastguard Worker    def test_free_deep_graph(self):
2534*da0073e9SAndroid Build Coastguard Worker        def scope():
2535*da0073e9SAndroid Build Coastguard Worker            depth = 150000
2536*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, requires_grad=True)
2537*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
2538*da0073e9SAndroid Build Coastguard Worker
2539*da0073e9SAndroid Build Coastguard Worker            # build a "chain" computation graph
2540*da0073e9SAndroid Build Coastguard Worker            for _ in range(depth):
2541*da0073e9SAndroid Build Coastguard Worker                y = y + y * 0.000001
2542*da0073e9SAndroid Build Coastguard Worker
2543*da0073e9SAndroid Build Coastguard Worker            # graph deletion occurs when the above locals go out of scope.
2544*da0073e9SAndroid Build Coastguard Worker            # In this case `del y` will trigger it but it's easier to leave
2545*da0073e9SAndroid Build Coastguard Worker            # it to Python to delete the locals.
2546*da0073e9SAndroid Build Coastguard Worker
2547*da0073e9SAndroid Build Coastguard Worker        # Should not stack overflow
2548*da0073e9SAndroid Build Coastguard Worker        scope()
2549*da0073e9SAndroid Build Coastguard Worker
2550*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("too slow")
2551*da0073e9SAndroid Build Coastguard Worker    def test_free_deep_graph_complicated(self):
2552*da0073e9SAndroid Build Coastguard Worker        def scope():
2553*da0073e9SAndroid Build Coastguard Worker            depth = 100000
2554*da0073e9SAndroid Build Coastguard Worker            randchoice = torch.randint(2, [depth, 2])
2555*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, requires_grad=True)
2556*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
2557*da0073e9SAndroid Build Coastguard Worker
2558*da0073e9SAndroid Build Coastguard Worker            # Hold the two previous values
2559*da0073e9SAndroid Build Coastguard Worker            prev_values = [None, None]
2560*da0073e9SAndroid Build Coastguard Worker
2561*da0073e9SAndroid Build Coastguard Worker            # Build a "chain with skip connections" graph
2562*da0073e9SAndroid Build Coastguard Worker            for _ in range(depth):
2563*da0073e9SAndroid Build Coastguard Worker                prev_tensors = [
2564*da0073e9SAndroid Build Coastguard Worker                    tensor for tensor in prev_values[:-1] if tensor is not None
2565*da0073e9SAndroid Build Coastguard Worker                ]
2566*da0073e9SAndroid Build Coastguard Worker                prev_values.append(y)
2567*da0073e9SAndroid Build Coastguard Worker                prev_values.pop(0)
2568*da0073e9SAndroid Build Coastguard Worker
2569*da0073e9SAndroid Build Coastguard Worker                # Definitely pick one tensor to add
2570*da0073e9SAndroid Build Coastguard Worker                y += y * 0.000001
2571*da0073e9SAndroid Build Coastguard Worker
2572*da0073e9SAndroid Build Coastguard Worker                # Possibly add other tensors
2573*da0073e9SAndroid Build Coastguard Worker                nprev = len(prev_tensors)
2574*da0073e9SAndroid Build Coastguard Worker                if nprev == 2:
2575*da0073e9SAndroid Build Coastguard Worker                    y += randchoice[depth].mul(torch.cat(prev_tensors)).sum()
2576*da0073e9SAndroid Build Coastguard Worker
2577*da0073e9SAndroid Build Coastguard Worker            # graph deletion occurs when the above locals go out of scope.
2578*da0073e9SAndroid Build Coastguard Worker
2579*da0073e9SAndroid Build Coastguard Worker        # Should not stack overflow
2580*da0073e9SAndroid Build Coastguard Worker        scope()
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("too slow")
2583*da0073e9SAndroid Build Coastguard Worker    def test_free_deep_graph_pyfunction(self):
2584*da0073e9SAndroid Build Coastguard Worker        class MyOp(Function):
2585*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2586*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, tensor1, tensor2):
2587*da0073e9SAndroid Build Coastguard Worker                return tensor1 + tensor2
2588*da0073e9SAndroid Build Coastguard Worker
2589*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2590*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
2591*da0073e9SAndroid Build Coastguard Worker                return grad_output, grad_output
2592*da0073e9SAndroid Build Coastguard Worker
2593*da0073e9SAndroid Build Coastguard Worker        def scope():
2594*da0073e9SAndroid Build Coastguard Worker            depth = 150000
2595*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(1, requires_grad=True)
2596*da0073e9SAndroid Build Coastguard Worker            y = x.clone()
2597*da0073e9SAndroid Build Coastguard Worker
2598*da0073e9SAndroid Build Coastguard Worker            # build deeply nested computation graph
2599*da0073e9SAndroid Build Coastguard Worker            for _ in range(depth):
2600*da0073e9SAndroid Build Coastguard Worker                y = MyOp.apply(y, y)
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker            # graph deletion occurs when the above locals go out of scope.
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        # Should not stack overflow
2605*da0073e9SAndroid Build Coastguard Worker        scope()
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker    def test_no_unnecessary_save(self):
2608*da0073e9SAndroid Build Coastguard Worker        # If we kept x in the derivative Function of x * 2 we would
2609*da0073e9SAndroid Build Coastguard Worker        # get an error in the backward that would complain that we've
2610*da0073e9SAndroid Build Coastguard Worker        # modified x, which was needed for gradient computation.
2611*da0073e9SAndroid Build Coastguard Worker        # Since we should elide unnecessary saves, this test should pass.
2612*da0073e9SAndroid Build Coastguard Worker        mu = torch.ones(1, requires_grad=True)
2613*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(1)
2614*da0073e9SAndroid Build Coastguard Worker        loss = 0
2615*da0073e9SAndroid Build Coastguard Worker        for i in range(3):
2616*da0073e9SAndroid Build Coastguard Worker            x.detach_()
2617*da0073e9SAndroid Build Coastguard Worker            x.copy_(mu + i)
2618*da0073e9SAndroid Build Coastguard Worker            ft = torch.tensor([float(i)])
2619*da0073e9SAndroid Build Coastguard Worker            multiplied = x * ft
2620*da0073e9SAndroid Build Coastguard Worker            s = multiplied.sum()
2621*da0073e9SAndroid Build Coastguard Worker            loss += s
2622*da0073e9SAndroid Build Coastguard Worker        loss.backward()
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker    def test_no_grad(self):
2625*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
2626*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(5, 5) * 4
2627*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2628*da0073e9SAndroid Build Coastguard Worker            w = x + y
2629*da0073e9SAndroid Build Coastguard Worker
2630*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
2631*da0073e9SAndroid Build Coastguard Worker            return x + y
2632*da0073e9SAndroid Build Coastguard Worker
2633*da0073e9SAndroid Build Coastguard Worker        adders = [torch.no_grad()(adder), torch.no_grad(adder)]
2634*da0073e9SAndroid Build Coastguard Worker
2635*da0073e9SAndroid Build Coastguard Worker        for adder in adders:
2636*da0073e9SAndroid Build Coastguard Worker            z = adder(x, y)
2637*da0073e9SAndroid Build Coastguard Worker
2638*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(w.requires_grad)
2639*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
2640*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(w.grad_fn)
2641*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(z.requires_grad)
2642*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
2643*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(z.grad_fn)
2644*da0073e9SAndroid Build Coastguard Worker
2645*da0073e9SAndroid Build Coastguard Worker        # test nested decorator and with-statement on no_grad
2646*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2647*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2648*da0073e9SAndroid Build Coastguard Worker            w = adder(x, y)
2649*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2650*da0073e9SAndroid Build Coastguard Worker
2651*da0073e9SAndroid Build Coastguard Worker    def test_enable_grad_decorator_no_paren(self):
2652*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1, requires_grad=True)
2653*da0073e9SAndroid Build Coastguard Worker
2654*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad
2655*da0073e9SAndroid Build Coastguard Worker        def doubler(x):
2656*da0073e9SAndroid Build Coastguard Worker            return x * 2
2657*da0073e9SAndroid Build Coastguard Worker
2658*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2659*da0073e9SAndroid Build Coastguard Worker            z = doubler(x)
2660*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.requires_grad)
2661*da0073e9SAndroid Build Coastguard Worker
2662*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_generator_functions(self):
2663*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2664*da0073e9SAndroid Build Coastguard Worker        def gen_no_grad():
2665*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
2666*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.is_grad_enabled(), False)
2667*da0073e9SAndroid Build Coastguard Worker                yield i
2668*da0073e9SAndroid Build Coastguard Worker
2669*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
2670*da0073e9SAndroid Build Coastguard Worker            for _ in gen_no_grad():
2671*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.is_grad_enabled(), True)
2672*da0073e9SAndroid Build Coastguard Worker
2673*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad()
2674*da0073e9SAndroid Build Coastguard Worker        def gen_enable_grad():
2675*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
2676*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.is_grad_enabled(), True)
2677*da0073e9SAndroid Build Coastguard Worker                yield i
2678*da0073e9SAndroid Build Coastguard Worker
2679*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2680*da0073e9SAndroid Build Coastguard Worker            for _ in gen_enable_grad():
2681*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.is_grad_enabled(), False)
2682*da0073e9SAndroid Build Coastguard Worker
2683*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_generator_functions_recursive(self):
2684*da0073e9SAndroid Build Coastguard Worker        # enable_grad_decorator_recursive and no_grad_decorator_recursive call each other
2685*da0073e9SAndroid Build Coastguard Worker        # recursively, to ensure that the decorators preserve the caller's setting
2686*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad()
2687*da0073e9SAndroid Build Coastguard Worker        def enable_grad_decorator_recursive(depth):
2688*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2689*da0073e9SAndroid Build Coastguard Worker            if depth > 0:
2690*da0073e9SAndroid Build Coastguard Worker                no_grad_decorator_recursive(depth - 1)
2691*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_grad_enabled())
2692*da0073e9SAndroid Build Coastguard Worker
2693*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2694*da0073e9SAndroid Build Coastguard Worker        def no_grad_decorator_recursive(depth):
2695*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2696*da0073e9SAndroid Build Coastguard Worker            if depth > 0:
2697*da0073e9SAndroid Build Coastguard Worker                enable_grad_decorator_recursive(depth - 1)
2698*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_grad_enabled())
2699*da0073e9SAndroid Build Coastguard Worker
2700*da0073e9SAndroid Build Coastguard Worker        # enable_grad_context_manager_recursive and no_grad_context_manager_recursive call
2701*da0073e9SAndroid Build Coastguard Worker        # each other recursively, to ensure that the decorators preserve the caller's setting
2702*da0073e9SAndroid Build Coastguard Worker        def enable_grad_context_manager_recursive(depth):
2703*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
2704*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_grad_enabled())
2705*da0073e9SAndroid Build Coastguard Worker                if depth > 0:
2706*da0073e9SAndroid Build Coastguard Worker                    no_grad_context_manager_recursive(depth - 1)
2707*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2708*da0073e9SAndroid Build Coastguard Worker
2709*da0073e9SAndroid Build Coastguard Worker        def no_grad_context_manager_recursive(depth):
2710*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
2711*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_grad_enabled())
2712*da0073e9SAndroid Build Coastguard Worker                if depth > 0:
2713*da0073e9SAndroid Build Coastguard Worker                    enable_grad_context_manager_recursive(depth - 1)
2714*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2715*da0073e9SAndroid Build Coastguard Worker
2716*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
2717*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2718*da0073e9SAndroid Build Coastguard Worker            enable_grad_decorator_recursive(10)
2719*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2720*da0073e9SAndroid Build Coastguard Worker            enable_grad_context_manager_recursive(10)
2721*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2722*da0073e9SAndroid Build Coastguard Worker
2723*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2724*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2725*da0073e9SAndroid Build Coastguard Worker            enable_grad_decorator_recursive(10)
2726*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2727*da0073e9SAndroid Build Coastguard Worker            enable_grad_context_manager_recursive(10)
2728*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2729*da0073e9SAndroid Build Coastguard Worker
2730*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_coroutines(self):
2731*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2732*da0073e9SAndroid Build Coastguard Worker        def coro_no_grad(n=10):
2733*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2734*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
2735*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_grad_enabled())
2736*da0073e9SAndroid Build Coastguard Worker                r = yield i
2737*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_grad_enabled())
2738*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(i, r)
2739*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2740*da0073e9SAndroid Build Coastguard Worker
2741*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad()
2742*da0073e9SAndroid Build Coastguard Worker        def coro_enable_grad(n=10):
2743*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2744*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
2745*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_grad_enabled())
2746*da0073e9SAndroid Build Coastguard Worker                r = yield i
2747*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_grad_enabled())
2748*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(i, r)
2749*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2750*da0073e9SAndroid Build Coastguard Worker
2751*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
2752*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_grad_enabled())
2753*da0073e9SAndroid Build Coastguard Worker            coro, r = coro_no_grad(), None
2754*da0073e9SAndroid Build Coastguard Worker            try:
2755*da0073e9SAndroid Build Coastguard Worker                while True:
2756*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2757*da0073e9SAndroid Build Coastguard Worker                    r = coro.send(r)
2758*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2759*da0073e9SAndroid Build Coastguard Worker
2760*da0073e9SAndroid Build Coastguard Worker            except StopIteration:
2761*da0073e9SAndroid Build Coastguard Worker                pass
2762*da0073e9SAndroid Build Coastguard Worker
2763*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2764*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_grad_enabled())
2765*da0073e9SAndroid Build Coastguard Worker            coro, r = coro_enable_grad(), None
2766*da0073e9SAndroid Build Coastguard Worker            try:
2767*da0073e9SAndroid Build Coastguard Worker                while True:
2768*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2769*da0073e9SAndroid Build Coastguard Worker                    r = coro.send(r)
2770*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2771*da0073e9SAndroid Build Coastguard Worker
2772*da0073e9SAndroid Build Coastguard Worker            except StopIteration:
2773*da0073e9SAndroid Build Coastguard Worker                pass
2774*da0073e9SAndroid Build Coastguard Worker
2775*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_coroutines_benign_exceptions(self):
2776*da0073e9SAndroid Build Coastguard Worker        class RecoverableException(Exception):
2777*da0073e9SAndroid Build Coastguard Worker            pass
2778*da0073e9SAndroid Build Coastguard Worker
2779*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2780*da0073e9SAndroid Build Coastguard Worker        def coro_no_grad(n=10):
2781*da0073e9SAndroid Build Coastguard Worker            has_raised = False
2782*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
2783*da0073e9SAndroid Build Coastguard Worker                try:
2784*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2785*da0073e9SAndroid Build Coastguard Worker                    yield (-i if has_raised else i)
2786*da0073e9SAndroid Build Coastguard Worker
2787*da0073e9SAndroid Build Coastguard Worker                except RecoverableException:
2788*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2789*da0073e9SAndroid Build Coastguard Worker                    has_raised = True
2790*da0073e9SAndroid Build Coastguard Worker
2791*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad()
2792*da0073e9SAndroid Build Coastguard Worker        def coro_enable_grad(n=10):
2793*da0073e9SAndroid Build Coastguard Worker            has_raised = False
2794*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
2795*da0073e9SAndroid Build Coastguard Worker                try:
2796*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2797*da0073e9SAndroid Build Coastguard Worker                    yield (-i if has_raised else i)
2798*da0073e9SAndroid Build Coastguard Worker
2799*da0073e9SAndroid Build Coastguard Worker                except RecoverableException:
2800*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2801*da0073e9SAndroid Build Coastguard Worker                    has_raised = True
2802*da0073e9SAndroid Build Coastguard Worker
2803*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
2804*da0073e9SAndroid Build Coastguard Worker            coro = coro_no_grad()
2805*da0073e9SAndroid Build Coastguard Worker            assert 0 == next(coro)
2806*da0073e9SAndroid Build Coastguard Worker            try:
2807*da0073e9SAndroid Build Coastguard Worker                while True:
2808*da0073e9SAndroid Build Coastguard Worker                    r = coro.throw(RecoverableException)
2809*da0073e9SAndroid Build Coastguard Worker                    self.assertLess(r, 0)
2810*da0073e9SAndroid Build Coastguard Worker
2811*da0073e9SAndroid Build Coastguard Worker            except StopIteration:
2812*da0073e9SAndroid Build Coastguard Worker                pass
2813*da0073e9SAndroid Build Coastguard Worker
2814*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2815*da0073e9SAndroid Build Coastguard Worker            coro = coro_enable_grad()
2816*da0073e9SAndroid Build Coastguard Worker            assert 0 == next(coro)
2817*da0073e9SAndroid Build Coastguard Worker            try:
2818*da0073e9SAndroid Build Coastguard Worker                while True:
2819*da0073e9SAndroid Build Coastguard Worker                    r = coro.throw(RecoverableException)
2820*da0073e9SAndroid Build Coastguard Worker                    self.assertLess(r, 0)
2821*da0073e9SAndroid Build Coastguard Worker
2822*da0073e9SAndroid Build Coastguard Worker            except StopIteration:
2823*da0073e9SAndroid Build Coastguard Worker                pass
2824*da0073e9SAndroid Build Coastguard Worker
2825*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_coroutines_critical_exceptions(self):
2826*da0073e9SAndroid Build Coastguard Worker        class UnrecoverableException(Exception):
2827*da0073e9SAndroid Build Coastguard Worker            pass
2828*da0073e9SAndroid Build Coastguard Worker
2829*da0073e9SAndroid Build Coastguard Worker        class SecondaryException(Exception):
2830*da0073e9SAndroid Build Coastguard Worker            pass
2831*da0073e9SAndroid Build Coastguard Worker
2832*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2833*da0073e9SAndroid Build Coastguard Worker        def coro_no_grad(n=10):
2834*da0073e9SAndroid Build Coastguard Worker            has_raised = False
2835*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
2836*da0073e9SAndroid Build Coastguard Worker                try:
2837*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2838*da0073e9SAndroid Build Coastguard Worker                    yield (-i if has_raised else i)
2839*da0073e9SAndroid Build Coastguard Worker
2840*da0073e9SAndroid Build Coastguard Worker                except UnrecoverableException:
2841*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2842*da0073e9SAndroid Build Coastguard Worker                    raise SecondaryException from None
2843*da0073e9SAndroid Build Coastguard Worker
2844*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad()
2845*da0073e9SAndroid Build Coastguard Worker        def coro_enable_grad(n=10):
2846*da0073e9SAndroid Build Coastguard Worker            has_raised = False
2847*da0073e9SAndroid Build Coastguard Worker            for i in range(n):
2848*da0073e9SAndroid Build Coastguard Worker                try:
2849*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2850*da0073e9SAndroid Build Coastguard Worker                    yield (-i if has_raised else i)
2851*da0073e9SAndroid Build Coastguard Worker
2852*da0073e9SAndroid Build Coastguard Worker                except UnrecoverableException:
2853*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2854*da0073e9SAndroid Build Coastguard Worker                    raise SecondaryException from None
2855*da0073e9SAndroid Build Coastguard Worker
2856*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
2857*da0073e9SAndroid Build Coastguard Worker            coro = coro_no_grad()
2858*da0073e9SAndroid Build Coastguard Worker            assert 0 == next(coro)
2859*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(SecondaryException):
2860*da0073e9SAndroid Build Coastguard Worker                coro.throw(UnrecoverableException)
2861*da0073e9SAndroid Build Coastguard Worker
2862*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2863*da0073e9SAndroid Build Coastguard Worker            coro = coro_enable_grad()
2864*da0073e9SAndroid Build Coastguard Worker            assert 0 == next(coro)
2865*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(SecondaryException):
2866*da0073e9SAndroid Build Coastguard Worker                coro.throw(UnrecoverableException)
2867*da0073e9SAndroid Build Coastguard Worker
2868*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_coroutines_exit(self):
2869*da0073e9SAndroid Build Coastguard Worker        @torch.no_grad()
2870*da0073e9SAndroid Build Coastguard Worker        def coro_no_grad(state):
2871*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
2872*da0073e9SAndroid Build Coastguard Worker                try:
2873*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2874*da0073e9SAndroid Build Coastguard Worker                    yield i
2875*da0073e9SAndroid Build Coastguard Worker
2876*da0073e9SAndroid Build Coastguard Worker                except GeneratorExit:
2877*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_grad_enabled())
2878*da0073e9SAndroid Build Coastguard Worker                    state.add("GeneratorExit")
2879*da0073e9SAndroid Build Coastguard Worker                    raise
2880*da0073e9SAndroid Build Coastguard Worker
2881*da0073e9SAndroid Build Coastguard Worker        @torch.enable_grad()
2882*da0073e9SAndroid Build Coastguard Worker        def coro_enable_grad(state):
2883*da0073e9SAndroid Build Coastguard Worker            for i in range(10):
2884*da0073e9SAndroid Build Coastguard Worker                try:
2885*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2886*da0073e9SAndroid Build Coastguard Worker                    yield i
2887*da0073e9SAndroid Build Coastguard Worker
2888*da0073e9SAndroid Build Coastguard Worker                except GeneratorExit:
2889*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
2890*da0073e9SAndroid Build Coastguard Worker                    state.add("GeneratorExit")
2891*da0073e9SAndroid Build Coastguard Worker                    raise
2892*da0073e9SAndroid Build Coastguard Worker
2893*da0073e9SAndroid Build Coastguard Worker        state = set()
2894*da0073e9SAndroid Build Coastguard Worker        with torch.enable_grad():
2895*da0073e9SAndroid Build Coastguard Worker            coro = coro_no_grad(state)
2896*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
2897*da0073e9SAndroid Build Coastguard Worker                next(coro)
2898*da0073e9SAndroid Build Coastguard Worker
2899*da0073e9SAndroid Build Coastguard Worker            coro.close()
2900*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("GeneratorExit" in state)
2901*da0073e9SAndroid Build Coastguard Worker
2902*da0073e9SAndroid Build Coastguard Worker        state = set()
2903*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2904*da0073e9SAndroid Build Coastguard Worker            coro = coro_enable_grad(state)
2905*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
2906*da0073e9SAndroid Build Coastguard Worker                next(coro)
2907*da0073e9SAndroid Build Coastguard Worker
2908*da0073e9SAndroid Build Coastguard Worker            coro.close()
2909*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("GeneratorExit" in state)
2910*da0073e9SAndroid Build Coastguard Worker
2911*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_python_function(self):
2912*da0073e9SAndroid Build Coastguard Worker        """Python Functions should respect grad mode."""
2913*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
2914*da0073e9SAndroid Build Coastguard Worker
2915*da0073e9SAndroid Build Coastguard Worker        class MyOp(Function):
2916*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2917*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
2918*da0073e9SAndroid Build Coastguard Worker                return x + 1
2919*da0073e9SAndroid Build Coastguard Worker
2920*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2921*da0073e9SAndroid Build Coastguard Worker            def backward(self, dy):
2922*da0073e9SAndroid Build Coastguard Worker                return dy
2923*da0073e9SAndroid Build Coastguard Worker
2924*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
2925*da0073e9SAndroid Build Coastguard Worker            y = MyOp.apply(x)
2926*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
2927*da0073e9SAndroid Build Coastguard Worker
2928*da0073e9SAndroid Build Coastguard Worker    def test_indexing(self):
2929*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1.0, 17).view(4, 4)
2930*da0073e9SAndroid Build Coastguard Worker        y = Variable(x, requires_grad=True)
2931*da0073e9SAndroid Build Coastguard Worker
2932*da0073e9SAndroid Build Coastguard Worker        def compare(x, y, idx, indexed_tensor, indexed_var):
2933*da0073e9SAndroid Build Coastguard Worker            indexed_var_t = indexed_var.data
2934*da0073e9SAndroid Build Coastguard Worker            if not isinstance(indexed_tensor, torch.Tensor):
2935*da0073e9SAndroid Build Coastguard Worker                indexed_var_t = indexed_var_t[0]
2936*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(indexed_tensor, indexed_var_t)
2937*da0073e9SAndroid Build Coastguard Worker
2938*da0073e9SAndroid Build Coastguard Worker            indexed_var.sum().backward()
2939*da0073e9SAndroid Build Coastguard Worker            expected_grad = torch.empty(x.size()).fill_(0)
2940*da0073e9SAndroid Build Coastguard Worker            expected_grad[idx] = 1
2941*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y.grad, expected_grad)
2942*da0073e9SAndroid Build Coastguard Worker
2943*da0073e9SAndroid Build Coastguard Worker        def check_index(x, y, idx):
2944*da0073e9SAndroid Build Coastguard Worker            if y.grad is not None:
2945*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
2946*da0073e9SAndroid Build Coastguard Worker                    y.grad.zero_()
2947*da0073e9SAndroid Build Coastguard Worker            indexed_tensor = x[idx]
2948*da0073e9SAndroid Build Coastguard Worker            indexed_var = y[idx]
2949*da0073e9SAndroid Build Coastguard Worker            compare(x, y, idx, indexed_tensor, indexed_var)
2950*da0073e9SAndroid Build Coastguard Worker
2951*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, 1)
2952*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (1, 1))
2953*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, slice(1, None))
2954*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, slice(None, 2))
2955*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None, 2), 2))
2956*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(1, 2), 2))
2957*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (1, slice(2, None)))
2958*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None, None), slice(2, None)))
2959*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, torch.LongTensor([0, 2]))
2960*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, torch.rand(4, 4).bernoulli().bool())
2961*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (Ellipsis, slice(2, None)))
2962*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0], [0]))
2963*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([1, 2, 3], [0]))
2964*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([1, 2], [2, 1]))
2965*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]]))
2966*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([slice(None), [2, 3]]))
2967*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([[2, 3], slice(None)]))
2968*da0073e9SAndroid Build Coastguard Worker
2969*da0073e9SAndroid Build Coastguard Worker        # advanced indexing, with less dim, or ellipsis
2970*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0]))
2971*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0],))
2972*da0073e9SAndroid Build Coastguard Worker
2973*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1.0, 49).view(4, 3, 4)
2974*da0073e9SAndroid Build Coastguard Worker        y = Variable(x, requires_grad=True)
2975*da0073e9SAndroid Build Coastguard Worker
2976*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None), [0], [0]))
2977*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0], [0], slice(None)))
2978*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None), [0, 1, 2], [0]))
2979*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0, 1, 2], [0], slice(None)))
2980*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None), [1, 2], [2, 1]))
2981*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([1, 2], [2, 1], slice(None)))
2982*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None), [[1, 2], [2, 0]], [[0, 1], [2, 3]]))
2983*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(None)))
2984*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None), slice(None), [2, 1]))
2985*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (slice(None), [2, 1], slice(None)))
2986*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([2, 1], slice(None), slice(None)))
2987*da0073e9SAndroid Build Coastguard Worker
2988*da0073e9SAndroid Build Coastguard Worker        # advanced indexing, with less dim, or ellipsis
2989*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0],))
2990*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0], slice(None)))
2991*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([0], Ellipsis))
2992*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([1, 2], [0, 1]))
2993*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, ([1, 2], [0, 1], Ellipsis))
2994*da0073e9SAndroid Build Coastguard Worker        check_index(x, y, (Ellipsis, [1, 2], [0, 1]))
2995*da0073e9SAndroid Build Coastguard Worker
2996*da0073e9SAndroid Build Coastguard Worker        # advanced indexing, with a tensor wrapped in a variable
2997*da0073e9SAndroid Build Coastguard Worker        z = torch.LongTensor([0, 1])
2998*da0073e9SAndroid Build Coastguard Worker        zv = Variable(z, requires_grad=False)
2999*da0073e9SAndroid Build Coastguard Worker        seq = [z, Ellipsis]
3000*da0073e9SAndroid Build Coastguard Worker        seqv = [zv, Ellipsis]
3001*da0073e9SAndroid Build Coastguard Worker
3002*da0073e9SAndroid Build Coastguard Worker        if y.grad is not None:
3003*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
3004*da0073e9SAndroid Build Coastguard Worker                y.grad.zero_()
3005*da0073e9SAndroid Build Coastguard Worker        indexed_tensor = x[seq]
3006*da0073e9SAndroid Build Coastguard Worker        indexed_var = y[seqv]
3007*da0073e9SAndroid Build Coastguard Worker        compare(x, y, seq, indexed_tensor, indexed_var)
3008*da0073e9SAndroid Build Coastguard Worker
3009*da0073e9SAndroid Build Coastguard Worker    def test_indexing_duplicates(self):
3010*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1.0, 17).view(4, 4)
3011*da0073e9SAndroid Build Coastguard Worker        y = Variable(x, requires_grad=True)
3012*da0073e9SAndroid Build Coastguard Worker
3013*da0073e9SAndroid Build Coastguard Worker        idx = torch.LongTensor([1, 1, 3, 2, 1, 2])
3014*da0073e9SAndroid Build Coastguard Worker        y[idx].sum().backward()
3015*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.zeros(4, 4)
3016*da0073e9SAndroid Build Coastguard Worker        for i in idx:
3017*da0073e9SAndroid Build Coastguard Worker            expected_grad[i] += 1
3018*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, expected_grad)
3019*da0073e9SAndroid Build Coastguard Worker
3020*da0073e9SAndroid Build Coastguard Worker        # with advanced indexing
3021*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1.0, 17).view(4, 4)
3022*da0073e9SAndroid Build Coastguard Worker        y = Variable(x, requires_grad=True)
3023*da0073e9SAndroid Build Coastguard Worker
3024*da0073e9SAndroid Build Coastguard Worker        idx = [[1, 1, 3, 2, 1, 2], [0]]
3025*da0073e9SAndroid Build Coastguard Worker        y[idx].sum().backward()
3026*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.zeros(4, 4)
3027*da0073e9SAndroid Build Coastguard Worker        for i in idx[0]:
3028*da0073e9SAndroid Build Coastguard Worker            for j in idx[1]:
3029*da0073e9SAndroid Build Coastguard Worker                expected_grad[i][j] += 1
3030*da0073e9SAndroid Build Coastguard Worker
3031*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, expected_grad)
3032*da0073e9SAndroid Build Coastguard Worker
3033*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1.0, 17).view(4, 4)
3034*da0073e9SAndroid Build Coastguard Worker        y = Variable(x, requires_grad=True)
3035*da0073e9SAndroid Build Coastguard Worker        idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]]
3036*da0073e9SAndroid Build Coastguard Worker        y[idx].sum().backward()
3037*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.tensor(
3038*da0073e9SAndroid Build Coastguard Worker            [
3039*da0073e9SAndroid Build Coastguard Worker                [0.0, 2.0, 0.0, 0.0],
3040*da0073e9SAndroid Build Coastguard Worker                [1.0, 0.0, 0.0, 0.0],
3041*da0073e9SAndroid Build Coastguard Worker                [0.0, 1.0, 0.0, 0.0],
3042*da0073e9SAndroid Build Coastguard Worker                [0.0, 0.0, 0.0, 0.0],
3043*da0073e9SAndroid Build Coastguard Worker            ]
3044*da0073e9SAndroid Build Coastguard Worker        )
3045*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, expected_grad)
3046*da0073e9SAndroid Build Coastguard Worker
3047*da0073e9SAndroid Build Coastguard Worker        x = torch.arange(1.0, 65).view(4, 4, 4)
3048*da0073e9SAndroid Build Coastguard Worker        y = Variable(x, requires_grad=True)
3049*da0073e9SAndroid Build Coastguard Worker
3050*da0073e9SAndroid Build Coastguard Worker        idx = [[1, 1, 1], slice(None), slice(None)]
3051*da0073e9SAndroid Build Coastguard Worker        y[idx].sum().backward()
3052*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.empty(4, 4, 4).zero_()
3053*da0073e9SAndroid Build Coastguard Worker        expected_grad[1].fill_(3)
3054*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, expected_grad)
3055*da0073e9SAndroid Build Coastguard Worker
3056*da0073e9SAndroid Build Coastguard Worker    def test_index_backward_does_not_save_tensor(self):
3057*da0073e9SAndroid Build Coastguard Worker        # Example from https://github.com/pytorch/pytorch/issues/24853.
3058*da0073e9SAndroid Build Coastguard Worker        # if `index(tensor, indices)` saves `tensor` for backwards, then it will
3059*da0073e9SAndroid Build Coastguard Worker        # trigger a version check on `tensor` during the backward pass, which
3060*da0073e9SAndroid Build Coastguard Worker        # will cause the following code to error because `tensor` gets modified
3061*da0073e9SAndroid Build Coastguard Worker        # by the indexing line.
3062*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0, 0, 0])
3063*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(3, requires_grad=True)
3064*da0073e9SAndroid Build Coastguard Worker        tensor = b + 0
3065*da0073e9SAndroid Build Coastguard Worker        tensor[a != 0] = tensor[a != 0]
3066*da0073e9SAndroid Build Coastguard Worker        tensor.backward(torch.zeros_like(tensor))
3067*da0073e9SAndroid Build Coastguard Worker
3068*da0073e9SAndroid Build Coastguard Worker    def test_volatile_deprecated(self):
3069*da0073e9SAndroid Build Coastguard Worker        v = torch.autograd.torch.randn(3, 3)
3070*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
3071*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(v.volatile)
3072*da0073e9SAndroid Build Coastguard Worker        self.assertIn("volatile", str(w[0].message))
3073*da0073e9SAndroid Build Coastguard Worker
3074*da0073e9SAndroid Build Coastguard Worker    def test_saved_variables_deprecated(self):
3075*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
3076*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3077*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, tensor1, tensor2):
3078*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(tensor1, tensor2)
3079*da0073e9SAndroid Build Coastguard Worker                return tensor1 + tensor2
3080*da0073e9SAndroid Build Coastguard Worker
3081*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3082*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3083*da0073e9SAndroid Build Coastguard Worker                var1, var2 = ctx.saved_variables
3084*da0073e9SAndroid Build Coastguard Worker                return (grad_output, grad_output)
3085*da0073e9SAndroid Build Coastguard Worker
3086*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as warns:
3087*da0073e9SAndroid Build Coastguard Worker            warnings.simplefilter("always")
3088*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((3, 3), requires_grad=True)
3089*da0073e9SAndroid Build Coastguard Worker            y = torch.randn((3, 3), requires_grad=True)
3090*da0073e9SAndroid Build Coastguard Worker            MyFunction.apply(x, y).sum().backward()
3091*da0073e9SAndroid Build Coastguard Worker
3092*da0073e9SAndroid Build Coastguard Worker            has_deprecated = (
3093*da0073e9SAndroid Build Coastguard Worker                "deprecated" in str(warn) and "saved_variables" in str(warn)
3094*da0073e9SAndroid Build Coastguard Worker                for warn in warns
3095*da0073e9SAndroid Build Coastguard Worker            )
3096*da0073e9SAndroid Build Coastguard Worker            has_deprecated = reduce(lambda x, y: x or y, has_deprecated)
3097*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(has_deprecated)
3098*da0073e9SAndroid Build Coastguard Worker
3099*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad(self):
3100*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
3101*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5)
3102*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(5, 5, requires_grad=True)
3103*da0073e9SAndroid Build Coastguard Worker        a = x + y
3104*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(a.requires_grad)
3105*da0073e9SAndroid Build Coastguard Worker        b = a + z
3106*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.requires_grad)
3107*da0073e9SAndroid Build Coastguard Worker
3108*da0073e9SAndroid Build Coastguard Worker        def error():
3109*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError
3110*da0073e9SAndroid Build Coastguard Worker
3111*da0073e9SAndroid Build Coastguard Worker        # Make sure backward isn't called on these
3112*da0073e9SAndroid Build Coastguard Worker        a._backward_hooks = OrderedDict()
3113*da0073e9SAndroid Build Coastguard Worker        x._backward_hooks = OrderedDict()
3114*da0073e9SAndroid Build Coastguard Worker        y._backward_hooks = OrderedDict()
3115*da0073e9SAndroid Build Coastguard Worker        a._backward_hooks["test"] = error
3116*da0073e9SAndroid Build Coastguard Worker        x._backward_hooks["test"] = error
3117*da0073e9SAndroid Build Coastguard Worker        y._backward_hooks["test"] = error
3118*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.ones(5, 5))
3119*da0073e9SAndroid Build Coastguard Worker
3120*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_(self):
3121*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
3122*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
3123*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x, x.requires_grad_())
3124*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.requires_grad)
3125*da0073e9SAndroid Build Coastguard Worker        self.assertIs(y, y.requires_grad_())
3126*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.requires_grad)
3127*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x, x.requires_grad_(True))
3128*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.requires_grad)
3129*da0073e9SAndroid Build Coastguard Worker        self.assertIs(y, y.requires_grad_(True))
3130*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.requires_grad)
3131*da0073e9SAndroid Build Coastguard Worker        z = x * y
3132*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: z.requires_grad_(False))
3133*da0073e9SAndroid Build Coastguard Worker        self.assertIs(z, z.requires_grad_())
3134*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.requires_grad)
3135*da0073e9SAndroid Build Coastguard Worker        self.assertIs(z, z.requires_grad_(True))
3136*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(z.requires_grad)
3137*da0073e9SAndroid Build Coastguard Worker
3138*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x, x.requires_grad_(False))
3139*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.requires_grad)
3140*da0073e9SAndroid Build Coastguard Worker        self.assertIs(y, y.requires_grad_(False))
3141*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
3142*da0073e9SAndroid Build Coastguard Worker
3143*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_inplace(self):
3144*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5)
3145*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5, requires_grad=True)
3146*da0073e9SAndroid Build Coastguard Worker        a += b
3147*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.requires_grad)
3148*da0073e9SAndroid Build Coastguard Worker
3149*da0073e9SAndroid Build Coastguard Worker        # non-leaf
3150*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5) + 0
3151*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5, requires_grad=True)
3152*da0073e9SAndroid Build Coastguard Worker        a += b
3153*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.requires_grad)
3154*da0073e9SAndroid Build Coastguard Worker
3155*da0073e9SAndroid Build Coastguard Worker    def test_no_requires_grad_inplace(self):
3156*da0073e9SAndroid Build Coastguard Worker        # basic case, should be able to modify inplace while requires_grad is False
3157*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3)
3158*da0073e9SAndroid Build Coastguard Worker        a.add_(5)
3159*da0073e9SAndroid Build Coastguard Worker        a.requires_grad = True
3160*da0073e9SAndroid Build Coastguard Worker        a.sum().backward()
3161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.ones(2, 3))
3162*da0073e9SAndroid Build Coastguard Worker
3163*da0073e9SAndroid Build Coastguard Worker        # same but with a view
3164*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3)
3165*da0073e9SAndroid Build Coastguard Worker        b = a[:]
3166*da0073e9SAndroid Build Coastguard Worker        b.add_(5)
3167*da0073e9SAndroid Build Coastguard Worker        a.requires_grad = True
3168*da0073e9SAndroid Build Coastguard Worker        a.sum().backward()
3169*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, torch.ones(2, 3))
3170*da0073e9SAndroid Build Coastguard Worker
3171*da0073e9SAndroid Build Coastguard Worker        # should fail if requires_grad = True when we modify inplace
3172*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3)
3173*da0073e9SAndroid Build Coastguard Worker        b = a[:]
3174*da0073e9SAndroid Build Coastguard Worker        a.requires_grad = True
3175*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3176*da0073e9SAndroid Build Coastguard Worker            a.add_(5)
3177*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3178*da0073e9SAndroid Build Coastguard Worker            b.add_(5)
3179*da0073e9SAndroid Build Coastguard Worker
3180*da0073e9SAndroid Build Coastguard Worker    def test_attribute_deletion(self):
3181*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((5, 5), requires_grad=True)
3182*da0073e9SAndroid Build Coastguard Worker        del x.grad
3183*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x.grad)
3184*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3185*da0073e9SAndroid Build Coastguard Worker            del x.data
3186*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(TypeError):
3187*da0073e9SAndroid Build Coastguard Worker            x.data = None
3188*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3189*da0073e9SAndroid Build Coastguard Worker            del x.requires_grad
3190*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3191*da0073e9SAndroid Build Coastguard Worker            del x._grad_fn
3192*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3193*da0073e9SAndroid Build Coastguard Worker            del x._backward_hooks
3194*da0073e9SAndroid Build Coastguard Worker
3195*da0073e9SAndroid Build Coastguard Worker    def test_duplicate_backward_root(self):
3196*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5, requires_grad=True)
3197*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 5, requires_grad=True)
3198*da0073e9SAndroid Build Coastguard Worker
3199*da0073e9SAndroid Build Coastguard Worker        x = a * b
3200*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn_like(x)
3201*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([x, x], [grad_output, grad_output])
3202*da0073e9SAndroid Build Coastguard Worker
3203*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, b * grad_output * 2)
3204*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, a * grad_output * 2)
3205*da0073e9SAndroid Build Coastguard Worker
3206*da0073e9SAndroid Build Coastguard Worker    def test_backward_no_grad(self):
3207*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 5, requires_grad=True)
3208*da0073e9SAndroid Build Coastguard Worker        b = a + 2
3209*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
3210*da0073e9SAndroid Build Coastguard Worker            torch.autograd.backward([b], [None])
3211*da0073e9SAndroid Build Coastguard Worker
3212*da0073e9SAndroid Build Coastguard Worker    def test_backward_twice_with_saved_values(self):
3213*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3214*da0073e9SAndroid Build Coastguard Worker        c = torch.zeros(3, dtype=torch.double)
3215*da0073e9SAndroid Build Coastguard Worker        c[[1, 2]] = b[[1, 1]]
3216*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3217*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
3218*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
3219*da0073e9SAndroid Build Coastguard Worker            "Specify retain_graph=True",
3220*da0073e9SAndroid Build Coastguard Worker            lambda: c.backward(torch.tensor([1, 1, 1], dtype=torch.double)),
3221*da0073e9SAndroid Build Coastguard Worker        )
3222*da0073e9SAndroid Build Coastguard Worker
3223*da0073e9SAndroid Build Coastguard Worker    def test_backward_twice_retained_graph_with_saved_values(self):
3224*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3225*da0073e9SAndroid Build Coastguard Worker        c = torch.zeros(3, dtype=torch.double)
3226*da0073e9SAndroid Build Coastguard Worker        c[[1, 2]] = b[[1, 1]]
3227*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
3228*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3229*da0073e9SAndroid Build Coastguard Worker
3230*da0073e9SAndroid Build Coastguard Worker    def test_backward_twice_without_saved_values(self):
3231*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3232*da0073e9SAndroid Build Coastguard Worker        c = b + 1
3233*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3234*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3235*da0073e9SAndroid Build Coastguard Worker
3236*da0073e9SAndroid Build Coastguard Worker    def test_backward_twice_retained_graph_without_saved_values(self):
3237*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(3, requires_grad=True, dtype=torch.double)
3238*da0073e9SAndroid Build Coastguard Worker        c = torch.zeros(3, dtype=torch.double)
3239*da0073e9SAndroid Build Coastguard Worker        c[[1, 2]] = b[[1, 1]]
3240*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
3241*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
3242*da0073e9SAndroid Build Coastguard Worker
3243*da0073e9SAndroid Build Coastguard Worker    def test_backward_create_graph_warns(self):
3244*da0073e9SAndroid Build Coastguard Worker        with set_warn_always_context(True):
3245*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(3, requires_grad=True, dtype=torch.double)
3246*da0073e9SAndroid Build Coastguard Worker            c = b * b
3247*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as ws:
3248*da0073e9SAndroid Build Coastguard Worker                c.backward(torch.ones_like(c), create_graph=True)
3249*da0073e9SAndroid Build Coastguard Worker            b.grad = None
3250*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
3251*da0073e9SAndroid Build Coastguard Worker                any(
3252*da0073e9SAndroid Build Coastguard Worker                    "Using backward() with create_graph=True" in str(w.message)
3253*da0073e9SAndroid Build Coastguard Worker                    for w in ws
3254*da0073e9SAndroid Build Coastguard Worker                )
3255*da0073e9SAndroid Build Coastguard Worker            )
3256*da0073e9SAndroid Build Coastguard Worker
3257*da0073e9SAndroid Build Coastguard Worker            # Should not warn for grad
3258*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as ws:
3259*da0073e9SAndroid Build Coastguard Worker                torch.autograd.grad(c, b, torch.ones_like(c), create_graph=True)
3260*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
3261*da0073e9SAndroid Build Coastguard Worker                any(
3262*da0073e9SAndroid Build Coastguard Worker                    "Using backward() with create_graph=True" in str(w.message)
3263*da0073e9SAndroid Build Coastguard Worker                    for w in ws
3264*da0073e9SAndroid Build Coastguard Worker                )
3265*da0073e9SAndroid Build Coastguard Worker            )
3266*da0073e9SAndroid Build Coastguard Worker
3267*da0073e9SAndroid Build Coastguard Worker    def test_next_functions(self):
3268*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
3269*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
3270*da0073e9SAndroid Build Coastguard Worker
3271*da0073e9SAndroid Build Coastguard Worker        a = x + y
3272*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(a.grad_fn)
3273*da0073e9SAndroid Build Coastguard Worker        next_functions = a.grad_fn.next_functions
3274*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(next_functions), 2)
3275*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad)
3276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next_functions[0][1], 0)
3277*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad)
3278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(next_functions[1][1], 0)
3279*da0073e9SAndroid Build Coastguard Worker
3280*da0073e9SAndroid Build Coastguard Worker        b = a + 5
3281*da0073e9SAndroid Build Coastguard Worker        next_functions = b.grad_fn.next_functions
3282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(next_functions), 2)
3283*da0073e9SAndroid Build Coastguard Worker        self.assertIs(next_functions[0][0], a.grad_fn)
3284*da0073e9SAndroid Build Coastguard Worker        self.assertIs(next_functions[1][0], None)
3285*da0073e9SAndroid Build Coastguard Worker
3286*da0073e9SAndroid Build Coastguard Worker    def test_inplace(self):
3287*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
3288*da0073e9SAndroid Build Coastguard Worker        y = Variable(torch.ones(5, 5) * 4, requires_grad=True)
3289*da0073e9SAndroid Build Coastguard Worker
3290*da0073e9SAndroid Build Coastguard Worker        z = x * y
3291*da0073e9SAndroid Build Coastguard Worker        q = z + y
3292*da0073e9SAndroid Build Coastguard Worker        w = z * y
3293*da0073e9SAndroid Build Coastguard Worker        z.add_(2)
3294*da0073e9SAndroid Build Coastguard Worker        # Add doesn't need it's inputs to do backward, so it shouldn't raise
3295*da0073e9SAndroid Build Coastguard Worker        q.backward(torch.ones(5, 5), retain_graph=True)
3296*da0073e9SAndroid Build Coastguard Worker        # Mul saves both inputs in forward, so it should raise
3297*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5)))
3298*da0073e9SAndroid Build Coastguard Worker
3299*da0073e9SAndroid Build Coastguard Worker        z = x * y
3300*da0073e9SAndroid Build Coastguard Worker        q = z * y
3301*da0073e9SAndroid Build Coastguard Worker        r = z + y
3302*da0073e9SAndroid Build Coastguard Worker        w = z.add_(y)
3303*da0073e9SAndroid Build Coastguard Worker        # w is a the last expression, so this should succeed
3304*da0073e9SAndroid Build Coastguard Worker        w.backward(torch.ones(5, 5), retain_graph=True)
3305*da0073e9SAndroid Build Coastguard Worker        # r doesn't use the modified value in backward, so it should succeed
3306*da0073e9SAndroid Build Coastguard Worker        r.backward(torch.ones(5, 5), retain_graph=True)
3307*da0073e9SAndroid Build Coastguard Worker        # q uses dirty z, so it should raise
3308*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
3309*da0073e9SAndroid Build Coastguard Worker
3310*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3311*da0073e9SAndroid Build Coastguard Worker            x.grad.zero_()
3312*da0073e9SAndroid Build Coastguard Worker        m = x / 2
3313*da0073e9SAndroid Build Coastguard Worker        z = m + y / 8
3314*da0073e9SAndroid Build Coastguard Worker        q = z * y
3315*da0073e9SAndroid Build Coastguard Worker        r = z + y
3316*da0073e9SAndroid Build Coastguard Worker        prev_version = z._version
3317*da0073e9SAndroid Build Coastguard Worker        w = z.exp_()
3318*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(z._version, prev_version)
3319*da0073e9SAndroid Build Coastguard Worker        r.backward(torch.ones(5, 5), retain_graph=True)
3320*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(5, 5) / 2)
3321*da0073e9SAndroid Build Coastguard Worker        w.backward(torch.ones(5, 5), retain_graph=True)
3322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.empty(5, 5).fill_((1 + math.e) / 2))
3323*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5)))
3324*da0073e9SAndroid Build Coastguard Worker
3325*da0073e9SAndroid Build Coastguard Worker        leaf = torch.ones(5, 5, requires_grad=True)
3326*da0073e9SAndroid Build Coastguard Worker        x = leaf.clone()
3327*da0073e9SAndroid Build Coastguard Worker        x.add_(10)
3328*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.ones(5, 5) * 11)
3329*da0073e9SAndroid Build Coastguard Worker        # x should be still usable
3330*da0073e9SAndroid Build Coastguard Worker        y = x + 2
3331*da0073e9SAndroid Build Coastguard Worker        y.backward(torch.ones(5, 5))
3332*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(leaf.grad, torch.ones(5, 5))
3333*da0073e9SAndroid Build Coastguard Worker        z = x * y
3334*da0073e9SAndroid Build Coastguard Worker        x.add_(2)
3335*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
3336*da0073e9SAndroid Build Coastguard Worker
3337*da0073e9SAndroid Build Coastguard Worker    def test_mark_non_differentiable(self):
3338*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
3339*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3340*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
3341*da0073e9SAndroid Build Coastguard Worker                output = input > 0
3342*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(output)
3343*da0073e9SAndroid Build Coastguard Worker                return output
3344*da0073e9SAndroid Build Coastguard Worker
3345*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3346*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3347*da0073e9SAndroid Build Coastguard Worker                return (grad_output * 0).to(torch.double)
3348*da0073e9SAndroid Build Coastguard Worker
3349*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
3350*da0073e9SAndroid Build Coastguard Worker        mask = MyFunction.apply(x)
3351*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(mask.requires_grad)
3352*da0073e9SAndroid Build Coastguard Worker        y = x.masked_fill(mask, 0)
3353*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
3354*da0073e9SAndroid Build Coastguard Worker
3355*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
3356*da0073e9SAndroid Build Coastguard Worker    def test_mark_non_differentiable_mixed(self):
3357*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
3358*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3359*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
3360*da0073e9SAndroid Build Coastguard Worker                a = input + 1
3361*da0073e9SAndroid Build Coastguard Worker                b = input + 2
3362*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(a)
3363*da0073e9SAndroid Build Coastguard Worker                return a, b
3364*da0073e9SAndroid Build Coastguard Worker
3365*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3366*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_a, grad_b):
3367*da0073e9SAndroid Build Coastguard Worker                self.assertTrue((grad_a == 0).all())
3368*da0073e9SAndroid Build Coastguard Worker                self.assertTrue((grad_b == 1).all())
3369*da0073e9SAndroid Build Coastguard Worker                return grad_b
3370*da0073e9SAndroid Build Coastguard Worker
3371*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
3372*da0073e9SAndroid Build Coastguard Worker        a, b = MyFunction.apply(x)
3373*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(a.requires_grad)
3374*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b.requires_grad)
3375*da0073e9SAndroid Build Coastguard Worker        b.sum().backward()
3376*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(5, 5))
3377*da0073e9SAndroid Build Coastguard Worker
3378*da0073e9SAndroid Build Coastguard Worker    def test_mark_non_differentiable_none(self):
3379*da0073e9SAndroid Build Coastguard Worker        # This used to segfault because MyFunction would send back null
3380*da0073e9SAndroid Build Coastguard Worker        # gradients to MulBackward, which is implemented in C++. C++
3381*da0073e9SAndroid Build Coastguard Worker        # implemented functions expect incoming grad_outputs to be non-null.
3382*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
3383*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3384*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
3385*da0073e9SAndroid Build Coastguard Worker                output = input.clone()
3386*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(output)
3387*da0073e9SAndroid Build Coastguard Worker                return output
3388*da0073e9SAndroid Build Coastguard Worker
3389*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3390*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3391*da0073e9SAndroid Build Coastguard Worker                return None
3392*da0073e9SAndroid Build Coastguard Worker
3393*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
3394*da0073e9SAndroid Build Coastguard Worker        r = MyFunction.apply(x * x)
3395*da0073e9SAndroid Build Coastguard Worker        (r * x).sum().backward()
3396*da0073e9SAndroid Build Coastguard Worker
3397*da0073e9SAndroid Build Coastguard Worker    def test_return_duplicate(self):
3398*da0073e9SAndroid Build Coastguard Worker        class DoubleDuplicate(Function):
3399*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3400*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3401*da0073e9SAndroid Build Coastguard Worker                output = x * 2
3402*da0073e9SAndroid Build Coastguard Worker                return output, output
3403*da0073e9SAndroid Build Coastguard Worker
3404*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3405*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad1, grad2):
3406*da0073e9SAndroid Build Coastguard Worker                return grad1 * 2 + grad2 * 2
3407*da0073e9SAndroid Build Coastguard Worker
3408*da0073e9SAndroid Build Coastguard Worker        def fn(x):
3409*da0073e9SAndroid Build Coastguard Worker            a, b = DoubleDuplicate.apply(x)
3410*da0073e9SAndroid Build Coastguard Worker            self.assertIs(a, b)
3411*da0073e9SAndroid Build Coastguard Worker            return a + b
3412*da0073e9SAndroid Build Coastguard Worker
3413*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
3414*da0073e9SAndroid Build Coastguard Worker        gradcheck(fn, [x])
3415*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(fn, [x])
3416*da0073e9SAndroid Build Coastguard Worker
3417*da0073e9SAndroid Build Coastguard Worker    def test_return_duplicate_inplace(self):
3418*da0073e9SAndroid Build Coastguard Worker        class DoubleInplace(Function):
3419*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3420*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3421*da0073e9SAndroid Build Coastguard Worker                x.mul_(2)
3422*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(x)
3423*da0073e9SAndroid Build Coastguard Worker                return x, x
3424*da0073e9SAndroid Build Coastguard Worker
3425*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3426*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad1, grad2):
3427*da0073e9SAndroid Build Coastguard Worker                return grad1 * 2 + grad2 * 2
3428*da0073e9SAndroid Build Coastguard Worker
3429*da0073e9SAndroid Build Coastguard Worker        def inplace_fn(x):
3430*da0073e9SAndroid Build Coastguard Worker            a, b = DoubleInplace.apply(x.clone())
3431*da0073e9SAndroid Build Coastguard Worker            self.assertIs(a, b)
3432*da0073e9SAndroid Build Coastguard Worker            return a + b
3433*da0073e9SAndroid Build Coastguard Worker
3434*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
3435*da0073e9SAndroid Build Coastguard Worker        gradcheck(inplace_fn, [x])
3436*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(inplace_fn, [x])
3437*da0073e9SAndroid Build Coastguard Worker
3438*da0073e9SAndroid Build Coastguard Worker        # Can't modify leaf variables in-place
3439*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x))
3440*da0073e9SAndroid Build Coastguard Worker        # Functions which modify views in-place must return only one output
3441*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x.clone()[0]))
3442*da0073e9SAndroid Build Coastguard Worker
3443*da0073e9SAndroid Build Coastguard Worker    def _test_setitem(self, size, index):
3444*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(*size, requires_grad=True)
3445*da0073e9SAndroid Build Coastguard Worker        y = x + 2
3446*da0073e9SAndroid Build Coastguard Worker        y_version = y._version
3447*da0073e9SAndroid Build Coastguard Worker        y[index] = 2
3448*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(y._version, y_version)
3449*da0073e9SAndroid Build Coastguard Worker        y.backward(torch.ones(*size))
3450*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.ones(*size)
3451*da0073e9SAndroid Build Coastguard Worker        expected_grad[index] = 0
3452*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, expected_grad)
3453*da0073e9SAndroid Build Coastguard Worker
3454*da0073e9SAndroid Build Coastguard Worker    def _test_setitem_tensor(self, size, index):
3455*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(*size, requires_grad=True)
3456*da0073e9SAndroid Build Coastguard Worker        y = x + 2
3457*da0073e9SAndroid Build Coastguard Worker        y_version = y._version
3458*da0073e9SAndroid Build Coastguard Worker        value = x.new(x[index].size()).fill_(7)
3459*da0073e9SAndroid Build Coastguard Worker        value.requires_grad = True
3460*da0073e9SAndroid Build Coastguard Worker        y[index] = value
3461*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(y._version, y_version)
3462*da0073e9SAndroid Build Coastguard Worker        y.backward(torch.ones(*size))
3463*da0073e9SAndroid Build Coastguard Worker        expected_grad_input = torch.ones(*size)
3464*da0073e9SAndroid Build Coastguard Worker        expected_grad_input[index] = 0
3465*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, expected_grad_input)
3466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(value.grad, torch.ones_like(value))
3467*da0073e9SAndroid Build Coastguard Worker
3468*da0073e9SAndroid Build Coastguard Worker        # case when x broadcasts to as y[1]
3469*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, requires_grad=True)
3470*da0073e9SAndroid Build Coastguard Worker        y = torch.zeros(2, 3, 4)
3471*da0073e9SAndroid Build Coastguard Worker        y[1] = x
3472*da0073e9SAndroid Build Coastguard Worker        y.backward(torch.randn(2, 3, 4))
3473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.size(), x.grad.size())
3474*da0073e9SAndroid Build Coastguard Worker
3475*da0073e9SAndroid Build Coastguard Worker    def test_setitem(self):
3476*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5), 1)
3477*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5,), 1)
3478*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((1,), 0)
3479*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((10,), [[0, 4, 2]])
3480*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5), [[0, 4], [2, 2]])
3481*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]])
3482*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)])
3483*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)])
3484*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]])
3485*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)])
3486*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5), 3)
3487*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]])
3488*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5,), 3)
3489*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor(
3490*da0073e9SAndroid Build Coastguard Worker            (5,), Variable(torch.LongTensor([3]), requires_grad=False).sum()
3491*da0073e9SAndroid Build Coastguard Worker        )
3492*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5,), [[0, 1, 2, 3]])
3493*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]])
3494*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)])
3495*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)])
3496*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]])
3497*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)])
3498*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor(
3499*da0073e9SAndroid Build Coastguard Worker            (5, 5, 5),
3500*da0073e9SAndroid Build Coastguard Worker            [
3501*da0073e9SAndroid Build Coastguard Worker                Variable(torch.LongTensor([1, 3]), requires_grad=False),
3502*da0073e9SAndroid Build Coastguard Worker                [2, 4],
3503*da0073e9SAndroid Build Coastguard Worker                slice(None),
3504*da0073e9SAndroid Build Coastguard Worker            ],
3505*da0073e9SAndroid Build Coastguard Worker        )
3506*da0073e9SAndroid Build Coastguard Worker
3507*da0073e9SAndroid Build Coastguard Worker    def test_setitem_mask(self):
3508*da0073e9SAndroid Build Coastguard Worker        mask = torch.BoolTensor(5, 5).bernoulli_()
3509*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5, 5), Variable(mask))
3510*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((5,), Variable(mask[0]))
3511*da0073e9SAndroid Build Coastguard Worker        self._test_setitem((1,), Variable(mask[0, 0:1]))
3512*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5, 5), Variable(mask))
3513*da0073e9SAndroid Build Coastguard Worker        self._test_setitem_tensor((5,), Variable(mask[0]))
3514*da0073e9SAndroid Build Coastguard Worker
3515*da0073e9SAndroid Build Coastguard Worker    def test_select_sum(self):
3516*da0073e9SAndroid Build Coastguard Worker        # both select and sum return Scalars in ATen; ensure they work together.
3517*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, dtype=torch.double, requires_grad=True)
3518*da0073e9SAndroid Build Coastguard Worker
3519*da0073e9SAndroid Build Coastguard Worker        def func(x):
3520*da0073e9SAndroid Build Coastguard Worker            return x.select(0, 1).sum()
3521*da0073e9SAndroid Build Coastguard Worker
3522*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [x])
3523*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, [x])
3524*da0073e9SAndroid Build Coastguard Worker
3525*da0073e9SAndroid Build Coastguard Worker    def test_diagonal_expanded_v(self):
3526*da0073e9SAndroid Build Coastguard Worker        value = torch.rand([])
3527*da0073e9SAndroid Build Coastguard Worker        v_expanded = torch.tensor(value).expand(10)
3528*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, 10, dtype=torch.double, requires_grad=True)
3529*da0073e9SAndroid Build Coastguard Worker        (result,) = torch.autograd.grad(a.diagonal(), a, v_expanded)
3530*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, torch.eye(10, dtype=torch.double) * value)
3531*da0073e9SAndroid Build Coastguard Worker
3532*da0073e9SAndroid Build Coastguard Worker    def test_select_expanded_v(self):
3533*da0073e9SAndroid Build Coastguard Worker        v_expanded = torch.rand(10).expand(10, 10)
3534*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, 10, 10, requires_grad=True)
3535*da0073e9SAndroid Build Coastguard Worker        (result,) = torch.autograd.grad(a[0], a, v_expanded)
3536*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(10, 10, 10)
3537*da0073e9SAndroid Build Coastguard Worker        expected[0] = v_expanded
3538*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
3539*da0073e9SAndroid Build Coastguard Worker
3540*da0073e9SAndroid Build Coastguard Worker    def test_slice_expanded_v(self):
3541*da0073e9SAndroid Build Coastguard Worker        v_expanded = torch.rand(10, 1).expand(2, 10, 10)
3542*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, 10, 10, requires_grad=True)
3543*da0073e9SAndroid Build Coastguard Worker        (result,) = torch.autograd.grad(a[3:5], a, v_expanded)
3544*da0073e9SAndroid Build Coastguard Worker        expected = torch.zeros(10, 10, 10)
3545*da0073e9SAndroid Build Coastguard Worker        expected[3:5] = v_expanded
3546*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
3547*da0073e9SAndroid Build Coastguard Worker
3548*da0073e9SAndroid Build Coastguard Worker    def test_unused_output(self):
3549*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10, requires_grad=True)
3550*da0073e9SAndroid Build Coastguard Worker        outputs = x.chunk(5)
3551*da0073e9SAndroid Build Coastguard Worker        o = outputs[2]
3552*da0073e9SAndroid Build Coastguard Worker        o = o * 4 + 2
3553*da0073e9SAndroid Build Coastguard Worker        o.sum().backward()
3554*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.zeros(10, 10)
3555*da0073e9SAndroid Build Coastguard Worker        expected_grad[4:6] = 4
3556*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, expected_grad)
3557*da0073e9SAndroid Build Coastguard Worker
3558*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3559*da0073e9SAndroid Build Coastguard Worker            x.grad.zero_()
3560*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.randn(2, 10)
3561*da0073e9SAndroid Build Coastguard Worker        outputs = x.chunk(5)
3562*da0073e9SAndroid Build Coastguard Worker        outputs[0].backward(grad_output)
3563*da0073e9SAndroid Build Coastguard Worker        expected_grad = torch.zeros(10, 10)
3564*da0073e9SAndroid Build Coastguard Worker        expected_grad[:2] = grad_output
3565*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, expected_grad)
3566*da0073e9SAndroid Build Coastguard Worker
3567*da0073e9SAndroid Build Coastguard Worker    # TODO: opinfo this or move to the sparse test suite
3568*da0073e9SAndroid Build Coastguard Worker    def _test_sparse_gather(self, size_x, size_ind, dim):
3569*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size_x, requires_grad=True)
3570*da0073e9SAndroid Build Coastguard Worker        if len(size_ind) > 0 and len(size_x) > 0:
3571*da0073e9SAndroid Build Coastguard Worker            ind = torch.randint(x.size(dim), size_ind)
3572*da0073e9SAndroid Build Coastguard Worker        else:
3573*da0073e9SAndroid Build Coastguard Worker            ind = torch.zeros(size_ind, dtype=torch.int64)
3574*da0073e9SAndroid Build Coastguard Worker        out = torch.gather(x, dim, ind, sparse_grad=False)
3575*da0073e9SAndroid Build Coastguard Worker        grad = torch.rand_like(out)
3576*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
3577*da0073e9SAndroid Build Coastguard Worker        grad_dense = x.grad.clone()
3578*da0073e9SAndroid Build Coastguard Worker        x.grad = None
3579*da0073e9SAndroid Build Coastguard Worker        out = torch.gather(x, dim, ind, sparse_grad=True)
3580*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
3581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_dense, x.grad.to_dense())
3582*da0073e9SAndroid Build Coastguard Worker
3583*da0073e9SAndroid Build Coastguard Worker    def test_sparse_gather_dim0(self):
3584*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_gather((10, 10), (5, 10), 0)
3585*da0073e9SAndroid Build Coastguard Worker
3586*da0073e9SAndroid Build Coastguard Worker    def test_sparse_gather_dim1(self):
3587*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1)
3588*da0073e9SAndroid Build Coastguard Worker
3589*da0073e9SAndroid Build Coastguard Worker    def test_sparse_gather_dim_neg(self):
3590*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1)
3591*da0073e9SAndroid Build Coastguard Worker
3592*da0073e9SAndroid Build Coastguard Worker    def test_sparse_gather_ind_scalar(self):
3593*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_gather((10,), (), 0)
3594*da0073e9SAndroid Build Coastguard Worker
3595*da0073e9SAndroid Build Coastguard Worker    def test_sparse_gather_x_scalar(self):
3596*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_gather((), (2,), 0)
3597*da0073e9SAndroid Build Coastguard Worker
3598*da0073e9SAndroid Build Coastguard Worker    def test_sparse_gather_both_scalar(self):
3599*da0073e9SAndroid Build Coastguard Worker        self._test_sparse_gather((), (), 0)
3600*da0073e9SAndroid Build Coastguard Worker
3601*da0073e9SAndroid Build Coastguard Worker    def test_gc_in_destructor(self):
3602*da0073e9SAndroid Build Coastguard Worker        """
3603*da0073e9SAndroid Build Coastguard Worker        Previously, if a Function destructor triggered a garbage collection,
3604*da0073e9SAndroid Build Coastguard Worker        the Variable's tp_dealloc handler would get called twice leading to a
3605*da0073e9SAndroid Build Coastguard Worker        segfault.
3606*da0073e9SAndroid Build Coastguard Worker        """
3607*da0073e9SAndroid Build Coastguard Worker
3608*da0073e9SAndroid Build Coastguard Worker        class CollectOnDelete(Function):
3609*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
3610*da0073e9SAndroid Build Coastguard Worker                return x
3611*da0073e9SAndroid Build Coastguard Worker
3612*da0073e9SAndroid Build Coastguard Worker            def backward(self, grad_output):
3613*da0073e9SAndroid Build Coastguard Worker                return grad_output
3614*da0073e9SAndroid Build Coastguard Worker
3615*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
3616*da0073e9SAndroid Build Coastguard Worker                gc.collect()
3617*da0073e9SAndroid Build Coastguard Worker
3618*da0073e9SAndroid Build Coastguard Worker        for _ in range(10):
3619*da0073e9SAndroid Build Coastguard Worker            CollectOnDelete().forward(torch.randn(1, requires_grad=True)).backward()
3620*da0073e9SAndroid Build Coastguard Worker
3621*da0073e9SAndroid Build Coastguard Worker    def test_naughty_autograd_function_attribute_access(self):
3622*da0073e9SAndroid Build Coastguard Worker        class Id(Function):
3623*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3624*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3625*da0073e9SAndroid Build Coastguard Worker                return x
3626*da0073e9SAndroid Build Coastguard Worker
3627*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3628*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_x):
3629*da0073e9SAndroid Build Coastguard Worker                return grad_x
3630*da0073e9SAndroid Build Coastguard Worker
3631*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(DeprecationWarning, "should not be instantiated"):
3632*da0073e9SAndroid Build Coastguard Worker            f = Id()
3633*da0073e9SAndroid Build Coastguard Worker
3634*da0073e9SAndroid Build Coastguard Worker        # After raising warning, should still return an instance
3635*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(f, Id)
3636*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(1, requires_grad=True)
3637*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3638*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "non-static forward method is deprecated"
3639*da0073e9SAndroid Build Coastguard Worker        ):
3640*da0073e9SAndroid Build Coastguard Worker            f(x)
3641*da0073e9SAndroid Build Coastguard Worker        t = Id.apply(x)
3642*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.grad_fn.name(), "IdBackward")
3643*da0073e9SAndroid Build Coastguard Worker
3644*da0073e9SAndroid Build Coastguard Worker        # THPFunction is the base class of both grad_fn and autograd functions,
3645*da0073e9SAndroid Build Coastguard Worker        # which means that a lot of accessors on them may segfault. Test that we
3646*da0073e9SAndroid Build Coastguard Worker        # properly error in this case.
3647*da0073e9SAndroid Build Coastguard Worker        t = torch.ones(1, requires_grad=True)
3648*da0073e9SAndroid Build Coastguard Worker        t._backward_hooks = {}
3649*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3650*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Attribute '_register_hook_dict' is invalid"
3651*da0073e9SAndroid Build Coastguard Worker        ):
3652*da0073e9SAndroid Build Coastguard Worker            f._register_hook_dict(t)
3653*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3654*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Attribute 'register_hook' is invalid"
3655*da0073e9SAndroid Build Coastguard Worker        ):
3656*da0073e9SAndroid Build Coastguard Worker            f.register_hook(lambda x, y: None)
3657*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3658*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Attribute 'next_functions' is invalid"
3659*da0073e9SAndroid Build Coastguard Worker        ):
3660*da0073e9SAndroid Build Coastguard Worker            f.next_functions
3661*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"):
3662*da0073e9SAndroid Build Coastguard Worker            f.name()
3663*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
3664*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "underlying PyNode has already been deallocated"
3665*da0073e9SAndroid Build Coastguard Worker        ):
3666*da0073e9SAndroid Build Coastguard Worker            f.metadata
3667*da0073e9SAndroid Build Coastguard Worker
3668*da0073e9SAndroid Build Coastguard Worker    @unittest.expectedFailure
3669*da0073e9SAndroid Build Coastguard Worker    def test_naughty_anomaly_access(self):
3670*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
3671*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3672*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3673*da0073e9SAndroid Build Coastguard Worker                return x
3674*da0073e9SAndroid Build Coastguard Worker
3675*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3676*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, g):
3677*da0073e9SAndroid Build Coastguard Worker                return g
3678*da0073e9SAndroid Build Coastguard Worker
3679*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(1, requires_grad=True)
3680*da0073e9SAndroid Build Coastguard Worker        y = MyFunction.apply(x)
3681*da0073e9SAndroid Build Coastguard Worker        y.backward()
3682*da0073e9SAndroid Build Coastguard Worker        y.grad_fn.metadata
3683*da0073e9SAndroid Build Coastguard Worker        g = y.grad_fn
3684*da0073e9SAndroid Build Coastguard Worker        del y
3685*da0073e9SAndroid Build Coastguard Worker        g.metadata  # this currently fails, but shouldn't
3686*da0073e9SAndroid Build Coastguard Worker
3687*da0073e9SAndroid Build Coastguard Worker    def test_naughty_autograd_function_stashing_ctx(self):
3688*da0073e9SAndroid Build Coastguard Worker        saved_ctx = []
3689*da0073e9SAndroid Build Coastguard Worker
3690*da0073e9SAndroid Build Coastguard Worker        class Id(Function):
3691*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3692*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3693*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
3694*da0073e9SAndroid Build Coastguard Worker                return x
3695*da0073e9SAndroid Build Coastguard Worker
3696*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3697*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_x):
3698*da0073e9SAndroid Build Coastguard Worker                saved_ctx.append(ctx)
3699*da0073e9SAndroid Build Coastguard Worker                return ctx.saved_tensors
3700*da0073e9SAndroid Build Coastguard Worker
3701*da0073e9SAndroid Build Coastguard Worker        p = torch.zeros(1, requires_grad=True)
3702*da0073e9SAndroid Build Coastguard Worker        loss = Id.apply(p)
3703*da0073e9SAndroid Build Coastguard Worker        loss.backward(retain_graph=True)
3704*da0073e9SAndroid Build Coastguard Worker        del loss
3705*da0073e9SAndroid Build Coastguard Worker        # At this point in time, it complains that the graph has been freed
3706*da0073e9SAndroid Build Coastguard Worker        # (which indeed true, although a somewhat indirect way of stating the
3707*da0073e9SAndroid Build Coastguard Worker        # problem).
3708*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors)
3709*da0073e9SAndroid Build Coastguard Worker
3710*da0073e9SAndroid Build Coastguard Worker    def test_custom_autograd_repeated_grad_grad(self):
3711*da0073e9SAndroid Build Coastguard Worker        # This test failed the equality check in PR #22983; it's an interesting
3712*da0073e9SAndroid Build Coastguard Worker        # and different test case worth enshrining.  mult1 is not testing
3713*da0073e9SAndroid Build Coastguard Worker        # anything that interesting, but mult2 is the interesting case.
3714*da0073e9SAndroid Build Coastguard Worker
3715*da0073e9SAndroid Build Coastguard Worker        def mult1(x):
3716*da0073e9SAndroid Build Coastguard Worker            return x.prod(dim=-1).prod(dim=-1)
3717*da0073e9SAndroid Build Coastguard Worker
3718*da0073e9SAndroid Build Coastguard Worker        class Mult(torch.autograd.Function):
3719*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3720*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3721*da0073e9SAndroid Build Coastguard Worker                y = mult1(x)
3722*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x, y)
3723*da0073e9SAndroid Build Coastguard Worker                return y
3724*da0073e9SAndroid Build Coastguard Worker
3725*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3726*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3727*da0073e9SAndroid Build Coastguard Worker                x, y = ctx.saved_tensors
3728*da0073e9SAndroid Build Coastguard Worker                return (grad_output * y)[:, None, None] / x
3729*da0073e9SAndroid Build Coastguard Worker
3730*da0073e9SAndroid Build Coastguard Worker        mult2 = Mult.apply
3731*da0073e9SAndroid Build Coastguard Worker
3732*da0073e9SAndroid Build Coastguard Worker        def check_gradgrad_repeated(x, y):
3733*da0073e9SAndroid Build Coastguard Worker            (gy,) = torch.autograd.grad(y[0], x, create_graph=True)
3734*da0073e9SAndroid Build Coastguard Worker            (ggy_1,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
3735*da0073e9SAndroid Build Coastguard Worker            (gy,) = torch.autograd.grad(y[0], x, create_graph=True)
3736*da0073e9SAndroid Build Coastguard Worker            (ggy_2,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True)
3737*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1])
3738*da0073e9SAndroid Build Coastguard Worker
3739*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 4, 4).requires_grad_()
3740*da0073e9SAndroid Build Coastguard Worker        check_gradgrad_repeated(x, mult1(x))
3741*da0073e9SAndroid Build Coastguard Worker        check_gradgrad_repeated(x, mult2(x))
3742*da0073e9SAndroid Build Coastguard Worker
3743*da0073e9SAndroid Build Coastguard Worker    def test_custom_autograd_no_early_free(self):
3744*da0073e9SAndroid Build Coastguard Worker        # This test failed complaining that buffers had already been freed
3745*da0073e9SAndroid Build Coastguard Worker        # prior to #22983.  Also pretty interesting test case.
3746*da0073e9SAndroid Build Coastguard Worker        class Double(torch.autograd.Function):
3747*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3748*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
3749*da0073e9SAndroid Build Coastguard Worker                y = x**2
3750*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x, y)
3751*da0073e9SAndroid Build Coastguard Worker                return y
3752*da0073e9SAndroid Build Coastguard Worker
3753*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3754*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3755*da0073e9SAndroid Build Coastguard Worker                x, _ = ctx.saved_tensors
3756*da0073e9SAndroid Build Coastguard Worker                return grad_output * 2 * x
3757*da0073e9SAndroid Build Coastguard Worker
3758*da0073e9SAndroid Build Coastguard Worker        # this is equivalent, but uses the output of .forward() in .backward()
3759*da0073e9SAndroid Build Coastguard Worker        class Double2(Double):
3760*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3761*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
3762*da0073e9SAndroid Build Coastguard Worker                x, y = ctx.saved_tensors
3763*da0073e9SAndroid Build Coastguard Worker                return grad_output * 2 * y / x
3764*da0073e9SAndroid Build Coastguard Worker
3765*da0073e9SAndroid Build Coastguard Worker        double = Double.apply
3766*da0073e9SAndroid Build Coastguard Worker        double2 = Double2.apply
3767*da0073e9SAndroid Build Coastguard Worker
3768*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2).double().requires_grad_()
3769*da0073e9SAndroid Build Coastguard Worker
3770*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(double, x))
3771*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradgradcheck(double, x))
3772*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(double2, x))
3773*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradgradcheck(double2, x))
3774*da0073e9SAndroid Build Coastguard Worker
3775*da0073e9SAndroid Build Coastguard Worker        y = double(x)
3776*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(y, x, create_graph=True)
3777*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(y, x)
3778*da0073e9SAndroid Build Coastguard Worker
3779*da0073e9SAndroid Build Coastguard Worker        y = double2(x)
3780*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(y, x, create_graph=True)
3781*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(y, x)  # should not error!
3782*da0073e9SAndroid Build Coastguard Worker
3783*da0073e9SAndroid Build Coastguard Worker    def test_detach(self):
3784*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10, requires_grad=True)
3785*da0073e9SAndroid Build Coastguard Worker        y = x + 2
3786*da0073e9SAndroid Build Coastguard Worker        y = y.detach()
3787*da0073e9SAndroid Build Coastguard Worker        z = y * 4 + 2
3788*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
3789*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(z.requires_grad)
3790*da0073e9SAndroid Build Coastguard Worker
3791*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10, requires_grad=True)
3792*da0073e9SAndroid Build Coastguard Worker        y = x * 2
3793*da0073e9SAndroid Build Coastguard Worker        y = y.detach()
3794*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
3795*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(y.grad_fn)
3796*da0073e9SAndroid Build Coastguard Worker        z = x + y
3797*da0073e9SAndroid Build Coastguard Worker        z.sum().backward()
3798*da0073e9SAndroid Build Coastguard Worker        # This is an incorrect gradient, but we assume that's what the user
3799*da0073e9SAndroid Build Coastguard Worker        # wanted. detach() is an advanced option.
3800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(10, 10))
3801*da0073e9SAndroid Build Coastguard Worker
3802*da0073e9SAndroid Build Coastguard Worker        # in-place detach
3803*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10, requires_grad=True)
3804*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 10, requires_grad=True)
3805*da0073e9SAndroid Build Coastguard Worker        a = x * 2
3806*da0073e9SAndroid Build Coastguard Worker        (y + a).sum().backward(retain_graph=True)
3807*da0073e9SAndroid Build Coastguard Worker        a.detach_()
3808*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(a.requires_grad)
3809*da0073e9SAndroid Build Coastguard Worker        (y + a).sum().backward()  # this won't backprop to x
3810*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(10, 10) * 2)
3811*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, torch.ones(10, 10) * 2)
3812*da0073e9SAndroid Build Coastguard Worker
3813*da0073e9SAndroid Build Coastguard Worker        # in-place detach on a view raises an exception
3814*da0073e9SAndroid Build Coastguard Worker        view = x.narrow(0, 1, 4)
3815*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "view", lambda: view.detach_())
3816*da0073e9SAndroid Build Coastguard Worker
3817*da0073e9SAndroid Build Coastguard Worker    def test_detach_base(self):
3818*da0073e9SAndroid Build Coastguard Worker        "detaching base does not detach view"
3819*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10, requires_grad=True)
3820*da0073e9SAndroid Build Coastguard Worker        view = x.narrow(0, 1, 4)
3821*da0073e9SAndroid Build Coastguard Worker        x.detach_()
3822*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x.requires_grad)
3823*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(view.requires_grad)
3824*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(view.grad_fn)
3825*da0073e9SAndroid Build Coastguard Worker        self.assertIs(view._base, x)
3826*da0073e9SAndroid Build Coastguard Worker
3827*da0073e9SAndroid Build Coastguard Worker    def test_detach_then_inplace_raises_in_autograd(self):
3828*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([], requires_grad=True)
3829*da0073e9SAndroid Build Coastguard Worker        orig_x = x.detach().clone()
3830*da0073e9SAndroid Build Coastguard Worker
3831*da0073e9SAndroid Build Coastguard Worker        y = x**2  # saves x
3832*da0073e9SAndroid Build Coastguard Worker        z = x.detach()
3833*da0073e9SAndroid Build Coastguard Worker        z.zero_()
3834*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "has been modified by an inplace"):
3835*da0073e9SAndroid Build Coastguard Worker            y.backward()
3836*da0073e9SAndroid Build Coastguard Worker
3837*da0073e9SAndroid Build Coastguard Worker    def _test_type_conversion_backward(self, t):
3838*da0073e9SAndroid Build Coastguard Worker        fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True)
3839*da0073e9SAndroid Build Coastguard Worker        fvar.double().sum().backward()
3840*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fvar.grad, torch.ones_like(fvar))
3841*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(fvar.grad), type(fvar))
3842*da0073e9SAndroid Build Coastguard Worker        dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=True)
3843*da0073e9SAndroid Build Coastguard Worker        dvar.float().sum().backward()
3844*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dvar.grad, torch.ones_like(dvar))
3845*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(dvar.grad), type(dvar))
3846*da0073e9SAndroid Build Coastguard Worker
3847*da0073e9SAndroid Build Coastguard Worker    def test_type_conversions(self):
3848*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
3849*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.float(), torch.FloatTensor)
3850*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x.int(), torch.IntTensor)
3851*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3852*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(x.float().cuda(), torch.cuda.FloatTensor)
3853*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(x.int().cuda(), torch.cuda.IntTensor)
3854*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(x.int().cuda().cpu(), torch.IntTensor)
3855*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() >= 2:
3856*da0073e9SAndroid Build Coastguard Worker                x2 = x.float().cuda(1)
3857*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(x2, torch.cuda.FloatTensor)
3858*da0073e9SAndroid Build Coastguard Worker                self.assertIs(x2.get_device(), 1)
3859*da0073e9SAndroid Build Coastguard Worker                x2 = x.float().cuda()
3860*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(x2, torch.cuda.FloatTensor)
3861*da0073e9SAndroid Build Coastguard Worker                self.assertIs(x2.get_device(), 0)
3862*da0073e9SAndroid Build Coastguard Worker                x2 = x2.cuda(1)
3863*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(x2, torch.cuda.FloatTensor)
3864*da0073e9SAndroid Build Coastguard Worker                self.assertIs(x2.get_device(), 1)
3865*da0073e9SAndroid Build Coastguard Worker                y = Variable(torch.randn(5).cuda(1), requires_grad=True)
3866*da0073e9SAndroid Build Coastguard Worker                y.cpu().sum().backward()
3867*da0073e9SAndroid Build Coastguard Worker                self.assertIs(y.grad.get_device(), 1)
3868*da0073e9SAndroid Build Coastguard Worker                self.assertIs(y.long().get_device(), 1)
3869*da0073e9SAndroid Build Coastguard Worker
3870*da0073e9SAndroid Build Coastguard Worker        for t in [
3871*da0073e9SAndroid Build Coastguard Worker            torch.DoubleTensor,
3872*da0073e9SAndroid Build Coastguard Worker            torch.FloatTensor,
3873*da0073e9SAndroid Build Coastguard Worker            torch.IntTensor,
3874*da0073e9SAndroid Build Coastguard Worker            torch.ByteTensor,
3875*da0073e9SAndroid Build Coastguard Worker        ]:
3876*da0073e9SAndroid Build Coastguard Worker            for y_var in (True, False):
3877*da0073e9SAndroid Build Coastguard Worker                y = torch.randint(5, (5, 5), dtype=t.dtype)
3878*da0073e9SAndroid Build Coastguard Worker                y = Variable(y) if y_var else y
3879*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(x.type(t), t)
3880*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(x.type_as(y), t)
3881*da0073e9SAndroid Build Coastguard Worker                # TODO: t.dtype should work
3882*da0073e9SAndroid Build Coastguard Worker                t_dtype = t().dtype
3883*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(x.type(t_dtype), t)
3884*da0073e9SAndroid Build Coastguard Worker                self.assertIs(t_dtype, x.type(t_dtype).dtype)
3885*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(y.data_ptr(), y.type(t).data_ptr())
3886*da0073e9SAndroid Build Coastguard Worker                if torch.cuda.is_available():
3887*da0073e9SAndroid Build Coastguard Worker                    for x_cuda in (True, False):
3888*da0073e9SAndroid Build Coastguard Worker                        for y_cuda in (True, False):
3889*da0073e9SAndroid Build Coastguard Worker                            x_c = x.cuda() if x_cuda else x
3890*da0073e9SAndroid Build Coastguard Worker                            y_c = y.cuda() if y_cuda else y
3891*da0073e9SAndroid Build Coastguard Worker                            _, y_type = y_c.type().rsplit(".", 1)
3892*da0073e9SAndroid Build Coastguard Worker                            y_typestr = ("torch.cuda." if y_cuda else "torch.") + y_type
3893*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(y_c.type(), x_c.type(y_typestr).type())
3894*da0073e9SAndroid Build Coastguard Worker                            self.assertIs(y_c.dtype, x_c.type(y_c.dtype).dtype)
3895*da0073e9SAndroid Build Coastguard Worker                            self.assertEqual(
3896*da0073e9SAndroid Build Coastguard Worker                                y_c.data_ptr(),
3897*da0073e9SAndroid Build Coastguard Worker                                y_c.cuda().data_ptr() if y_cuda else y_c.data_ptr(),
3898*da0073e9SAndroid Build Coastguard Worker                            )
3899*da0073e9SAndroid Build Coastguard Worker
3900*da0073e9SAndroid Build Coastguard Worker        self._test_type_conversion_backward(lambda x: x)
3901*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
3902*da0073e9SAndroid Build Coastguard Worker            self._test_type_conversion_backward(lambda x: x.cuda())
3903*da0073e9SAndroid Build Coastguard Worker            if torch.cuda.device_count() >= 2:
3904*da0073e9SAndroid Build Coastguard Worker                # one of these has to be the non-default device
3905*da0073e9SAndroid Build Coastguard Worker                self._test_type_conversion_backward(lambda x: x.cuda(0))
3906*da0073e9SAndroid Build Coastguard Worker                self._test_type_conversion_backward(lambda x: x.cuda(1))
3907*da0073e9SAndroid Build Coastguard Worker
3908*da0073e9SAndroid Build Coastguard Worker    def test_isolated_node(self):
3909*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
3910*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
3911*da0073e9SAndroid Build Coastguard Worker
3912*da0073e9SAndroid Build Coastguard Worker        a = x + y
3913*da0073e9SAndroid Build Coastguard Worker        b = torch.max(a, 1, True)[1].repeat(1, 5).double()
3914*da0073e9SAndroid Build Coastguard Worker        o = (b + a).sum()
3915*da0073e9SAndroid Build Coastguard Worker        o.backward()
3916*da0073e9SAndroid Build Coastguard Worker
3917*da0073e9SAndroid Build Coastguard Worker    def test_shape(self):
3918*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 4)
3919*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, len(x.shape))
3920*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.shape[0], 3)
3921*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.shape[1], 4)
3922*da0073e9SAndroid Build Coastguard Worker
3923*da0073e9SAndroid Build Coastguard Worker    def test_numpy_requires_grad(self):
3924*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
3925*da0073e9SAndroid Build Coastguard Worker        err_msg_outputs = r"Can't call numpy\(\) on Tensor that requires grad. Use tensor.detach\(\).numpy\(\) instead."
3926*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg_outputs):
3927*da0073e9SAndroid Build Coastguard Worker            x.numpy()
3928*da0073e9SAndroid Build Coastguard Worker
3929*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3930*da0073e9SAndroid Build Coastguard Worker            x.numpy()
3931*da0073e9SAndroid Build Coastguard Worker
3932*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
3933*da0073e9SAndroid Build Coastguard Worker        x.numpy()
3934*da0073e9SAndroid Build Coastguard Worker
3935*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
3936*da0073e9SAndroid Build Coastguard Worker            x.numpy()
3937*da0073e9SAndroid Build Coastguard Worker
3938*da0073e9SAndroid Build Coastguard Worker    def test_return_leaf(self):
3939*da0073e9SAndroid Build Coastguard Worker        class Identity(Function):
3940*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3941*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
3942*da0073e9SAndroid Build Coastguard Worker                return a, a + b
3943*da0073e9SAndroid Build Coastguard Worker
3944*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3945*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_a, grad_b):
3946*da0073e9SAndroid Build Coastguard Worker                return grad_a + grad_b, grad_b
3947*da0073e9SAndroid Build Coastguard Worker
3948*da0073e9SAndroid Build Coastguard Worker        hook_called = [False]
3949*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
3950*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
3951*da0073e9SAndroid Build Coastguard Worker
3952*da0073e9SAndroid Build Coastguard Worker        q, p = Identity.apply(x, y)
3953*da0073e9SAndroid Build Coastguard Worker
3954*da0073e9SAndroid Build Coastguard Worker        # Make sure hooks only receive grad from usage of q, not x.
3955*da0073e9SAndroid Build Coastguard Worker        def hook(grad):
3956*da0073e9SAndroid Build Coastguard Worker            hook_called[0] = True
3957*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grad, torch.ones(5, 5))
3958*da0073e9SAndroid Build Coastguard Worker
3959*da0073e9SAndroid Build Coastguard Worker        q.register_hook(hook)
3960*da0073e9SAndroid Build Coastguard Worker        (q + p + x).sum().backward()
3961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(5, 5) * 3)
3962*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, torch.ones(5, 5))
3963*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hook_called[0])
3964*da0073e9SAndroid Build Coastguard Worker
3965*da0073e9SAndroid Build Coastguard Worker    def test_return_leaf_inplace(self):
3966*da0073e9SAndroid Build Coastguard Worker        class Inplace(InplaceFunction):
3967*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3968*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
3969*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(a)
3970*da0073e9SAndroid Build Coastguard Worker                return a.add_(b), b + 2
3971*da0073e9SAndroid Build Coastguard Worker
3972*da0073e9SAndroid Build Coastguard Worker            @staticmethod
3973*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_a, grad_b):
3974*da0073e9SAndroid Build Coastguard Worker                return grad_a, grad_a + grad_b
3975*da0073e9SAndroid Build Coastguard Worker
3976*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
3977*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, requires_grad=True)
3978*da0073e9SAndroid Build Coastguard Worker
3979*da0073e9SAndroid Build Coastguard Worker        q, p = Inplace.apply(x, y)
3980*da0073e9SAndroid Build Coastguard Worker        self.assertIs(q, x)
3981*da0073e9SAndroid Build Coastguard Worker        self.assertIs(q.grad_fn.__class__, Inplace._backward_cls)
3982*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(q.requires_grad)
3983*da0073e9SAndroid Build Coastguard Worker        q.sum().backward()
3984*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, torch.ones(5, 5))
3985*da0073e9SAndroid Build Coastguard Worker
3986*da0073e9SAndroid Build Coastguard Worker    def test_leaf_assignment(self):
3987*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5)
3988*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, requires_grad=True)
3989*da0073e9SAndroid Build Coastguard Worker        z = torch.randn(5, requires_grad=True)
3990*da0073e9SAndroid Build Coastguard Worker
3991*da0073e9SAndroid Build Coastguard Worker        x[0] = y
3992*da0073e9SAndroid Build Coastguard Worker        x[1] = 2 * z
3993*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.requires_grad)
3994*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(x.grad_fn, None)
3995*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
3996*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, torch.ones(5))
3997*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.grad, torch.ones(5) * 2)
3998*da0073e9SAndroid Build Coastguard Worker
3999*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_assignment(self):
4000*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
4001*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5)
4002*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
4003*da0073e9SAndroid Build Coastguard Worker            x[0] = y
4004*da0073e9SAndroid Build Coastguard Worker
4005*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.requires_grad)
4006*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x.grad_fn)
4007*da0073e9SAndroid Build Coastguard Worker
4008*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_modifies_version(self):
4009*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, requires_grad=True)
4010*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, requires_grad=True)
4011*da0073e9SAndroid Build Coastguard Worker        z = (x * y).sum()
4012*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
4013*da0073e9SAndroid Build Coastguard Worker            x *= 2
4014*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
4015*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "modified by an inplace operation", lambda: z.backward()
4016*da0073e9SAndroid Build Coastguard Worker        )
4017*da0073e9SAndroid Build Coastguard Worker
4018*da0073e9SAndroid Build Coastguard Worker    def test_increment_version(self):
4019*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(5, requires_grad=True)
4020*da0073e9SAndroid Build Coastguard Worker        v = a._version
4021*da0073e9SAndroid Build Coastguard Worker        torch.autograd.graph.increment_version(a)
4022*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a._version, v + 1)
4023*da0073e9SAndroid Build Coastguard Worker
4024*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(5, dtype=torch.int)
4025*da0073e9SAndroid Build Coastguard Worker        v = a._version
4026*da0073e9SAndroid Build Coastguard Worker        torch.autograd.graph.increment_version(a)
4027*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a._version, v + 1)
4028*da0073e9SAndroid Build Coastguard Worker
4029*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
4030*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(5, requires_grad=True)
4031*da0073e9SAndroid Build Coastguard Worker            # does not error
4032*da0073e9SAndroid Build Coastguard Worker            torch.autograd.graph.increment_version(a)
4033*da0073e9SAndroid Build Coastguard Worker
4034*da0073e9SAndroid Build Coastguard Worker        # does not error
4035*da0073e9SAndroid Build Coastguard Worker        torch.autograd.graph.increment_version(a)
4036*da0073e9SAndroid Build Coastguard Worker
4037*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_input(self):
4038*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
4039*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4040*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
4041*da0073e9SAndroid Build Coastguard Worker                return x
4042*da0073e9SAndroid Build Coastguard Worker
4043*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4044*da0073e9SAndroid Build Coastguard Worker            def backward(self, grad_output):
4045*da0073e9SAndroid Build Coastguard Worker                return grad_output
4046*da0073e9SAndroid Build Coastguard Worker
4047*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, requires_grad=True)
4048*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
4049*da0073e9SAndroid Build Coastguard Worker            y = MyFunction.apply(x)
4050*da0073e9SAndroid Build Coastguard Worker
4051*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.requires_grad)
4052*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(y.grad_fn)
4053*da0073e9SAndroid Build Coastguard Worker
4054*da0073e9SAndroid Build Coastguard Worker    def test_backward_copy(self):
4055*da0073e9SAndroid Build Coastguard Worker        # This tests checks backward engine for a very subtle bug that appreared
4056*da0073e9SAndroid Build Coastguard Worker        # in one of the initial versions of autograd. Gradients tensors were
4057*da0073e9SAndroid Build Coastguard Worker        # simply stored in lists while the function waited for all its gradients
4058*da0073e9SAndroid Build Coastguard Worker        # to be computed. However, sometimes an output was used multiple times,
4059*da0073e9SAndroid Build Coastguard Worker        # so the gradients needed to be summed. Engine used to keep a need_copy
4060*da0073e9SAndroid Build Coastguard Worker        # set of tensors that will need a clone upon next addition and removed
4061*da0073e9SAndroid Build Coastguard Worker        # them from the set as soon as the clone was performed. However, this
4062*da0073e9SAndroid Build Coastguard Worker        # could lead to incorrect results if the same gradient tensor was
4063*da0073e9SAndroid Build Coastguard Worker        # buffered in three places in the graph:
4064*da0073e9SAndroid Build Coastguard Worker        # 1. When accumulating gradients in one of these places it was cloned
4065*da0073e9SAndroid Build Coastguard Worker        #    and removed from need_copy set.
4066*da0073e9SAndroid Build Coastguard Worker        # 2. When accumulating in second place, it wasn't in the need_copy set,
4067*da0073e9SAndroid Build Coastguard Worker        #    so the gradients were simply accumulated in-place (which already
4068*da0073e9SAndroid Build Coastguard Worker        #    modified the grad in 3rd place)
4069*da0073e9SAndroid Build Coastguard Worker        # 3. When accumulating in the third place, it wasn't in the need_copy set
4070*da0073e9SAndroid Build Coastguard Worker        #    as well, so the incoming gradient was summed in-place, yielding
4071*da0073e9SAndroid Build Coastguard Worker        #    incorrect results in all functions, except the first one.
4072*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
4073*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(5, 5, requires_grad=True)
4074*da0073e9SAndroid Build Coastguard Worker        # Simulate that we're in the middle of the graph
4075*da0073e9SAndroid Build Coastguard Worker        a = x + 2
4076*da0073e9SAndroid Build Coastguard Worker        b = y + 2
4077*da0073e9SAndroid Build Coastguard Worker        c = x + 2
4078*da0073e9SAndroid Build Coastguard Worker        # This op will just return grad_output two times in backward
4079*da0073e9SAndroid Build Coastguard Worker        add1 = a + b
4080*da0073e9SAndroid Build Coastguard Worker        add2 = add1 + c
4081*da0073e9SAndroid Build Coastguard Worker        # Simulate a long branch, so grad_output will get buffered.
4082*da0073e9SAndroid Build Coastguard Worker        for _ in range(4):
4083*da0073e9SAndroid Build Coastguard Worker            a = a * 2
4084*da0073e9SAndroid Build Coastguard Worker            b = b * 2
4085*da0073e9SAndroid Build Coastguard Worker            c = c * 2
4086*da0073e9SAndroid Build Coastguard Worker        branch = a + b + c
4087*da0073e9SAndroid Build Coastguard Worker        out = add2 + branch
4088*da0073e9SAndroid Build Coastguard Worker        # expected gradients are:
4089*da0073e9SAndroid Build Coastguard Worker        # for x: 34 (16 from final a, 16 from final c, 2 from add2)
4090*da0073e9SAndroid Build Coastguard Worker        # for y: 17 (16 from final b, 1 from add2)
4091*da0073e9SAndroid Build Coastguard Worker        grad_output = torch.ones(5, 5)
4092*da0073e9SAndroid Build Coastguard Worker        out.backward(grad_output)
4093*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(5, 5) * 34)
4094*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, torch.ones(5, 5) * 17)
4095*da0073e9SAndroid Build Coastguard Worker
4096*da0073e9SAndroid Build Coastguard Worker    def test_save_none_for_backward(self):
4097*da0073e9SAndroid Build Coastguard Worker        test_case = self
4098*da0073e9SAndroid Build Coastguard Worker
4099*da0073e9SAndroid Build Coastguard Worker        class MyFn(Function):
4100*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4101*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
4102*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(None, input, None)
4103*da0073e9SAndroid Build Coastguard Worker                return input * input
4104*da0073e9SAndroid Build Coastguard Worker
4105*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4106*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
4107*da0073e9SAndroid Build Coastguard Worker                n1, input, n2 = ctx.saved_tensors
4108*da0073e9SAndroid Build Coastguard Worker                test_case.assertIsNone(n1)
4109*da0073e9SAndroid Build Coastguard Worker                test_case.assertIsNone(n2)
4110*da0073e9SAndroid Build Coastguard Worker                return 2 * input * grad_output
4111*da0073e9SAndroid Build Coastguard Worker
4112*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
4113*da0073e9SAndroid Build Coastguard Worker        y = MyFn.apply(x)
4114*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
4115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, 2 * x)
4116*da0073e9SAndroid Build Coastguard Worker
4117*da0073e9SAndroid Build Coastguard Worker    def test_too_many_grads(self):
4118*da0073e9SAndroid Build Coastguard Worker        class MyFn(Function):
4119*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4120*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
4121*da0073e9SAndroid Build Coastguard Worker                return input
4122*da0073e9SAndroid Build Coastguard Worker
4123*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4124*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
4125*da0073e9SAndroid Build Coastguard Worker                return grad_output, None, None
4126*da0073e9SAndroid Build Coastguard Worker
4127*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, requires_grad=True)
4128*da0073e9SAndroid Build Coastguard Worker        y = MyFn.apply(x)
4129*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
4130*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones_like(x))
4131*da0073e9SAndroid Build Coastguard Worker
4132*da0073e9SAndroid Build Coastguard Worker    def test_pickle(self):
4133*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10, requires_grad=True)
4134*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 10, requires_grad=False)
4135*da0073e9SAndroid Build Coastguard Worker
4136*da0073e9SAndroid Build Coastguard Worker        def assert_strict_equal(var1, var2):
4137*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(var1, var2)
4138*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(var1.requires_grad, var2.requires_grad)
4139*da0073e9SAndroid Build Coastguard Worker
4140*da0073e9SAndroid Build Coastguard Worker        serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)]
4141*da0073e9SAndroid Build Coastguard Worker        for dump in serialized:
4142*da0073e9SAndroid Build Coastguard Worker            xc, yc = pickle.loads(dump)
4143*da0073e9SAndroid Build Coastguard Worker            assert_strict_equal(xc, x)
4144*da0073e9SAndroid Build Coastguard Worker            assert_strict_equal(yc, y)
4145*da0073e9SAndroid Build Coastguard Worker
4146*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
4147*da0073e9SAndroid Build Coastguard Worker    def test_dep_nograd(self):
4148*da0073e9SAndroid Build Coastguard Worker        class F1(Function):
4149*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4150*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
4151*da0073e9SAndroid Build Coastguard Worker                out = torch.randn(input.size())
4152*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(out)
4153*da0073e9SAndroid Build Coastguard Worker                return input, out
4154*da0073e9SAndroid Build Coastguard Worker
4155*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4156*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output, ignored):
4157*da0073e9SAndroid Build Coastguard Worker                return grad_output
4158*da0073e9SAndroid Build Coastguard Worker
4159*da0073e9SAndroid Build Coastguard Worker        class F2(Function):
4160*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4161*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input, ignored):
4162*da0073e9SAndroid Build Coastguard Worker                return input
4163*da0073e9SAndroid Build Coastguard Worker
4164*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4165*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
4166*da0073e9SAndroid Build Coastguard Worker                return grad_output, None
4167*da0073e9SAndroid Build Coastguard Worker
4168*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, requires_grad=True)
4169*da0073e9SAndroid Build Coastguard Worker        a, b = F1.apply(x)
4170*da0073e9SAndroid Build Coastguard Worker        b = b + 1  # separate F1 from F2 by another op
4171*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(a.requires_grad)
4172*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(b.requires_grad)
4173*da0073e9SAndroid Build Coastguard Worker        c = F2.apply(a, b)
4174*da0073e9SAndroid Build Coastguard Worker        c.backward(torch.ones(c.size()))
4175*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(x.size()))
4176*da0073e9SAndroid Build Coastguard Worker
4177*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_enabled(self):
4178*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1.0], requires_grad=True)
4179*da0073e9SAndroid Build Coastguard Worker        with torch.set_grad_enabled(False):
4180*da0073e9SAndroid Build Coastguard Worker            y = x * 2
4181*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(y.requires_grad)
4182*da0073e9SAndroid Build Coastguard Worker        with torch.set_grad_enabled(True):
4183*da0073e9SAndroid Build Coastguard Worker            y = x * 2
4184*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.requires_grad)
4185*da0073e9SAndroid Build Coastguard Worker        with torch.set_grad_enabled(False):
4186*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(True)
4187*da0073e9SAndroid Build Coastguard Worker            y = x * 2
4188*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(y.requires_grad)
4189*da0073e9SAndroid Build Coastguard Worker
4190*da0073e9SAndroid Build Coastguard Worker    def test_set_grad_enabled_wraps(self):
4191*da0073e9SAndroid Build Coastguard Worker        for decorator in [True, False]:
4192*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
4193*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_grad_enabled())
4194*da0073e9SAndroid Build Coastguard Worker
4195*da0073e9SAndroid Build Coastguard Worker                if decorator:
4196*da0073e9SAndroid Build Coastguard Worker                    # This should not mutate the global grad mode!
4197*da0073e9SAndroid Build Coastguard Worker                    @torch.set_grad_enabled(False)
4198*da0073e9SAndroid Build Coastguard Worker                    def inner_func(x):
4199*da0073e9SAndroid Build Coastguard Worker                        return x.sin()
4200*da0073e9SAndroid Build Coastguard Worker
4201*da0073e9SAndroid Build Coastguard Worker                else:
4202*da0073e9SAndroid Build Coastguard Worker
4203*da0073e9SAndroid Build Coastguard Worker                    def inner_func(x):
4204*da0073e9SAndroid Build Coastguard Worker                        return x.sin()
4205*da0073e9SAndroid Build Coastguard Worker
4206*da0073e9SAndroid Build Coastguard Worker                    # This is non-idiomatic usage!
4207*da0073e9SAndroid Build Coastguard Worker                    # More idiomatic usage: torch.set_grad_enabled(False)(inner_func)
4208*da0073e9SAndroid Build Coastguard Worker                    obj = torch.set_grad_enabled(False)
4209*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(not torch.is_grad_enabled())
4210*da0073e9SAndroid Build Coastguard Worker
4211*da0073e9SAndroid Build Coastguard Worker                    # this will consume the set_grad_enabled global mutation!
4212*da0073e9SAndroid Build Coastguard Worker                    inner_func = obj(inner_func)
4213*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch.is_grad_enabled())
4214*da0073e9SAndroid Build Coastguard Worker
4215*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_grad_enabled())
4216*da0073e9SAndroid Build Coastguard Worker
4217*da0073e9SAndroid Build Coastguard Worker                x = torch.zeros(1, requires_grad=True)
4218*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(not inner_func(x).requires_grad)
4219*da0073e9SAndroid Build Coastguard Worker
4220*da0073e9SAndroid Build Coastguard Worker    def test_simple_reentrant(self):
4221*da0073e9SAndroid Build Coastguard Worker        y_data = torch.randn(2, 2)
4222*da0073e9SAndroid Build Coastguard Worker
4223*da0073e9SAndroid Build Coastguard Worker        class Reenter(Function):
4224*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4225*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
4226*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
4227*da0073e9SAndroid Build Coastguard Worker                    ctx.x = Variable(x, requires_grad=True)
4228*da0073e9SAndroid Build Coastguard Worker                    ctx.y = Variable(y_data, requires_grad=True)
4229*da0073e9SAndroid Build Coastguard Worker                    ctx.output_var = ctx.x * ctx.y
4230*da0073e9SAndroid Build Coastguard Worker                return ctx.output_var.detach()
4231*da0073e9SAndroid Build Coastguard Worker
4232*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4233*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
4234*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
4235*da0073e9SAndroid Build Coastguard Worker                    ctx.output_var.sum().backward()
4236*da0073e9SAndroid Build Coastguard Worker                return ctx.x.grad * grad_output
4237*da0073e9SAndroid Build Coastguard Worker
4238*da0073e9SAndroid Build Coastguard Worker        # Reentrant starts on CPU thread, finishs on GPU thread
4239*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
4240*da0073e9SAndroid Build Coastguard Worker        out = Reenter.apply(x)
4241*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
4242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, y_data)
4243*da0073e9SAndroid Build Coastguard Worker
4244*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_child_error(self):
4245*da0073e9SAndroid Build Coastguard Worker        # Parent graph.
4246*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(3, 3, requires_grad=True)
4247*da0073e9SAndroid Build Coastguard Worker        c = a * a
4248*da0073e9SAndroid Build Coastguard Worker
4249*da0073e9SAndroid Build Coastguard Worker        # Reentrant child graph.
4250*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(3, 3, requires_grad=True)
4251*da0073e9SAndroid Build Coastguard Worker        e = b * b
4252*da0073e9SAndroid Build Coastguard Worker        f = TestAutograd.SimulateBackwardError.apply(e)
4253*da0073e9SAndroid Build Coastguard Worker        reentrant_root = f.sum()
4254*da0073e9SAndroid Build Coastguard Worker
4255*da0073e9SAndroid Build Coastguard Worker        class ReentrantFunc(Function):
4256*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4257*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp):
4258*da0073e9SAndroid Build Coastguard Worker                return inp.clone()
4259*da0073e9SAndroid Build Coastguard Worker
4260*da0073e9SAndroid Build Coastguard Worker            @staticmethod
4261*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
4262*da0073e9SAndroid Build Coastguard Worker                # Reentrant backward in child will throw an error.
4263*da0073e9SAndroid Build Coastguard Worker                reentrant_root.backward()
4264*da0073e9SAndroid Build Coastguard Worker                return grad
4265*da0073e9SAndroid Build Coastguard Worker
4266*da0073e9SAndroid Build Coastguard Worker        d = ReentrantFunc.apply(c)
4267*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Simulate error"):
4268*da0073e9SAndroid Build Coastguard Worker            d.sum().backward()
4269*da0073e9SAndroid Build Coastguard Worker
4270*da0073e9SAndroid Build Coastguard Worker    def test_var_mean_differentiable(self):
4271*da0073e9SAndroid Build Coastguard Worker        dim = [2, 4]
4272*da0073e9SAndroid Build Coastguard Worker        keepdim = False
4273*da0073e9SAndroid Build Coastguard Worker        input1 = torch.randn(3, 4, 5, 6, 2, 3, requires_grad=True)
4274*da0073e9SAndroid Build Coastguard Worker        input2 = deepcopy(input1)
4275*da0073e9SAndroid Build Coastguard Worker        var1, mean1 = torch.var_mean(input1, dim=dim, keepdim=keepdim)
4276*da0073e9SAndroid Build Coastguard Worker        var2 = input2.var(dim=dim, keepdim=keepdim)
4277*da0073e9SAndroid Build Coastguard Worker        mean2 = input2.mean(dim=dim, keepdim=keepdim)
4278*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(3, 4, 6, 3, requires_grad=True)
4279*da0073e9SAndroid Build Coastguard Worker
4280*da0073e9SAndroid Build Coastguard Worker        r1 = var1 * var1 * mean1 * mean1
4281*da0073e9SAndroid Build Coastguard Worker        r2 = var2 * var2 * mean2 * mean2
4282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(r1, r2, rtol=0.01, atol=0.0)
4283*da0073e9SAndroid Build Coastguard Worker
4284*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(r1, grad)
4285*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward(r2, grad)
4286*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(input1.grad, input2.grad, rtol=0.01, atol=0.0)
4287*da0073e9SAndroid Build Coastguard Worker
4288*da0073e9SAndroid Build Coastguard Worker    @skipIfNoLapack
4289*da0073e9SAndroid Build Coastguard Worker    def test_lobpcg(self):
4290*da0073e9SAndroid Build Coastguard Worker        def func(k, A, largest=True, B=None):
4291*da0073e9SAndroid Build Coastguard Worker            X_shape = list(A.shape)
4292*da0073e9SAndroid Build Coastguard Worker            X_shape[-1] = k
4293*da0073e9SAndroid Build Coastguard Worker            X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device)
4294*da0073e9SAndroid Build Coastguard Worker            if A.dim() > 2:
4295*da0073e9SAndroid Build Coastguard Worker                X = X.expand(X_shape)
4296*da0073e9SAndroid Build Coastguard Worker
4297*da0073e9SAndroid Build Coastguard Worker            D, U = torch.lobpcg(A=A, k=k, B=B, X=X, largest=largest)
4298*da0073e9SAndroid Build Coastguard Worker
4299*da0073e9SAndroid Build Coastguard Worker            # LOBPCG uses a random initial eigenspace approximation
4300*da0073e9SAndroid Build Coastguard Worker            # if parameter `X` is not provided.
4301*da0073e9SAndroid Build Coastguard Worker            # This may cause a non-deterministic behavior
4302*da0073e9SAndroid Build Coastguard Worker            # when it comes to the sign of an eigenvector
4303*da0073e9SAndroid Build Coastguard Worker            # (note if v is an eigenvector, so is -v),
4304*da0073e9SAndroid Build Coastguard Worker            # hence we eliminate this non-determinism
4305*da0073e9SAndroid Build Coastguard Worker            # by making sure that each column of U
4306*da0073e9SAndroid Build Coastguard Worker            # gets multiplied by the sign of its max (in absolute value) element.
4307*da0073e9SAndroid Build Coastguard Worker            # Also, gradcheck changes the content of the input by +/- eps (default to 1e-06)
4308*da0073e9SAndroid Build Coastguard Worker            # to compute the numerical gradient which can also cause the signs to flip.
4309*da0073e9SAndroid Build Coastguard Worker            _, idx = U.abs().max(-2, keepdim=True)
4310*da0073e9SAndroid Build Coastguard Worker            sign = U.gather(-2, idx).sign()
4311*da0073e9SAndroid Build Coastguard Worker            U = U * sign
4312*da0073e9SAndroid Build Coastguard Worker            return D, U
4313*da0073e9SAndroid Build Coastguard Worker
4314*da0073e9SAndroid Build Coastguard Worker        # TODO: review if this can be ported to OpInfos or moved to test_linalg.py
4315*da0073e9SAndroid Build Coastguard Worker        def run_symeig_test(k, sizes, largest=True):
4316*da0073e9SAndroid Build Coastguard Worker            A = torch.rand(*sizes).double()
4317*da0073e9SAndroid Build Coastguard Worker            A = (A @ A.mT) / 10
4318*da0073e9SAndroid Build Coastguard Worker            A.requires_grad_(True)
4319*da0073e9SAndroid Build Coastguard Worker
4320*da0073e9SAndroid Build Coastguard Worker            gradcheck(lambda A: func(k, A, largest), A, check_batched_grad=False)
4321*da0073e9SAndroid Build Coastguard Worker
4322*da0073e9SAndroid Build Coastguard Worker            # Custom gradient vectors for better stability due to some
4323*da0073e9SAndroid Build Coastguard Worker            # non-determinism in the lobpcg's forward.
4324*da0073e9SAndroid Build Coastguard Worker            # Note it is not required if symeig is in forward instead (tested).
4325*da0073e9SAndroid Build Coastguard Worker            D_grad = torch.rand(*A.shape[:-2], k) / 100
4326*da0073e9SAndroid Build Coastguard Worker            U_grad = torch.rand(*A.shape[:-1], k) / 100
4327*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(
4328*da0073e9SAndroid Build Coastguard Worker                lambda A: func(k, A, largest),
4329*da0073e9SAndroid Build Coastguard Worker                A,
4330*da0073e9SAndroid Build Coastguard Worker                [D_grad, U_grad],
4331*da0073e9SAndroid Build Coastguard Worker                atol=1e-4,
4332*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
4333*da0073e9SAndroid Build Coastguard Worker            )
4334*da0073e9SAndroid Build Coastguard Worker
4335*da0073e9SAndroid Build Coastguard Worker            # check whether A.grad is symmetric
4336*da0073e9SAndroid Build Coastguard Worker            A = A.detach().requires_grad_(True)
4337*da0073e9SAndroid Build Coastguard Worker            D, U = func(k, A, largest)
4338*da0073e9SAndroid Build Coastguard Worker            (D.sum() + U.sum()).backward()
4339*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(A.grad, A.grad.mT)
4340*da0073e9SAndroid Build Coastguard Worker
4341*da0073e9SAndroid Build Coastguard Worker        for largest in [True, False]:
4342*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(1, (6, 6), largest=largest)
4343*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(1, (2, 6, 6), largest=largest)
4344*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(1, (2, 2, 6, 6), largest=largest)
4345*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(2, (6, 6), largest=largest)
4346*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(2, (2, 6, 6), largest=largest)
4347*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(2, (2, 2, 6, 6), largest=largest)
4348*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(3, (9, 9), largest=largest)
4349*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(3, (2, 9, 9), largest=largest)
4350*da0073e9SAndroid Build Coastguard Worker            run_symeig_test(3, (2, 2, 9, 9), largest=largest)
4351*da0073e9SAndroid Build Coastguard Worker
4352*da0073e9SAndroid Build Coastguard Worker    def test_variable_traverse(self):
4353*da0073e9SAndroid Build Coastguard Worker        def get_out_and_unrefed_cycle():
4354*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(10, requires_grad=True)
4355*da0073e9SAndroid Build Coastguard Worker            tmp = inp.view(10, 1)
4356*da0073e9SAndroid Build Coastguard Worker            out = tmp.view(10)
4357*da0073e9SAndroid Build Coastguard Worker
4358*da0073e9SAndroid Build Coastguard Worker            # Create a reference cycle that contains an
4359*da0073e9SAndroid Build Coastguard Worker            # intermediary Variable in the graph
4360*da0073e9SAndroid Build Coastguard Worker            my_list = []
4361*da0073e9SAndroid Build Coastguard Worker            my_list.append(tmp)
4362*da0073e9SAndroid Build Coastguard Worker            my_list.append(my_list)
4363*da0073e9SAndroid Build Coastguard Worker
4364*da0073e9SAndroid Build Coastguard Worker            return out
4365*da0073e9SAndroid Build Coastguard Worker
4366*da0073e9SAndroid Build Coastguard Worker        out = get_out_and_unrefed_cycle()
4367*da0073e9SAndroid Build Coastguard Worker        gc.collect()
4368*da0073e9SAndroid Build Coastguard Worker        # This will segfault if things have been erroneously released
4369*da0073e9SAndroid Build Coastguard Worker        out.backward(torch.randn(out.size()))
4370*da0073e9SAndroid Build Coastguard Worker
4371*da0073e9SAndroid Build Coastguard Worker    # TODO: review porting these to OpInfo tests
4372*da0073e9SAndroid Build Coastguard Worker    def test_pow_zero_tensor_gradient(self):
4373*da0073e9SAndroid Build Coastguard Worker        def run_test(input_size, exponent):
4374*da0073e9SAndroid Build Coastguard Worker            input = torch.zeros(*input_size, requires_grad=True)
4375*da0073e9SAndroid Build Coastguard Worker            input.pow(exponent).sum().backward()
4376*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(input.grad.abs().sum(), 0)
4377*da0073e9SAndroid Build Coastguard Worker
4378*da0073e9SAndroid Build Coastguard Worker        run_test((10,), torch.zeros(10))
4379*da0073e9SAndroid Build Coastguard Worker        run_test((10, 10), torch.zeros(10, 10))
4380*da0073e9SAndroid Build Coastguard Worker        run_test((10,), 0)
4381*da0073e9SAndroid Build Coastguard Worker
4382*da0073e9SAndroid Build Coastguard Worker    def test_current_graph_task_id(self):
4383*da0073e9SAndroid Build Coastguard Worker        id = [-1]
4384*da0073e9SAndroid Build Coastguard Worker
4385*da0073e9SAndroid Build Coastguard Worker        def hook(_):
4386*da0073e9SAndroid Build Coastguard Worker            id[0] = torch._C._current_graph_task_id()
4387*da0073e9SAndroid Build Coastguard Worker
4388*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1.0, requires_grad=True).clone()
4389*da0073e9SAndroid Build Coastguard Worker        t.register_hook(hook)
4390*da0073e9SAndroid Build Coastguard Worker
4391*da0073e9SAndroid Build Coastguard Worker        t.backward(retain_graph=True)
4392*da0073e9SAndroid Build Coastguard Worker        base = id[0]
4393*da0073e9SAndroid Build Coastguard Worker        t.backward(retain_graph=True)
4394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id[0] - base, 1)
4395*da0073e9SAndroid Build Coastguard Worker        t.backward(retain_graph=True)
4396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(id[0] - base, 2)
4397*da0073e9SAndroid Build Coastguard Worker
4398*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch._C._current_graph_task_id(), -1)
4399*da0073e9SAndroid Build Coastguard Worker
4400*da0073e9SAndroid Build Coastguard Worker    def test_current_graph_task_execution_order(self):
4401*da0073e9SAndroid Build Coastguard Worker        predicted = [None]
4402*da0073e9SAndroid Build Coastguard Worker
4403*da0073e9SAndroid Build Coastguard Worker        def hook(_):
4404*da0073e9SAndroid Build Coastguard Worker            predicted[0] = torch._C._current_graph_task_execution_order()
4405*da0073e9SAndroid Build Coastguard Worker
4406*da0073e9SAndroid Build Coastguard Worker        def names(nodes):
4407*da0073e9SAndroid Build Coastguard Worker            return ", ".join([node.name().split(" ")[-1] for node in nodes]) + "\n"
4408*da0073e9SAndroid Build Coastguard Worker
4409*da0073e9SAndroid Build Coastguard Worker        def grad_fns(*tensors):
4410*da0073e9SAndroid Build Coastguard Worker            # or grad accumulator
4411*da0073e9SAndroid Build Coastguard Worker            out = []
4412*da0073e9SAndroid Build Coastguard Worker            for t in tensors:
4413*da0073e9SAndroid Build Coastguard Worker                if t.requires_grad and t.grad_fn is None:
4414*da0073e9SAndroid Build Coastguard Worker                    out.append(t.clone().grad_fn.next_functions[0][0])
4415*da0073e9SAndroid Build Coastguard Worker                else:
4416*da0073e9SAndroid Build Coastguard Worker                    out.append(t.grad_fn)
4417*da0073e9SAndroid Build Coastguard Worker            return out
4418*da0073e9SAndroid Build Coastguard Worker
4419*da0073e9SAndroid Build Coastguard Worker        actual = []
4420*da0073e9SAndroid Build Coastguard Worker
4421*da0073e9SAndroid Build Coastguard Worker        def register_logging_hooks(*tensors):
4422*da0073e9SAndroid Build Coastguard Worker            # register hooks that log the order in which they are called
4423*da0073e9SAndroid Build Coastguard Worker            def get_hook(i):
4424*da0073e9SAndroid Build Coastguard Worker                def hook(t_):
4425*da0073e9SAndroid Build Coastguard Worker                    actual.append(tensors[i])
4426*da0073e9SAndroid Build Coastguard Worker
4427*da0073e9SAndroid Build Coastguard Worker                return hook
4428*da0073e9SAndroid Build Coastguard Worker
4429*da0073e9SAndroid Build Coastguard Worker            for i, t in enumerate(tensors):
4430*da0073e9SAndroid Build Coastguard Worker                t.register_hook(get_hook(i))
4431*da0073e9SAndroid Build Coastguard Worker
4432*da0073e9SAndroid Build Coastguard Worker        # Basic example: single path
4433*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4434*da0073e9SAndroid Build Coastguard Worker        t.register_hook(hook)
4435*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4436*da0073e9SAndroid Build Coastguard Worker            t.backward()
4437*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
4438*da0073e9SAndroid Build Coastguard Worker            names(predicted[0]),
4439*da0073e9SAndroid Build Coastguard Worker            """\
4440*da0073e9SAndroid Build Coastguard WorkerExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad
4441*da0073e9SAndroid Build Coastguard Worker""",
4442*da0073e9SAndroid Build Coastguard Worker        )
4443*da0073e9SAndroid Build Coastguard Worker
4444*da0073e9SAndroid Build Coastguard Worker        # We don't exactly follow sequence_nr order
4445*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
4446*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(2.0, requires_grad=True)
4447*da0073e9SAndroid Build Coastguard Worker        c = b.sin()
4448*da0073e9SAndroid Build Coastguard Worker        d = a.cos()
4449*da0073e9SAndroid Build Coastguard Worker        out = c * d
4450*da0073e9SAndroid Build Coastguard Worker        register_logging_hooks(a, b, c, d, out)
4451*da0073e9SAndroid Build Coastguard Worker        out.register_hook(hook)
4452*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4453*da0073e9SAndroid Build Coastguard Worker            out.backward()
4454*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(predicted[0], grad_fns(*actual))
4455*da0073e9SAndroid Build Coastguard Worker        actual = []
4456*da0073e9SAndroid Build Coastguard Worker
4457*da0073e9SAndroid Build Coastguard Worker        # Accumulate grad node has more than one input
4458*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
4459*da0073e9SAndroid Build Coastguard Worker        b = a.sin()
4460*da0073e9SAndroid Build Coastguard Worker        c = a.cos()
4461*da0073e9SAndroid Build Coastguard Worker        out = b * c
4462*da0073e9SAndroid Build Coastguard Worker        register_logging_hooks(a, b, c, out)
4463*da0073e9SAndroid Build Coastguard Worker        out.register_hook(hook)
4464*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4465*da0073e9SAndroid Build Coastguard Worker            out.backward()
4466*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(predicted[0], grad_fns(*actual))
4467*da0073e9SAndroid Build Coastguard Worker        actual = []
4468*da0073e9SAndroid Build Coastguard Worker
4469*da0073e9SAndroid Build Coastguard Worker        # Multiple roots are also OK
4470*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
4471*da0073e9SAndroid Build Coastguard Worker        b = a * 2
4472*da0073e9SAndroid Build Coastguard Worker        out = b.sin()
4473*da0073e9SAndroid Build Coastguard Worker        out2 = b.cos()
4474*da0073e9SAndroid Build Coastguard Worker        out3 = b.cos()
4475*da0073e9SAndroid Build Coastguard Worker        register_logging_hooks(a, b, out, out2, out3)
4476*da0073e9SAndroid Build Coastguard Worker        out3.register_hook(hook)
4477*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4478*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad((out, out3, out2), inputs=(a,))
4479*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
4480*da0073e9SAndroid Build Coastguard Worker            names(predicted[0]),
4481*da0073e9SAndroid Build Coastguard Worker            """\
4482*da0073e9SAndroid Build Coastguard WorkerCosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad
4483*da0073e9SAndroid Build Coastguard Worker""",
4484*da0073e9SAndroid Build Coastguard Worker        )
4485*da0073e9SAndroid Build Coastguard Worker        # TODO: Uncomment after update to hooks behavior
4486*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(predicted[0], grad_fns(*actual))
4487*da0073e9SAndroid Build Coastguard Worker        actual = []
4488*da0073e9SAndroid Build Coastguard Worker
4489*da0073e9SAndroid Build Coastguard Worker        # Case where next node is nullptr
4490*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
4491*da0073e9SAndroid Build Coastguard Worker        b = a * 2
4492*da0073e9SAndroid Build Coastguard Worker        out = b.sin()
4493*da0073e9SAndroid Build Coastguard Worker        register_logging_hooks(a, b, out)
4494*da0073e9SAndroid Build Coastguard Worker        out.register_hook(hook)
4495*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4496*da0073e9SAndroid Build Coastguard Worker            out.backward()
4497*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(predicted[0], grad_fns(*actual))
4498*da0073e9SAndroid Build Coastguard Worker        actual = []
4499*da0073e9SAndroid Build Coastguard Worker
4500*da0073e9SAndroid Build Coastguard Worker        # Case where two `inputs` on the same path
4501*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
4502*da0073e9SAndroid Build Coastguard Worker        b = a * 2
4503*da0073e9SAndroid Build Coastguard Worker        out = b.sin()
4504*da0073e9SAndroid Build Coastguard Worker        register_logging_hooks(a, b, out)
4505*da0073e9SAndroid Build Coastguard Worker        out.register_hook(hook)
4506*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4507*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad((out,), inputs=(a, b))
4508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4509*da0073e9SAndroid Build Coastguard Worker            names(predicted[0]),
4510*da0073e9SAndroid Build Coastguard Worker            """\
4511*da0073e9SAndroid Build Coastguard WorkerSinBackward0, MulBackward0, torch::autograd::AccumulateGrad
4512*da0073e9SAndroid Build Coastguard Worker""",
4513*da0073e9SAndroid Build Coastguard Worker        )
4514*da0073e9SAndroid Build Coastguard Worker        # TODO: Uncomment after update to hooks behavior
4515*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(predicted[0], grad_fns(*actual))
4516*da0073e9SAndroid Build Coastguard Worker        actual = []
4517*da0073e9SAndroid Build Coastguard Worker
4518*da0073e9SAndroid Build Coastguard Worker        # Case where `inputs` specifies a subgraph
4519*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
4520*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(1.0, requires_grad=True)
4521*da0073e9SAndroid Build Coastguard Worker        c = a * b
4522*da0073e9SAndroid Build Coastguard Worker        out = c.sin()
4523*da0073e9SAndroid Build Coastguard Worker        register_logging_hooks(a, b, c, out)
4524*da0073e9SAndroid Build Coastguard Worker        out.register_hook(hook)
4525*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
4526*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad((out,), inputs=(a,))
4527*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
4528*da0073e9SAndroid Build Coastguard Worker            names(predicted[0]),
4529*da0073e9SAndroid Build Coastguard Worker            """\
4530*da0073e9SAndroid Build Coastguard WorkerSinBackward0, MulBackward0, torch::autograd::AccumulateGrad
4531*da0073e9SAndroid Build Coastguard Worker""",
4532*da0073e9SAndroid Build Coastguard Worker        )
4533*da0073e9SAndroid Build Coastguard Worker        # TODO: Uncomment after update to hooks behavior
4534*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(predicted[0], grad_fns(*actual))
4535*da0073e9SAndroid Build Coastguard Worker        actual = []
4536*da0073e9SAndroid Build Coastguard Worker
4537*da0073e9SAndroid Build Coastguard Worker        # Errors when not called in a backward
4538*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4539*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "should only be called during the backward pass"
4540*da0073e9SAndroid Build Coastguard Worker        ):
4541*da0073e9SAndroid Build Coastguard Worker            torch._C._current_graph_task_execution_order()
4542*da0073e9SAndroid Build Coastguard Worker
4543*da0073e9SAndroid Build Coastguard Worker        # Errors when context manager not enabled
4544*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1.0, requires_grad=True).clone().sin().exp()
4545*da0073e9SAndroid Build Coastguard Worker        t.register_hook(hook)
4546*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
4547*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
4548*da0073e9SAndroid Build Coastguard Worker            "expects the current backward to be executed with multithreading disabled",
4549*da0073e9SAndroid Build Coastguard Worker        ):
4550*da0073e9SAndroid Build Coastguard Worker            t.backward()
4551*da0073e9SAndroid Build Coastguard Worker
4552*da0073e9SAndroid Build Coastguard Worker    def test_view_replay_enabled(self):
4553*da0073e9SAndroid Build Coastguard Worker        def f(x):
4554*da0073e9SAndroid Build Coastguard Worker            out = x.clone().view(-1)
4555*da0073e9SAndroid Build Coastguard Worker            # mutate the view, triggering autograd view-replay logic
4556*da0073e9SAndroid Build Coastguard Worker            out.add_(1)
4557*da0073e9SAndroid Build Coastguard Worker            return out
4558*da0073e9SAndroid Build Coastguard Worker
4559*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2, requires_grad=True)
4560*da0073e9SAndroid Build Coastguard Worker
4561*da0073e9SAndroid Build Coastguard Worker        # Test as a context manager
4562*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._force_original_view_tracking(False):
4563*da0073e9SAndroid Build Coastguard Worker            out = f(x)
4564*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("AsStridedBackward" in str(out.grad_fn))
4565*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.autograd.is_view_replay_enabled())
4566*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.is_view_replay_enabled())
4567*da0073e9SAndroid Build Coastguard Worker
4568*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._force_original_view_tracking(True):
4569*da0073e9SAndroid Build Coastguard Worker            out = f(x)
4570*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("ViewBackward" in str(out.grad_fn))
4571*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd.is_view_replay_enabled())
4572*da0073e9SAndroid Build Coastguard Worker        out = f(x)
4573*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("AsStridedBackward" in str(out.grad_fn))
4574*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.is_view_replay_enabled())
4575*da0073e9SAndroid Build Coastguard Worker
4576*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._force_original_view_tracking(False):
4577*da0073e9SAndroid Build Coastguard Worker            torch.autograd._force_original_view_tracking(True)
4578*da0073e9SAndroid Build Coastguard Worker            out = f(x)
4579*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("ViewBackward" in str(out.grad_fn))
4580*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd.is_view_replay_enabled())
4581*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.is_view_replay_enabled())
4582*da0073e9SAndroid Build Coastguard Worker
4583*da0073e9SAndroid Build Coastguard Worker        # Test as a function
4584*da0073e9SAndroid Build Coastguard Worker        torch.autograd._force_original_view_tracking(False)
4585*da0073e9SAndroid Build Coastguard Worker        out = f(x)
4586*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("AsStridedBackward" in str(out.grad_fn))
4587*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.is_view_replay_enabled())
4588*da0073e9SAndroid Build Coastguard Worker
4589*da0073e9SAndroid Build Coastguard Worker        torch.autograd._force_original_view_tracking(True)
4590*da0073e9SAndroid Build Coastguard Worker        out = f(x)
4591*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("ViewBackward" in str(out.grad_fn))
4592*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.autograd.is_view_replay_enabled())
4593*da0073e9SAndroid Build Coastguard Worker
4594*da0073e9SAndroid Build Coastguard Worker    def test_unsafe_set_version_counter(self):
4595*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, requires_grad=True).clone()
4596*da0073e9SAndroid Build Coastguard Worker        x.add_(1)
4597*da0073e9SAndroid Build Coastguard Worker        x.add_(2)
4598*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, x._version)
4599*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._unsafe_preserve_version_counter(x):
4600*da0073e9SAndroid Build Coastguard Worker            x.mul_(2)
4601*da0073e9SAndroid Build Coastguard Worker            x.mul_(3)
4602*da0073e9SAndroid Build Coastguard Worker        # version counter doesn't change inside of the context manager
4603*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2, x._version)
4604*da0073e9SAndroid Build Coastguard Worker
4605*da0073e9SAndroid Build Coastguard Worker        torch._C._autograd._unsafe_set_version_counter(x, 0)
4606*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, x._version)
4607*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Cannot set"):
4608*da0073e9SAndroid Build Coastguard Worker            torch._C._autograd._unsafe_set_version_counter(x, -1)
4609*da0073e9SAndroid Build Coastguard Worker
4610*da0073e9SAndroid Build Coastguard Worker    def test_current_node(self):
4611*da0073e9SAndroid Build Coastguard Worker        pr = []
4612*da0073e9SAndroid Build Coastguard Worker
4613*da0073e9SAndroid Build Coastguard Worker        class MyMode(TorchDispatchMode):
4614*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args, kwargs=None):
4615*da0073e9SAndroid Build Coastguard Worker                node = torch._C._current_autograd_node()
4616*da0073e9SAndroid Build Coastguard Worker                # Don't use node.name() here as it is not consistent on windows
4617*da0073e9SAndroid Build Coastguard Worker                node_name = node.__class__.__name__ if node else "None"
4618*da0073e9SAndroid Build Coastguard Worker                pr.append(f"Running {func} from within {node_name}")
4619*da0073e9SAndroid Build Coastguard Worker                return func(*args, **(kwargs or {}))
4620*da0073e9SAndroid Build Coastguard Worker
4621*da0073e9SAndroid Build Coastguard Worker        with MyMode():
4622*da0073e9SAndroid Build Coastguard Worker            pr.append("FW")
4623*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(10, requires_grad=True)
4624*da0073e9SAndroid Build Coastguard Worker            b = a.mul(2).div(3).sum()
4625*da0073e9SAndroid Build Coastguard Worker            pr.append("BW")
4626*da0073e9SAndroid Build Coastguard Worker            b.backward()
4627*da0073e9SAndroid Build Coastguard Worker            pr.append("Done")
4628*da0073e9SAndroid Build Coastguard Worker
4629*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
4630*da0073e9SAndroid Build Coastguard Worker            "\n".join(pr),
4631*da0073e9SAndroid Build Coastguard Worker            """\
4632*da0073e9SAndroid Build Coastguard WorkerFW
4633*da0073e9SAndroid Build Coastguard WorkerRunning aten.rand.default from within None
4634*da0073e9SAndroid Build Coastguard WorkerRunning aten.mul.Tensor from within None
4635*da0073e9SAndroid Build Coastguard WorkerRunning aten.div.Tensor from within None
4636*da0073e9SAndroid Build Coastguard WorkerRunning aten.sum.default from within None
4637*da0073e9SAndroid Build Coastguard WorkerBW
4638*da0073e9SAndroid Build Coastguard WorkerRunning aten.ones_like.default from within None
4639*da0073e9SAndroid Build Coastguard WorkerRunning aten.expand.default from within SumBackward0
4640*da0073e9SAndroid Build Coastguard WorkerRunning aten.div.Tensor from within DivBackward0
4641*da0073e9SAndroid Build Coastguard WorkerRunning aten.mul.Tensor from within MulBackward0
4642*da0073e9SAndroid Build Coastguard WorkerRunning aten.detach.default from within AccumulateGrad
4643*da0073e9SAndroid Build Coastguard WorkerRunning aten.detach.default from within AccumulateGrad
4644*da0073e9SAndroid Build Coastguard WorkerDone""",
4645*da0073e9SAndroid Build Coastguard Worker        )
4646*da0073e9SAndroid Build Coastguard Worker
4647*da0073e9SAndroid Build Coastguard Worker    def test_profiler(self):
4648*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
4649*da0073e9SAndroid Build Coastguard Worker
4650*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4651*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd._profiler_enabled())
4652*da0073e9SAndroid Build Coastguard Worker            y = x * 2 + 4
4653*da0073e9SAndroid Build Coastguard Worker
4654*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd._profiler_enabled())
4655*da0073e9SAndroid Build Coastguard Worker
4656*da0073e9SAndroid Build Coastguard Worker        names = ["aten::mul", "aten::add"]
4657*da0073e9SAndroid Build Coastguard Worker        found_indices = set()
4658*da0073e9SAndroid Build Coastguard Worker        for evt in p.function_events:
4659*da0073e9SAndroid Build Coastguard Worker            if evt.name in names:
4660*da0073e9SAndroid Build Coastguard Worker                found_indices.add(names.index(evt.name))
4661*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(found_indices), len(names))
4662*da0073e9SAndroid Build Coastguard Worker
4663*da0073e9SAndroid Build Coastguard Worker    def test_profiler_seq_nr(self):
4664*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4665*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(10, 10, requires_grad=True)
4666*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(10, 10, requires_grad=True)
4667*da0073e9SAndroid Build Coastguard Worker            z = x + y
4668*da0073e9SAndroid Build Coastguard Worker            s = z.sum(dim=None)
4669*da0073e9SAndroid Build Coastguard Worker            s.backward()
4670*da0073e9SAndroid Build Coastguard Worker        print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
4671*da0073e9SAndroid Build Coastguard Worker        # expecting aten::add, aten::sum to have the sequence numbers,
4672*da0073e9SAndroid Build Coastguard Worker        # expecting the corresponding backward nodes to have the same numbers
4673*da0073e9SAndroid Build Coastguard Worker        # as the forward ops
4674*da0073e9SAndroid Build Coastguard Worker        autograd_ops = {
4675*da0073e9SAndroid Build Coastguard Worker            ("aten::add", "Add"): [],
4676*da0073e9SAndroid Build Coastguard Worker            ("aten::sum", "Sum"): [],
4677*da0073e9SAndroid Build Coastguard Worker        }
4678*da0073e9SAndroid Build Coastguard Worker        accumulate_ops = []
4679*da0073e9SAndroid Build Coastguard Worker        found_empty = False
4680*da0073e9SAndroid Build Coastguard Worker        for e in p.function_events:
4681*da0073e9SAndroid Build Coastguard Worker            for (fwd_name, bwd_name), ops in autograd_ops.items():
4682*da0073e9SAndroid Build Coastguard Worker                if e.name == fwd_name or (bwd_name in e.name and "Backward" in e.name):
4683*da0073e9SAndroid Build Coastguard Worker                    ops.append(e)
4684*da0073e9SAndroid Build Coastguard Worker
4685*da0073e9SAndroid Build Coastguard Worker            if "AccumulateGrad" in e.name:
4686*da0073e9SAndroid Build Coastguard Worker                accumulate_ops.append(e)
4687*da0073e9SAndroid Build Coastguard Worker
4688*da0073e9SAndroid Build Coastguard Worker            # check that nested ops (e.g. empty) don't have
4689*da0073e9SAndroid Build Coastguard Worker            # sequence number
4690*da0073e9SAndroid Build Coastguard Worker            if e.name == "aten::empty":
4691*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(e.sequence_nr, -1)
4692*da0073e9SAndroid Build Coastguard Worker                found_empty = True
4693*da0073e9SAndroid Build Coastguard Worker
4694*da0073e9SAndroid Build Coastguard Worker        for idx, ((fwd_name, bwd_name), ops) in enumerate(autograd_ops.items()):
4695*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(ops), 3)
4696*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[0].name, fwd_name)
4697*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
4698*da0073e9SAndroid Build Coastguard Worker                ops[1].name,
4699*da0073e9SAndroid Build Coastguard Worker                f"autograd::engine::evaluate_function: {bwd_name}Backward{idx}",
4700*da0073e9SAndroid Build Coastguard Worker            )
4701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[2].name, f"{bwd_name}Backward{idx}")
4702*da0073e9SAndroid Build Coastguard Worker            self.assertGreaterEqual(ops[0].sequence_nr, 0)
4703*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[1].sequence_nr, ops[0].sequence_nr)
4704*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[2].sequence_nr, ops[0].sequence_nr)
4705*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[0].fwd_thread, 0)
4706*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[1].fwd_thread, ops[0].thread)
4707*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ops[2].fwd_thread, ops[0].thread)
4708*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_empty)
4709*da0073e9SAndroid Build Coastguard Worker
4710*da0073e9SAndroid Build Coastguard Worker    def test_profiler_unboxed_only(self):
4711*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
4712*da0073e9SAndroid Build Coastguard Worker
4713*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
4714*da0073e9SAndroid Build Coastguard Worker            x.resize_([3, 2])
4715*da0073e9SAndroid Build Coastguard Worker
4716*da0073e9SAndroid Build Coastguard Worker    def test_profiler_propagation(self):
4717*da0073e9SAndroid Build Coastguard Worker        def foo(x):
4718*da0073e9SAndroid Build Coastguard Worker            with record_function("in_foo") as rf:
4719*da0073e9SAndroid Build Coastguard Worker                return x * 2
4720*da0073e9SAndroid Build Coastguard Worker
4721*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(3, 4)
4722*da0073e9SAndroid Build Coastguard Worker        traced_foo = torch.jit.trace(foo, x)
4723*da0073e9SAndroid Build Coastguard Worker
4724*da0073e9SAndroid Build Coastguard Worker        def bar(x):
4725*da0073e9SAndroid Build Coastguard Worker            with record_function("in_bar") as rf:
4726*da0073e9SAndroid Build Coastguard Worker                # we expect that profiler will be able
4727*da0073e9SAndroid Build Coastguard Worker                # propagate across fork
4728*da0073e9SAndroid Build Coastguard Worker                fut = torch.jit._fork(traced_foo, x)
4729*da0073e9SAndroid Build Coastguard Worker                y = torch.jit._wait(fut)
4730*da0073e9SAndroid Build Coastguard Worker                # note: continuation (and rf's end) can
4731*da0073e9SAndroid Build Coastguard Worker                # be executed in a different thread
4732*da0073e9SAndroid Build Coastguard Worker                with record_function("in_bar_after_wait") as rf2:
4733*da0073e9SAndroid Build Coastguard Worker                    y = y * 2
4734*da0073e9SAndroid Build Coastguard Worker                return y
4735*da0073e9SAndroid Build Coastguard Worker
4736*da0073e9SAndroid Build Coastguard Worker        traced_bar = torch.jit.trace(bar, x)
4737*da0073e9SAndroid Build Coastguard Worker
4738*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4739*da0073e9SAndroid Build Coastguard Worker            traced_bar(x)
4740*da0073e9SAndroid Build Coastguard Worker
4741*da0073e9SAndroid Build Coastguard Worker        found_foo = False
4742*da0073e9SAndroid Build Coastguard Worker        found_bar = False
4743*da0073e9SAndroid Build Coastguard Worker        found_bar_after_wait = False
4744*da0073e9SAndroid Build Coastguard Worker        for info in p.function_events:
4745*da0073e9SAndroid Build Coastguard Worker            if info.name == "in_foo":
4746*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(found_foo)
4747*da0073e9SAndroid Build Coastguard Worker                found_foo = True
4748*da0073e9SAndroid Build Coastguard Worker            elif info.name == "in_bar":
4749*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(found_bar)
4750*da0073e9SAndroid Build Coastguard Worker                found_bar = True
4751*da0073e9SAndroid Build Coastguard Worker            elif info.name == "in_bar_after_wait":
4752*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(found_bar_after_wait)
4753*da0073e9SAndroid Build Coastguard Worker                found_bar_after_wait = True
4754*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_foo)
4755*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_bar)
4756*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(found_bar_after_wait)
4757*da0073e9SAndroid Build Coastguard Worker
4758*da0073e9SAndroid Build Coastguard Worker    def test_record_function_callbacks(self):
4759*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
4760*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4761*da0073e9SAndroid Build Coastguard Worker            with record_function("foo"):
4762*da0073e9SAndroid Build Coastguard Worker                y = x * 2 + 4
4763*da0073e9SAndroid Build Coastguard Worker
4764*da0073e9SAndroid Build Coastguard Worker        function_events = p.function_events
4765*da0073e9SAndroid Build Coastguard Worker        foo_event = next(event for event in function_events if "foo" in event.name)
4766*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_event.count, 1)
4767*da0073e9SAndroid Build Coastguard Worker
4768*da0073e9SAndroid Build Coastguard Worker    def test_record_function_legacy(self):
4769*da0073e9SAndroid Build Coastguard Worker        # Test the new _record_function ops work
4770*da0073e9SAndroid Build Coastguard Worker        # Note: Remove once record_function uses these directly
4771*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
4772*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4773*da0073e9SAndroid Build Coastguard Worker            handle = torch.ops.profiler._record_function_enter("bar", None)
4774*da0073e9SAndroid Build Coastguard Worker            try:
4775*da0073e9SAndroid Build Coastguard Worker                y = x * 2 + 4
4776*da0073e9SAndroid Build Coastguard Worker            finally:
4777*da0073e9SAndroid Build Coastguard Worker                torch.ops.profiler._record_function_exit(handle)
4778*da0073e9SAndroid Build Coastguard Worker
4779*da0073e9SAndroid Build Coastguard Worker        function_events = p.function_events
4780*da0073e9SAndroid Build Coastguard Worker        foo_event = next(event for event in function_events if "bar" in event.name)
4781*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo_event.count, 1)
4782*da0073e9SAndroid Build Coastguard Worker
4783*da0073e9SAndroid Build Coastguard Worker    def test_profiler_aggregation_fake(self):
4784*da0073e9SAndroid Build Coastguard Worker        events = EventList()
4785*da0073e9SAndroid Build Coastguard Worker        id = [0]
4786*da0073e9SAndroid Build Coastguard Worker
4787*da0073e9SAndroid Build Coastguard Worker        def get_id():
4788*da0073e9SAndroid Build Coastguard Worker            id[0] = id[0] + 1
4789*da0073e9SAndroid Build Coastguard Worker            return id[0]
4790*da0073e9SAndroid Build Coastguard Worker
4791*da0073e9SAndroid Build Coastguard Worker        # [[thread_id, [(start, end, id), ....]], ...]
4792*da0073e9SAndroid Build Coastguard Worker        # Using list instead of a dict so order is guaranteed for any Python
4793*da0073e9SAndroid Build Coastguard Worker        # version
4794*da0073e9SAndroid Build Coastguard Worker        threads = [
4795*da0073e9SAndroid Build Coastguard Worker            [1, [(0, 1, get_id()), (1, 2, get_id())]],
4796*da0073e9SAndroid Build Coastguard Worker            [0, [(0, 2, get_id()), (1, 2, get_id()), (1, 3, get_id())]],
4797*da0073e9SAndroid Build Coastguard Worker        ]
4798*da0073e9SAndroid Build Coastguard Worker        for thread, ranges in threads:
4799*da0073e9SAndroid Build Coastguard Worker            for range in ranges:
4800*da0073e9SAndroid Build Coastguard Worker                assert len(range) == 3
4801*da0073e9SAndroid Build Coastguard Worker                events.append(
4802*da0073e9SAndroid Build Coastguard Worker                    FunctionEvent(
4803*da0073e9SAndroid Build Coastguard Worker                        id=range[2],
4804*da0073e9SAndroid Build Coastguard Worker                        node_id=0,
4805*da0073e9SAndroid Build Coastguard Worker                        name="",
4806*da0073e9SAndroid Build Coastguard Worker                        thread=thread,
4807*da0073e9SAndroid Build Coastguard Worker                        start_us=range[0],
4808*da0073e9SAndroid Build Coastguard Worker                        end_us=range[1],
4809*da0073e9SAndroid Build Coastguard Worker                    )
4810*da0073e9SAndroid Build Coastguard Worker                )
4811*da0073e9SAndroid Build Coastguard Worker
4812*da0073e9SAndroid Build Coastguard Worker        events._populate_cpu_children()
4813*da0073e9SAndroid Build Coastguard Worker
4814*da0073e9SAndroid Build Coastguard Worker        # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2]
4815*da0073e9SAndroid Build Coastguard Worker        # as a child of [1, 3]
4816*da0073e9SAndroid Build Coastguard Worker        res = [[], [], [], [], [4]]
4817*da0073e9SAndroid Build Coastguard Worker
4818*da0073e9SAndroid Build Coastguard Worker        def get_children_ids(event):
4819*da0073e9SAndroid Build Coastguard Worker            return [child.id for child in event.cpu_children]
4820*da0073e9SAndroid Build Coastguard Worker
4821*da0073e9SAndroid Build Coastguard Worker        assert [get_children_ids(event) for event in events] == res
4822*da0073e9SAndroid Build Coastguard Worker
4823*da0073e9SAndroid Build Coastguard Worker    def test_profiler_aggregation_table(self):
4824*da0073e9SAndroid Build Coastguard Worker        """
4825*da0073e9SAndroid Build Coastguard Worker        Test if the profiling result is aggregated for `str(prof)`
4826*da0073e9SAndroid Build Coastguard Worker
4827*da0073e9SAndroid Build Coastguard Worker        See: https://github.com/pytorch/pytorch/issues/37500
4828*da0073e9SAndroid Build Coastguard Worker        """
4829*da0073e9SAndroid Build Coastguard Worker
4830*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1024)
4831*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
4832*da0073e9SAndroid Build Coastguard Worker            torch.einsum("i->", x)
4833*da0073e9SAndroid Build Coastguard Worker
4834*da0073e9SAndroid Build Coastguard Worker        prof_str = str(prof)
4835*da0073e9SAndroid Build Coastguard Worker        prof_table = prof.table()
4836*da0073e9SAndroid Build Coastguard Worker
4837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(prof_table, prof_str)
4838*da0073e9SAndroid Build Coastguard Worker
4839*da0073e9SAndroid Build Coastguard Worker    def test_profiler_function_event_avg(self):
4840*da0073e9SAndroid Build Coastguard Worker        avg = FunctionEventAvg()
4841*da0073e9SAndroid Build Coastguard Worker        avg.add(
4842*da0073e9SAndroid Build Coastguard Worker            FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15)
4843*da0073e9SAndroid Build Coastguard Worker        )
4844*da0073e9SAndroid Build Coastguard Worker        avg.add(
4845*da0073e9SAndroid Build Coastguard Worker            FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30)
4846*da0073e9SAndroid Build Coastguard Worker        )
4847*da0073e9SAndroid Build Coastguard Worker        avg.add(avg)
4848*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.key, "foo")
4849*da0073e9SAndroid Build Coastguard Worker
4850*da0073e9SAndroid Build Coastguard Worker        # aggregate stats
4851*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.count, 4)
4852*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.cpu_time_total, 30)
4853*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.self_cpu_time_total, 30)
4854*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.device_time_total, 0)
4855*da0073e9SAndroid Build Coastguard Worker
4856*da0073e9SAndroid Build Coastguard Worker        # average stats
4857*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.cpu_time, 7.5)
4858*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(avg.device_time_total, 0)
4859*da0073e9SAndroid Build Coastguard Worker
4860*da0073e9SAndroid Build Coastguard Worker    def test_profiler_shapes(self):
4861*da0073e9SAndroid Build Coastguard Worker        print()
4862*da0073e9SAndroid Build Coastguard Worker        layer1 = torch.nn.Linear(20, 30)
4863*da0073e9SAndroid Build Coastguard Worker        layer2 = torch.nn.Linear(30, 40)
4864*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(128, 20)
4865*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
4866*da0073e9SAndroid Build Coastguard Worker            layer2(layer1(input))
4867*da0073e9SAndroid Build Coastguard Worker
4868*da0073e9SAndroid Build Coastguard Worker        print(prof.function_events)
4869*da0073e9SAndroid Build Coastguard Worker
4870*da0073e9SAndroid Build Coastguard Worker        linear_expected_shapes = [
4871*da0073e9SAndroid Build Coastguard Worker            [[128, 20], [30, 20], [30]],
4872*da0073e9SAndroid Build Coastguard Worker            [[128, 30], [40, 30], [40]],
4873*da0073e9SAndroid Build Coastguard Worker        ]
4874*da0073e9SAndroid Build Coastguard Worker
4875*da0073e9SAndroid Build Coastguard Worker        found_indices = set()
4876*da0073e9SAndroid Build Coastguard Worker        for event in prof.function_events:
4877*da0073e9SAndroid Build Coastguard Worker            if event.name == "aten::linear":
4878*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(event.input_shapes in linear_expected_shapes)
4879*da0073e9SAndroid Build Coastguard Worker                found_indices.add(linear_expected_shapes.index(event.input_shapes))
4880*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(found_indices), len(linear_expected_shapes))
4881*da0073e9SAndroid Build Coastguard Worker
4882*da0073e9SAndroid Build Coastguard Worker    def test_profiler_aggregation_lstm(self):
4883*da0073e9SAndroid Build Coastguard Worker        print()
4884*da0073e9SAndroid Build Coastguard Worker        rnn = torch.nn.LSTM(10, 20, 2)
4885*da0073e9SAndroid Build Coastguard Worker        total_time_s = 0
4886*da0073e9SAndroid Build Coastguard Worker        with profile(record_shapes=True, use_kineto=kineto_available()) as prof:
4887*da0073e9SAndroid Build Coastguard Worker            for i in range(20):
4888*da0073e9SAndroid Build Coastguard Worker                input = torch.randn(5, 3, 10)
4889*da0073e9SAndroid Build Coastguard Worker                h = torch.randn(2, 3, 20)
4890*da0073e9SAndroid Build Coastguard Worker                c = torch.randn(2, 3, 20)
4891*da0073e9SAndroid Build Coastguard Worker                start = time.time()
4892*da0073e9SAndroid Build Coastguard Worker                rnn(input, (h, c))
4893*da0073e9SAndroid Build Coastguard Worker                end = time.time()
4894*da0073e9SAndroid Build Coastguard Worker                total_time_s += end - start
4895*da0073e9SAndroid Build Coastguard Worker
4896*da0073e9SAndroid Build Coastguard Worker        print(prof.table(sort_by="self_cpu_time_total", row_limit=10, header="TEST"))
4897*da0073e9SAndroid Build Coastguard Worker        print(
4898*da0073e9SAndroid Build Coastguard Worker            prof.key_averages(group_by_input_shape=True).table(
4899*da0073e9SAndroid Build Coastguard Worker                sort_by="self_cpu_time_total", row_limit=10
4900*da0073e9SAndroid Build Coastguard Worker            )
4901*da0073e9SAndroid Build Coastguard Worker        )
4902*da0073e9SAndroid Build Coastguard Worker        print(
4903*da0073e9SAndroid Build Coastguard Worker            prof.table(
4904*da0073e9SAndroid Build Coastguard Worker                sort_by="self_cpu_time_total",
4905*da0073e9SAndroid Build Coastguard Worker                row_limit=10,
4906*da0073e9SAndroid Build Coastguard Worker                max_src_column_width=300,
4907*da0073e9SAndroid Build Coastguard Worker                header="TEST",
4908*da0073e9SAndroid Build Coastguard Worker                top_level_events_only=True,
4909*da0073e9SAndroid Build Coastguard Worker            )
4910*da0073e9SAndroid Build Coastguard Worker        )
4911*da0073e9SAndroid Build Coastguard Worker        print(
4912*da0073e9SAndroid Build Coastguard Worker            prof.key_averages(group_by_input_shape=True).table(
4913*da0073e9SAndroid Build Coastguard Worker                sort_by="self_cpu_time_total", row_limit=10, top_level_events_only=True
4914*da0073e9SAndroid Build Coastguard Worker            )
4915*da0073e9SAndroid Build Coastguard Worker        )
4916*da0073e9SAndroid Build Coastguard Worker
4917*da0073e9SAndroid Build Coastguard Worker        total_time_us = (
4918*da0073e9SAndroid Build Coastguard Worker            total_time_s * 1000.0 * 1000.0
4919*da0073e9SAndroid Build Coastguard Worker        )  # make it us which is profiler default
4920*da0073e9SAndroid Build Coastguard Worker        print("Total time based on python measurements: ", _format_time(total_time_us))
4921*da0073e9SAndroid Build Coastguard Worker        print(
4922*da0073e9SAndroid Build Coastguard Worker            f"CPU time measurement python side overhead: {(total_time_us / prof.self_cpu_time_total - 1.0) * 100.0:.2f}%"
4923*da0073e9SAndroid Build Coastguard Worker        )
4924*da0073e9SAndroid Build Coastguard Worker
4925*da0073e9SAndroid Build Coastguard Worker        if sys.platform != "win32":
4926*da0073e9SAndroid Build Coastguard Worker            with tempfile.NamedTemporaryFile() as trace_file:
4927*da0073e9SAndroid Build Coastguard Worker                prof.export_chrome_trace(trace_file.name)
4928*da0073e9SAndroid Build Coastguard Worker
4929*da0073e9SAndroid Build Coastguard Worker    def test_record_function(self):
4930*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
4931*da0073e9SAndroid Build Coastguard Worker
4932*da0073e9SAndroid Build Coastguard Worker        def forward(x):
4933*da0073e9SAndroid Build Coastguard Worker            with record_function("outer"):
4934*da0073e9SAndroid Build Coastguard Worker                y = x * 2 + 4
4935*da0073e9SAndroid Build Coastguard Worker                with record_function("inner"):
4936*da0073e9SAndroid Build Coastguard Worker                    y = y - 1
4937*da0073e9SAndroid Build Coastguard Worker            y = y / 1
4938*da0073e9SAndroid Build Coastguard Worker
4939*da0073e9SAndroid Build Coastguard Worker        forward(x)
4940*da0073e9SAndroid Build Coastguard Worker
4941*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4942*da0073e9SAndroid Build Coastguard Worker            forward(x)
4943*da0073e9SAndroid Build Coastguard Worker
4944*da0073e9SAndroid Build Coastguard Worker        events = p.function_events
4945*da0073e9SAndroid Build Coastguard Worker        important_events = [
4946*da0073e9SAndroid Build Coastguard Worker            "outer",
4947*da0073e9SAndroid Build Coastguard Worker            "aten::mul",
4948*da0073e9SAndroid Build Coastguard Worker            "aten::add",
4949*da0073e9SAndroid Build Coastguard Worker            "inner",
4950*da0073e9SAndroid Build Coastguard Worker            "aten::sub",
4951*da0073e9SAndroid Build Coastguard Worker            "aten::div",
4952*da0073e9SAndroid Build Coastguard Worker        ]
4953*da0073e9SAndroid Build Coastguard Worker        idx = 0
4954*da0073e9SAndroid Build Coastguard Worker        for info in events:
4955*da0073e9SAndroid Build Coastguard Worker            if info.name == important_events[idx]:
4956*da0073e9SAndroid Build Coastguard Worker                idx = idx + 1
4957*da0073e9SAndroid Build Coastguard Worker            if idx == len(important_events):
4958*da0073e9SAndroid Build Coastguard Worker                break
4959*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(idx, len(important_events))
4960*da0073e9SAndroid Build Coastguard Worker
4961*da0073e9SAndroid Build Coastguard Worker        # We can also use record_function to decorate arbitrary function
4962*da0073e9SAndroid Build Coastguard Worker        @record_function("my_func")
4963*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
4964*da0073e9SAndroid Build Coastguard Worker            return x + y
4965*da0073e9SAndroid Build Coastguard Worker
4966*da0073e9SAndroid Build Coastguard Worker        with profile(use_kineto=kineto_available()) as p:
4967*da0073e9SAndroid Build Coastguard Worker            f(1, 2)
4968*da0073e9SAndroid Build Coastguard Worker
4969*da0073e9SAndroid Build Coastguard Worker        self.assertTrue("my_func" in str(p))
4970*da0073e9SAndroid Build Coastguard Worker
4971*da0073e9SAndroid Build Coastguard Worker    def test_record_function_multithreaded(self):
4972*da0073e9SAndroid Build Coastguard Worker        rf = record_function("outer")
4973*da0073e9SAndroid Build Coastguard Worker        rf.__enter__()
4974*da0073e9SAndroid Build Coastguard Worker        with record_function("inner"):
4975*da0073e9SAndroid Build Coastguard Worker            # test that exiting the record function after starting another one
4976*da0073e9SAndroid Build Coastguard Worker            # doesn't throw.
4977*da0073e9SAndroid Build Coastguard Worker            rf.__exit__(None, None, None)
4978*da0073e9SAndroid Build Coastguard Worker
4979*da0073e9SAndroid Build Coastguard Worker        with record_function("inner"):
4980*da0073e9SAndroid Build Coastguard Worker            rf.__enter__()
4981*da0073e9SAndroid Build Coastguard Worker        # test that exiting the record function after ending another one
4982*da0073e9SAndroid Build Coastguard Worker        # doesn't throw.
4983*da0073e9SAndroid Build Coastguard Worker        rf.__exit__(None, None, None)
4984*da0073e9SAndroid Build Coastguard Worker
4985*da0073e9SAndroid Build Coastguard Worker    def test_dir(self):
4986*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, 10)
4987*da0073e9SAndroid Build Coastguard Worker        keys = dir(x)
4988*da0073e9SAndroid Build Coastguard Worker        self.assertIn("shape", keys)
4989*da0073e9SAndroid Build Coastguard Worker
4990*da0073e9SAndroid Build Coastguard Worker        # real and imag are only implemented for complex tensors.
4991*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(10, 10, dtype=torch.cfloat)
4992*da0073e9SAndroid Build Coastguard Worker        imag_key = "imag"
4993*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: hasattr(x, imag_key))
4994*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(hasattr(y, imag_key))
4995*da0073e9SAndroid Build Coastguard Worker        keys.remove(imag_key)
4996*da0073e9SAndroid Build Coastguard Worker
4997*da0073e9SAndroid Build Coastguard Worker        for key in keys:
4998*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(hasattr(x, key))
4999*da0073e9SAndroid Build Coastguard Worker
5000*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_saved_output(self):
5001*da0073e9SAndroid Build Coastguard Worker        # Test an in-place operation on a view in which the in-place op saves
5002*da0073e9SAndroid Build Coastguard Worker        # its output. Previously, this created a reference cycle.
5003*da0073e9SAndroid Build Coastguard Worker        dealloc = [0]
5004*da0073e9SAndroid Build Coastguard Worker
5005*da0073e9SAndroid Build Coastguard Worker        class IncrementOnDelete:
5006*da0073e9SAndroid Build Coastguard Worker            def __del__(self):
5007*da0073e9SAndroid Build Coastguard Worker                dealloc[0] += 1
5008*da0073e9SAndroid Build Coastguard Worker
5009*da0073e9SAndroid Build Coastguard Worker        def test():
5010*da0073e9SAndroid Build Coastguard Worker            root = torch.randn(3, 3, requires_grad=True)
5011*da0073e9SAndroid Build Coastguard Worker            copy = root.clone()
5012*da0073e9SAndroid Build Coastguard Worker            copy.grad_fn.register_hook(IncrementOnDelete())
5013*da0073e9SAndroid Build Coastguard Worker            view = copy.view(9)
5014*da0073e9SAndroid Build Coastguard Worker            torch.nn.functional.relu(view, inplace=True)
5015*da0073e9SAndroid Build Coastguard Worker
5016*da0073e9SAndroid Build Coastguard Worker        test()
5017*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dealloc[0], 1)
5018*da0073e9SAndroid Build Coastguard Worker
5019*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_leaf_errors(self):
5020*da0073e9SAndroid Build Coastguard Worker        # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
5021*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(1, requires_grad=True)
5022*da0073e9SAndroid Build Coastguard Worker        y = x.view_as(x)
5023*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5024*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5025*da0073e9SAndroid Build Coastguard Worker            "a view of a leaf Variable that "
5026*da0073e9SAndroid Build Coastguard Worker            "requires grad is being used in "
5027*da0073e9SAndroid Build Coastguard Worker            "an in-place operation.",
5028*da0073e9SAndroid Build Coastguard Worker        ):
5029*da0073e9SAndroid Build Coastguard Worker            y.add_(1)
5030*da0073e9SAndroid Build Coastguard Worker
5031*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_backward(self):
5032*da0073e9SAndroid Build Coastguard Worker        # Issue #10532: Make sure that this does not raise RuntimeError.
5033*da0073e9SAndroid Build Coastguard Worker        net = nn.Sequential(nn.InstanceNorm2d(2), nn.ReLU(True))
5034*da0073e9SAndroid Build Coastguard Worker
5035*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[[[1.0, 1.0]]]], requires_grad=True)
5036*da0073e9SAndroid Build Coastguard Worker        (g,) = torch.autograd.grad(
5037*da0073e9SAndroid Build Coastguard Worker            net(x).pow(2), [x], grad_outputs=x.new_ones(x.shape), create_graph=True
5038*da0073e9SAndroid Build Coastguard Worker        )
5039*da0073e9SAndroid Build Coastguard Worker        torch.autograd.grad(g.sum(), [x])
5040*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, torch.tensor([[[[1.0, 1.0]]]]))
5041*da0073e9SAndroid Build Coastguard Worker
5042*da0073e9SAndroid Build Coastguard Worker        # https://discuss.pytorch.org/t/freeing-buffer-strange-behavior/31955/8
5043*da0073e9SAndroid Build Coastguard Worker        inputs = torch.ones((1, 3, 256, 256), requires_grad=True)
5044*da0073e9SAndroid Build Coastguard Worker
5045*da0073e9SAndroid Build Coastguard Worker        tmp1 = (inputs + 1).view_as(inputs)
5046*da0073e9SAndroid Build Coastguard Worker        tmp2 = torch.nn.functional.threshold(tmp1, 0.0, 0.0, True)
5047*da0073e9SAndroid Build Coastguard Worker        prob_interpolated = torch.sigmoid(tmp2)
5048*da0073e9SAndroid Build Coastguard Worker
5049*da0073e9SAndroid Build Coastguard Worker        gradients = torch.autograd.grad(
5050*da0073e9SAndroid Build Coastguard Worker            outputs=prob_interpolated,
5051*da0073e9SAndroid Build Coastguard Worker            inputs=inputs,
5052*da0073e9SAndroid Build Coastguard Worker            grad_outputs=torch.ones(prob_interpolated.size()),
5053*da0073e9SAndroid Build Coastguard Worker            create_graph=True,
5054*da0073e9SAndroid Build Coastguard Worker            retain_graph=True,
5055*da0073e9SAndroid Build Coastguard Worker        )[0]
5056*da0073e9SAndroid Build Coastguard Worker
5057*da0073e9SAndroid Build Coastguard Worker        gradient_penalty = gradients.sum()
5058*da0073e9SAndroid Build Coastguard Worker        gradient_penalty.backward()
5059*da0073e9SAndroid Build Coastguard Worker
5060*da0073e9SAndroid Build Coastguard Worker        fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0]
5061*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(fn.name(), "ThresholdBackwardBackward0")
5062*da0073e9SAndroid Build Coastguard Worker
5063*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_weak_grad_fn(self):
5064*da0073e9SAndroid Build Coastguard Worker        # Issue 23502: Test that b's grad_fn is preserved.
5065*da0073e9SAndroid Build Coastguard Worker        a = torch.arange(10.0, requires_grad=True)
5066*da0073e9SAndroid Build Coastguard Worker
5067*da0073e9SAndroid Build Coastguard Worker        b = a.narrow(0, 0, 2).clone().view(-1)
5068*da0073e9SAndroid Build Coastguard Worker        b.relu_()
5069*da0073e9SAndroid Build Coastguard Worker
5070*da0073e9SAndroid Build Coastguard Worker        c = b.clone()
5071*da0073e9SAndroid Build Coastguard Worker        del b
5072*da0073e9SAndroid Build Coastguard Worker        gc.collect()
5073*da0073e9SAndroid Build Coastguard Worker
5074*da0073e9SAndroid Build Coastguard Worker        s = c.sum()
5075*da0073e9SAndroid Build Coastguard Worker        s.backward()
5076*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s, torch.tensor(1.0))
5077*da0073e9SAndroid Build Coastguard Worker
5078*da0073e9SAndroid Build Coastguard Worker        # Issue #21875: Fail faster (when we try to modify the view vs. in backward())
5079*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(10, requires_grad=True).narrow(0, 0, 10)
5080*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
5081*da0073e9SAndroid Build Coastguard Worker            b = a.relu_()
5082*da0073e9SAndroid Build Coastguard Worker
5083*da0073e9SAndroid Build Coastguard Worker    def test_out_variant_raises_when_inputs_require_grad(self):
5084*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, requires_grad=True)
5085*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, requires_grad=True)
5086*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros_like(a)
5087*da0073e9SAndroid Build Coastguard Worker
5088*da0073e9SAndroid Build Coastguard Worker        # out=... functions don't support automatic differentiation currently
5089*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x))
5090*da0073e9SAndroid Build Coastguard Worker
5091*da0073e9SAndroid Build Coastguard Worker        # the inputs can require grad if we're in no_grad() mode
5092*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
5093*da0073e9SAndroid Build Coastguard Worker            torch.mul(a, b, out=x)
5094*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, a * b)
5095*da0073e9SAndroid Build Coastguard Worker
5096*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2)
5097*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2)
5098*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2, 2, requires_grad=True)
5099*da0073e9SAndroid Build Coastguard Worker        # we should throw an exception if the output requires grad
5100*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x))
5101*da0073e9SAndroid Build Coastguard Worker
5102*da0073e9SAndroid Build Coastguard Worker    def test_anomaly_detect_nan(self):
5103*da0073e9SAndroid Build Coastguard Worker        size = 10
5104*da0073e9SAndroid Build Coastguard Worker
5105*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
5106*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5107*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1, inp2, fail_0th):
5108*da0073e9SAndroid Build Coastguard Worker                ctx.fail_0th = fail_0th
5109*da0073e9SAndroid Build Coastguard Worker                return inp1.sum(0, keepdim=True)
5110*da0073e9SAndroid Build Coastguard Worker
5111*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5112*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
5113*da0073e9SAndroid Build Coastguard Worker                gI = gO.clone().expand(size)
5114*da0073e9SAndroid Build Coastguard Worker                gI[0] = 0
5115*da0073e9SAndroid Build Coastguard Worker                gI[0] /= 0  # Generate a nan
5116*da0073e9SAndroid Build Coastguard Worker                if ctx.fail_0th:
5117*da0073e9SAndroid Build Coastguard Worker                    return gI, None, None
5118*da0073e9SAndroid Build Coastguard Worker                else:
5119*da0073e9SAndroid Build Coastguard Worker                    return None, gI, None
5120*da0073e9SAndroid Build Coastguard Worker
5121*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(size, requires_grad=True)
5122*da0073e9SAndroid Build Coastguard Worker        out = MyFunc.apply(inp, inp, True)
5123*da0073e9SAndroid Build Coastguard Worker        out.backward()  # Should not fail
5124*da0073e9SAndroid Build Coastguard Worker
5125*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(size, requires_grad=True)
5126*da0073e9SAndroid Build Coastguard Worker        out = MyFunc.apply(inp, inp, True)
5127*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5128*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5129*da0073e9SAndroid Build Coastguard Worker            "Function 'MyFuncBackward' returned nan values in its 0th output.",
5130*da0073e9SAndroid Build Coastguard Worker        ):
5131*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
5132*da0073e9SAndroid Build Coastguard Worker                with detect_anomaly():
5133*da0073e9SAndroid Build Coastguard Worker                    out.backward()
5134*da0073e9SAndroid Build Coastguard Worker            self.assertIn("No forward pass information", str(w[0].message))
5135*da0073e9SAndroid Build Coastguard Worker
5136*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(size, requires_grad=True)
5137*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5138*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5139*da0073e9SAndroid Build Coastguard Worker            "Function 'MyFuncBackward' returned nan values in its 1th output.",
5140*da0073e9SAndroid Build Coastguard Worker        ):
5141*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
5142*da0073e9SAndroid Build Coastguard Worker                with detect_anomaly():
5143*da0073e9SAndroid Build Coastguard Worker                    out = MyFunc.apply(inp, inp, False)
5144*da0073e9SAndroid Build Coastguard Worker                    out.backward()
5145*da0073e9SAndroid Build Coastguard Worker            self.assertIn("MyFunc.apply", str(w[0].message))
5146*da0073e9SAndroid Build Coastguard Worker
5147*da0073e9SAndroid Build Coastguard Worker    def test_calculate_shape_util(self):
5148*da0073e9SAndroid Build Coastguard Worker        out = torch.randn(10, 5, requires_grad=True)
5149*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(5, 10, requires_grad=True)
5150*da0073e9SAndroid Build Coastguard Worker        out_shape, grad_shape = _calculate_shape(out, grad, False)
5151*da0073e9SAndroid Build Coastguard Worker
5152*da0073e9SAndroid Build Coastguard Worker        assert out_shape == torch.Size([10, 5])
5153*da0073e9SAndroid Build Coastguard Worker        assert grad_shape == torch.Size([5, 10])
5154*da0073e9SAndroid Build Coastguard Worker
5155*da0073e9SAndroid Build Coastguard Worker        out = torch.nested.as_nested_tensor(
5156*da0073e9SAndroid Build Coastguard Worker            [
5157*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 5, requires_grad=True),
5158*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 5, requires_grad=True),
5159*da0073e9SAndroid Build Coastguard Worker                torch.randn(10, 5, requires_grad=True),
5160*da0073e9SAndroid Build Coastguard Worker            ]
5161*da0073e9SAndroid Build Coastguard Worker        )
5162*da0073e9SAndroid Build Coastguard Worker        grad = torch.nested.as_nested_tensor(
5163*da0073e9SAndroid Build Coastguard Worker            [
5164*da0073e9SAndroid Build Coastguard Worker                torch.randn(5, 10, requires_grad=True),
5165*da0073e9SAndroid Build Coastguard Worker                torch.randn(5, 10, requires_grad=True),
5166*da0073e9SAndroid Build Coastguard Worker            ]
5167*da0073e9SAndroid Build Coastguard Worker        )
5168*da0073e9SAndroid Build Coastguard Worker        out_shape, grad_shape = _calculate_shape(out, grad, False)
5169*da0073e9SAndroid Build Coastguard Worker
5170*da0073e9SAndroid Build Coastguard Worker        assert torch.equal(out_shape, torch.tensor([[10, 5], [10, 5], [10, 5]]))
5171*da0073e9SAndroid Build Coastguard Worker        assert torch.equal(grad_shape, torch.tensor([[5, 10], [5, 10]]))
5172*da0073e9SAndroid Build Coastguard Worker
5173*da0073e9SAndroid Build Coastguard Worker    def test_nested_anomaly_detect_nan(self):
5174*da0073e9SAndroid Build Coastguard Worker        size = 10
5175*da0073e9SAndroid Build Coastguard Worker
5176*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
5177*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5178*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1, fail_0th):
5179*da0073e9SAndroid Build Coastguard Worker                ctx.fail_0th = fail_0th
5180*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(inp1)
5181*da0073e9SAndroid Build Coastguard Worker                return inp1.sum(0, keepdim=True)
5182*da0073e9SAndroid Build Coastguard Worker
5183*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5184*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
5185*da0073e9SAndroid Build Coastguard Worker                (inp,) = ctx.saved_tensors
5186*da0073e9SAndroid Build Coastguard Worker                fail_0th = ctx.fail_0th
5187*da0073e9SAndroid Build Coastguard Worker                g = gO.clone().expand(size)
5188*da0073e9SAndroid Build Coastguard Worker                gI = MyFunc2.apply(g * inp, g + inp, fail_0th)
5189*da0073e9SAndroid Build Coastguard Worker                return gI, None
5190*da0073e9SAndroid Build Coastguard Worker
5191*da0073e9SAndroid Build Coastguard Worker        class MyFunc2(Function):
5192*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5193*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1, inp2, fail_0th):
5194*da0073e9SAndroid Build Coastguard Worker                ctx.fail_0th = fail_0th
5195*da0073e9SAndroid Build Coastguard Worker                return inp1 * 2.0 + inp2
5196*da0073e9SAndroid Build Coastguard Worker
5197*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5198*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
5199*da0073e9SAndroid Build Coastguard Worker                fail_0th = ctx.fail_0th
5200*da0073e9SAndroid Build Coastguard Worker                g1 = gO.clone()
5201*da0073e9SAndroid Build Coastguard Worker                g2 = gO.clone()
5202*da0073e9SAndroid Build Coastguard Worker                g1[0] = 0
5203*da0073e9SAndroid Build Coastguard Worker                g2[0] = 0
5204*da0073e9SAndroid Build Coastguard Worker                # generate a nan
5205*da0073e9SAndroid Build Coastguard Worker                if fail_0th:
5206*da0073e9SAndroid Build Coastguard Worker                    g1[0] /= 0
5207*da0073e9SAndroid Build Coastguard Worker                else:
5208*da0073e9SAndroid Build Coastguard Worker                    g2[0] /= 0
5209*da0073e9SAndroid Build Coastguard Worker                return g1, g2, None
5210*da0073e9SAndroid Build Coastguard Worker
5211*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(size, requires_grad=True)
5212*da0073e9SAndroid Build Coastguard Worker        out = MyFunc.apply(inp, True)
5213*da0073e9SAndroid Build Coastguard Worker        (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5214*da0073e9SAndroid Build Coastguard Worker        gsum = ginp.sum()
5215*da0073e9SAndroid Build Coastguard Worker        gsum.backward()  # should not fail
5216*da0073e9SAndroid Build Coastguard Worker
5217*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(size, requires_grad=True)
5218*da0073e9SAndroid Build Coastguard Worker        out = MyFunc.apply(inp, True)
5219*da0073e9SAndroid Build Coastguard Worker        (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5220*da0073e9SAndroid Build Coastguard Worker        gsum = ginp.sum()
5221*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5222*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5223*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
5224*da0073e9SAndroid Build Coastguard Worker                "Function 'MyFunc2Backward' returned nan values in its 0th output.",
5225*da0073e9SAndroid Build Coastguard Worker            ):
5226*da0073e9SAndroid Build Coastguard Worker                with detect_anomaly():
5227*da0073e9SAndroid Build Coastguard Worker                    gsum.backward()
5228*da0073e9SAndroid Build Coastguard Worker        self.assertIn("No forward pass information", str(w[1].message))
5229*da0073e9SAndroid Build Coastguard Worker
5230*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(size, requires_grad=True)
5231*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5232*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5233*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
5234*da0073e9SAndroid Build Coastguard Worker                "Function 'MyFunc2Backward' returned nan values in its 1th output.",
5235*da0073e9SAndroid Build Coastguard Worker            ):
5236*da0073e9SAndroid Build Coastguard Worker                with detect_anomaly():
5237*da0073e9SAndroid Build Coastguard Worker                    out = MyFunc.apply(inp, False)
5238*da0073e9SAndroid Build Coastguard Worker                    (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5239*da0073e9SAndroid Build Coastguard Worker                    gsum = ginp.sum()
5240*da0073e9SAndroid Build Coastguard Worker                    gsum.backward()
5241*da0073e9SAndroid Build Coastguard Worker        self.assertIn("MyFunc2.apply", str(w[1].message))
5242*da0073e9SAndroid Build Coastguard Worker        self.assertIn("MyFunc.apply", str(w[2].message))
5243*da0073e9SAndroid Build Coastguard Worker
5244*da0073e9SAndroid Build Coastguard Worker    def test_anomaly_grad_warnings(self):
5245*da0073e9SAndroid Build Coastguard Worker        # PyTorch won't throw warnings if there is an error
5246*da0073e9SAndroid Build Coastguard Worker        # but we'd want to at least see them in stderr
5247*da0073e9SAndroid Build Coastguard Worker
5248*da0073e9SAndroid Build Coastguard Worker        class StdErrDiverter:
5249*da0073e9SAndroid Build Coastguard Worker            def __enter__(self):
5250*da0073e9SAndroid Build Coastguard Worker                self.stderr_orig = sys.stderr
5251*da0073e9SAndroid Build Coastguard Worker                self.stderr_new = io.StringIO()
5252*da0073e9SAndroid Build Coastguard Worker                sys.stderr = self.stderr_new
5253*da0073e9SAndroid Build Coastguard Worker                return self
5254*da0073e9SAndroid Build Coastguard Worker
5255*da0073e9SAndroid Build Coastguard Worker            def __exit__(self, *args):
5256*da0073e9SAndroid Build Coastguard Worker                self.captured = self.stderr_new.getvalue()
5257*da0073e9SAndroid Build Coastguard Worker                sys.stderr = self.stderr_orig
5258*da0073e9SAndroid Build Coastguard Worker
5259*da0073e9SAndroid Build Coastguard Worker        # if the warnings don't throw, they will be handled as regular warnings
5260*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5261*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5262*da0073e9SAndroid Build Coastguard Worker            "one of the variables needed for gradient computation has been "
5263*da0073e9SAndroid Build Coastguard Worker            "modified by an inplace operation",
5264*da0073e9SAndroid Build Coastguard Worker        ):
5265*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
5266*da0073e9SAndroid Build Coastguard Worker                with detect_anomaly():
5267*da0073e9SAndroid Build Coastguard Worker                    a = torch.randn(5, requires_grad=True)
5268*da0073e9SAndroid Build Coastguard Worker                    d1 = a + 1
5269*da0073e9SAndroid Build Coastguard Worker                    d2 = d1**2
5270*da0073e9SAndroid Build Coastguard Worker                    d1 += 1
5271*da0073e9SAndroid Build Coastguard Worker                    torch.autograd.grad(d2.sum(), a)
5272*da0073e9SAndroid Build Coastguard Worker
5273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), 2)
5274*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Anomaly Detection has been enabled", str(w[0].message))
5275*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Error detected in PowBackward0", str(w[1].message))
5276*da0073e9SAndroid Build Coastguard Worker
5277*da0073e9SAndroid Build Coastguard Worker        # if the warning throws, it will be printed to sys.stderr
5278*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
5279*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
5280*da0073e9SAndroid Build Coastguard Worker            "one of the variables needed for gradient computation has been "
5281*da0073e9SAndroid Build Coastguard Worker            "modified by an inplace operation",
5282*da0073e9SAndroid Build Coastguard Worker        ):
5283*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
5284*da0073e9SAndroid Build Coastguard Worker                with detect_anomaly():
5285*da0073e9SAndroid Build Coastguard Worker                    warnings.simplefilter("error")
5286*da0073e9SAndroid Build Coastguard Worker                    with StdErrDiverter() as s:
5287*da0073e9SAndroid Build Coastguard Worker                        a = torch.randn(5, requires_grad=True)
5288*da0073e9SAndroid Build Coastguard Worker                        d1 = a + 1
5289*da0073e9SAndroid Build Coastguard Worker                        d2 = d1**2
5290*da0073e9SAndroid Build Coastguard Worker                        d1 += 1
5291*da0073e9SAndroid Build Coastguard Worker                        torch.autograd.grad(d2.sum(), a)
5292*da0073e9SAndroid Build Coastguard Worker
5293*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), 1)
5294*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Anomaly Detection has been enabled", str(w[0].message))
5295*da0073e9SAndroid Build Coastguard Worker        self.assertIn("Error detected in PowBackward0", s.captured)
5296*da0073e9SAndroid Build Coastguard Worker
5297*da0073e9SAndroid Build Coastguard Worker    def test_anomaly_assign_parent_cleanup(self):
5298*da0073e9SAndroid Build Coastguard Worker        # Test that python objects created are properly cleaned up when assign_parent is called
5299*da0073e9SAndroid Build Coastguard Worker
5300*da0073e9SAndroid Build Coastguard Worker        def get_ref():
5301*da0073e9SAndroid Build Coastguard Worker            # we use torch.exp here but any function that will construct a new node in its
5302*da0073e9SAndroid Build Coastguard Worker            # backward call in grad mode will work
5303*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, 2, requires_grad=True)
5304*da0073e9SAndroid Build Coastguard Worker            t = x.exp()
5305*da0073e9SAndroid Build Coastguard Worker
5306*da0073e9SAndroid Build Coastguard Worker            # ExpBackward calls mul, creating the MulBackward node when create_graph=True.
5307*da0073e9SAndroid Build Coastguard Worker            # In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to
5308*da0073e9SAndroid Build Coastguard Worker            # MulBackward's anomaly metadata dict, creating the following reference chain:
5309*da0073e9SAndroid Build Coastguard Worker            #
5310*da0073e9SAndroid Build Coastguard Worker            # grad -> MulBackward -> PyObject -> ExpBackward
5311*da0073e9SAndroid Build Coastguard Worker            #
5312*da0073e9SAndroid Build Coastguard Worker            with detect_anomaly():
5313*da0073e9SAndroid Build Coastguard Worker                grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True)
5314*da0073e9SAndroid Build Coastguard Worker
5315*da0073e9SAndroid Build Coastguard Worker            # We add a weak reference to a new Foo object, which we insert into ExpBackward's metadata dict
5316*da0073e9SAndroid Build Coastguard Worker            #
5317*da0073e9SAndroid Build Coastguard Worker            # (PyObject) -> ExpBackward -> dict -> *Foo*
5318*da0073e9SAndroid Build Coastguard Worker            #            t ----^        WeakRef ---^
5319*da0073e9SAndroid Build Coastguard Worker            #
5320*da0073e9SAndroid Build Coastguard Worker            # We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed
5321*da0073e9SAndroid Build Coastguard Worker            # We can test this by seeing whether Foo is not kept alive once t is destroyed
5322*da0073e9SAndroid Build Coastguard Worker            class Foo:
5323*da0073e9SAndroid Build Coastguard Worker                pass
5324*da0073e9SAndroid Build Coastguard Worker
5325*da0073e9SAndroid Build Coastguard Worker            my_obj = Foo()
5326*da0073e9SAndroid Build Coastguard Worker            meta_dict = t.grad_fn.metadata
5327*da0073e9SAndroid Build Coastguard Worker            meta_dict[0] = my_obj
5328*da0073e9SAndroid Build Coastguard Worker            ref = weakref.ref(my_obj)
5329*da0073e9SAndroid Build Coastguard Worker            return t, ref
5330*da0073e9SAndroid Build Coastguard Worker
5331*da0073e9SAndroid Build Coastguard Worker        t, ref = get_ref()
5332*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(ref())
5333*da0073e9SAndroid Build Coastguard Worker        del t
5334*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(ref())
5335*da0073e9SAndroid Build Coastguard Worker
5336*da0073e9SAndroid Build Coastguard Worker    def test_nested_anomaly_printstack_cleanup(self):
5337*da0073e9SAndroid Build Coastguard Worker        # Test if metadata dict PyObject is properly destroyed
5338*da0073e9SAndroid Build Coastguard Worker        def get_ref():
5339*da0073e9SAndroid Build Coastguard Worker            # This is similar to the construction in test_anomaly_assign_parent_cleanup:
5340*da0073e9SAndroid Build Coastguard Worker            #
5341*da0073e9SAndroid Build Coastguard Worker            # MyFuncBackward2 -> PyObject -> MyFuncBackward -> dict -> Foo
5342*da0073e9SAndroid Build Coastguard Worker            #                               out ---^         WeakRef ---^
5343*da0073e9SAndroid Build Coastguard Worker            #
5344*da0073e9SAndroid Build Coastguard Worker            # We want to check that Foo is still properly destroyed even when MyFunc2Backward's
5345*da0073e9SAndroid Build Coastguard Worker            # AnomalyMetadata calls printstack, which does some python object manipulation.
5346*da0073e9SAndroid Build Coastguard Worker            #
5347*da0073e9SAndroid Build Coastguard Worker            # You might be wondering why we still have to test_anomaly_assign_parent_cleanup,
5348*da0073e9SAndroid Build Coastguard Worker            # since if PyObject is not destroyed here, wouldn't this test would detect that also?
5349*da0073e9SAndroid Build Coastguard Worker            # The answer is that custom function's PyObject (THPFunction) actually only hold
5350*da0073e9SAndroid Build Coastguard Worker            # a weak reference to the c++ node!
5351*da0073e9SAndroid Build Coastguard Worker            class MyFunc(Function):
5352*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5353*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x):
5354*da0073e9SAndroid Build Coastguard Worker                    ctx.save_for_backward(x)
5355*da0073e9SAndroid Build Coastguard Worker                    return x
5356*da0073e9SAndroid Build Coastguard Worker
5357*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5358*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, gO):
5359*da0073e9SAndroid Build Coastguard Worker                    (x,) = ctx.saved_tensors
5360*da0073e9SAndroid Build Coastguard Worker                    return MyFunc2.apply(x)
5361*da0073e9SAndroid Build Coastguard Worker
5362*da0073e9SAndroid Build Coastguard Worker            class MyFunc2(Function):
5363*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5364*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x):
5365*da0073e9SAndroid Build Coastguard Worker                    return x
5366*da0073e9SAndroid Build Coastguard Worker
5367*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5368*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, gO):
5369*da0073e9SAndroid Build Coastguard Worker                    return gO + float("NaN")
5370*da0073e9SAndroid Build Coastguard Worker
5371*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(1, requires_grad=True)
5372*da0073e9SAndroid Build Coastguard Worker            out = MyFunc.apply(inp)
5373*da0073e9SAndroid Build Coastguard Worker            (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True)
5374*da0073e9SAndroid Build Coastguard Worker
5375*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
5376*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
5377*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
5378*da0073e9SAndroid Build Coastguard Worker                    "Function 'MyFunc2Backward' returned nan values in its 0th output.",
5379*da0073e9SAndroid Build Coastguard Worker                ):
5380*da0073e9SAndroid Build Coastguard Worker                    with detect_anomaly():
5381*da0073e9SAndroid Build Coastguard Worker                        ginp.backward()
5382*da0073e9SAndroid Build Coastguard Worker
5383*da0073e9SAndroid Build Coastguard Worker            class Foo:
5384*da0073e9SAndroid Build Coastguard Worker                pass
5385*da0073e9SAndroid Build Coastguard Worker
5386*da0073e9SAndroid Build Coastguard Worker            my_obj = Foo()
5387*da0073e9SAndroid Build Coastguard Worker            meta_dict = out.grad_fn.metadata
5388*da0073e9SAndroid Build Coastguard Worker            meta_dict[0] = my_obj
5389*da0073e9SAndroid Build Coastguard Worker            ref = weakref.ref(my_obj)
5390*da0073e9SAndroid Build Coastguard Worker            return out, ref
5391*da0073e9SAndroid Build Coastguard Worker
5392*da0073e9SAndroid Build Coastguard Worker        t, ref = get_ref()
5393*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(ref())
5394*da0073e9SAndroid Build Coastguard Worker        del t
5395*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(ref())
5396*da0073e9SAndroid Build Coastguard Worker
5397*da0073e9SAndroid Build Coastguard Worker    def test_anomaly_mode_no_check_nan(self):
5398*da0073e9SAndroid Build Coastguard Worker        class MyFunc(torch.autograd.Function):
5399*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5400*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp):
5401*da0073e9SAndroid Build Coastguard Worker                return inp.clone()
5402*da0073e9SAndroid Build Coastguard Worker
5403*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5404*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
5405*da0073e9SAndroid Build Coastguard Worker                return torch.tensor(float("nan")).expand(10, 10)
5406*da0073e9SAndroid Build Coastguard Worker
5407*da0073e9SAndroid Build Coastguard Worker        def run_fn(a):
5408*da0073e9SAndroid Build Coastguard Worker            out = MyFunc.apply(a)
5409*da0073e9SAndroid Build Coastguard Worker            return out.sum()
5410*da0073e9SAndroid Build Coastguard Worker
5411*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
5412*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.detect_anomaly(check_nan=False):
5413*da0073e9SAndroid Build Coastguard Worker                inp = torch.rand(10, 10, requires_grad=True)
5414*da0073e9SAndroid Build Coastguard Worker                out = run_fn(inp)
5415*da0073e9SAndroid Build Coastguard Worker                out.backward(retain_graph=True)
5416*da0073e9SAndroid Build Coastguard Worker
5417*da0073e9SAndroid Build Coastguard Worker                with torch.autograd.detect_anomaly(check_nan=True):
5418*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(
5419*da0073e9SAndroid Build Coastguard Worker                        RuntimeError,
5420*da0073e9SAndroid Build Coastguard Worker                        "Function 'MyFuncBackward' returned nan values in its 0th output.",
5421*da0073e9SAndroid Build Coastguard Worker                    ):
5422*da0073e9SAndroid Build Coastguard Worker                        out.backward(retain_graph=True)
5423*da0073e9SAndroid Build Coastguard Worker
5424*da0073e9SAndroid Build Coastguard Worker                out.backward()
5425*da0073e9SAndroid Build Coastguard Worker
5426*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_copy(self):
5427*da0073e9SAndroid Build Coastguard Worker        # create autograd function that saves grad pointer as class static
5428*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
5429*da0073e9SAndroid Build Coastguard Worker            static_grad_ptr = None
5430*da0073e9SAndroid Build Coastguard Worker
5431*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5432*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1, inp2):
5433*da0073e9SAndroid Build Coastguard Worker                return inp1 + inp2
5434*da0073e9SAndroid Build Coastguard Worker
5435*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5436*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
5437*da0073e9SAndroid Build Coastguard Worker                MyFunc.static_grad_ptr = grad.data_ptr()
5438*da0073e9SAndroid Build Coastguard Worker                return grad, grad
5439*da0073e9SAndroid Build Coastguard Worker
5440*da0073e9SAndroid Build Coastguard Worker        class NonContGradFunc(Function):
5441*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5442*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1):
5443*da0073e9SAndroid Build Coastguard Worker                ctx.size = inp1.size()
5444*da0073e9SAndroid Build Coastguard Worker                return torch.tensor([1.0])
5445*da0073e9SAndroid Build Coastguard Worker
5446*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5447*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
5448*da0073e9SAndroid Build Coastguard Worker                return torch.ones(1).expand(ctx.size)
5449*da0073e9SAndroid Build Coastguard Worker
5450*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, 6, requires_grad=True)
5451*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(5, 6, requires_grad=True)
5452*da0073e9SAndroid Build Coastguard Worker        # non-contiguous grad should be copied
5453*da0073e9SAndroid Build Coastguard Worker        NonContGradFunc.apply(MyFunc.apply(a, b)).backward()
5454*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(a.grad.data_ptr() == MyFunc.static_grad_ptr)
5455*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(b.grad.data_ptr() == MyFunc.static_grad_ptr)
5456*da0073e9SAndroid Build Coastguard Worker        # test case that should trigger no copy for one of a,b
5457*da0073e9SAndroid Build Coastguard Worker        a.grad = b.grad = None
5458*da0073e9SAndroid Build Coastguard Worker        MyFunc.apply(a, b)[1][0].backward()
5459*da0073e9SAndroid Build Coastguard Worker        p_g = MyFunc.static_grad_ptr
5460*da0073e9SAndroid Build Coastguard Worker        p_a = a.grad.data_ptr()
5461*da0073e9SAndroid Build Coastguard Worker        p_b = b.grad.data_ptr()
5462*da0073e9SAndroid Build Coastguard Worker        # check a,b uses different grad buffer
5463*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(p_a == p_b)
5464*da0073e9SAndroid Build Coastguard Worker        # check one of them is using the computed buffer
5465*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(p_a == p_g or p_b == p_g)
5466*da0073e9SAndroid Build Coastguard Worker
5467*da0073e9SAndroid Build Coastguard Worker    def test_no_grad_copy_sparse(self):
5468*da0073e9SAndroid Build Coastguard Worker        # create autograd function that saves grad pointer as class static
5469*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
5470*da0073e9SAndroid Build Coastguard Worker            static_grad_ptr = None
5471*da0073e9SAndroid Build Coastguard Worker
5472*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5473*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1, inp2):
5474*da0073e9SAndroid Build Coastguard Worker                return inp1 + inp2
5475*da0073e9SAndroid Build Coastguard Worker
5476*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5477*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
5478*da0073e9SAndroid Build Coastguard Worker                MyFunc.static_grad_ptr = grad._values().data_ptr()
5479*da0073e9SAndroid Build Coastguard Worker                return grad, grad
5480*da0073e9SAndroid Build Coastguard Worker
5481*da0073e9SAndroid Build Coastguard Worker        class NonContGradFunc(Function):
5482*da0073e9SAndroid Build Coastguard Worker            static_grad_ptr = None
5483*da0073e9SAndroid Build Coastguard Worker
5484*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5485*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp1, inp2):
5486*da0073e9SAndroid Build Coastguard Worker                return inp1 + inp2
5487*da0073e9SAndroid Build Coastguard Worker
5488*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5489*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
5490*da0073e9SAndroid Build Coastguard Worker                # Create a sparse tensor with non-contigous indices and values
5491*da0073e9SAndroid Build Coastguard Worker                # and return as grad.
5492*da0073e9SAndroid Build Coastguard Worker                v = torch.rand(1, 3)
5493*da0073e9SAndroid Build Coastguard Worker                i = torch.ones(1, 1, dtype=torch.long)
5494*da0073e9SAndroid Build Coastguard Worker                nv = v.expand(8, 3)
5495*da0073e9SAndroid Build Coastguard Worker                ni = i.expand(1, 8)
5496*da0073e9SAndroid Build Coastguard Worker                ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32)
5497*da0073e9SAndroid Build Coastguard Worker                NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr()
5498*da0073e9SAndroid Build Coastguard Worker                return ngrad, ngrad
5499*da0073e9SAndroid Build Coastguard Worker
5500*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(10, 3, requires_grad=True)
5501*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(10, 3, requires_grad=True)
5502*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9])
5503*da0073e9SAndroid Build Coastguard Worker        offsets = torch.tensor([0, 4])
5504*da0073e9SAndroid Build Coastguard Worker        import torch.nn.functional as F
5505*da0073e9SAndroid Build Coastguard Worker
5506*da0073e9SAndroid Build Coastguard Worker        # test case that should trigger no copy for one of a,b
5507*da0073e9SAndroid Build Coastguard Worker        emb_matrix = MyFunc.apply(a, b)
5508*da0073e9SAndroid Build Coastguard Worker        loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
5509*da0073e9SAndroid Build Coastguard Worker        loss.backward(retain_graph=True)
5510*da0073e9SAndroid Build Coastguard Worker        p_g = MyFunc.static_grad_ptr
5511*da0073e9SAndroid Build Coastguard Worker        p_a = a.grad._values().data_ptr()
5512*da0073e9SAndroid Build Coastguard Worker        p_b = b.grad._values().data_ptr()
5513*da0073e9SAndroid Build Coastguard Worker        # check a,b uses different grad buffer
5514*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(p_a == p_b)
5515*da0073e9SAndroid Build Coastguard Worker        # check one of them is using the computed buffer
5516*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(p_a == p_g or p_b == p_g)
5517*da0073e9SAndroid Build Coastguard Worker
5518*da0073e9SAndroid Build Coastguard Worker        # Run backwards multiple times to ensure accumulation works.
5519*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
5520*da0073e9SAndroid Build Coastguard Worker            loss.backward(retain_graph=True)
5521*da0073e9SAndroid Build Coastguard Worker
5522*da0073e9SAndroid Build Coastguard Worker        # non-contiguous indices and value, we should trigger a copy.
5523*da0073e9SAndroid Build Coastguard Worker        a.grad = b.grad = None
5524*da0073e9SAndroid Build Coastguard Worker        emb_matrix = NonContGradFunc.apply(a, b)
5525*da0073e9SAndroid Build Coastguard Worker        loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum()
5526*da0073e9SAndroid Build Coastguard Worker        loss.backward(retain_graph=True)
5527*da0073e9SAndroid Build Coastguard Worker        p_g = NonContGradFunc.static_grad_ptr
5528*da0073e9SAndroid Build Coastguard Worker        p_a = a.grad._values().data_ptr()
5529*da0073e9SAndroid Build Coastguard Worker        p_b = b.grad._values().data_ptr()
5530*da0073e9SAndroid Build Coastguard Worker        # check a,b uses different grad buffer
5531*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(p_a == p_b)
5532*da0073e9SAndroid Build Coastguard Worker        # Verify we cloned both grads.
5533*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(p_a == p_g)
5534*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(p_b == p_g)
5535*da0073e9SAndroid Build Coastguard Worker
5536*da0073e9SAndroid Build Coastguard Worker        # Run backwards multiple times to ensure accumulation works.
5537*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
5538*da0073e9SAndroid Build Coastguard Worker            loss.backward(retain_graph=True)
5539*da0073e9SAndroid Build Coastguard Worker
5540*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_single_input(self):
5541*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5542*da0073e9SAndroid Build Coastguard Worker            def f(inp):
5543*da0073e9SAndroid Build Coastguard Worker                return inp.mul(5)
5544*da0073e9SAndroid Build Coastguard Worker
5545*da0073e9SAndroid Build Coastguard Worker            gradcheck(
5546*da0073e9SAndroid Build Coastguard Worker                f,
5547*da0073e9SAndroid Build Coastguard Worker                torch.rand(10, dtype=torch.float64, requires_grad=True),
5548*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5549*da0073e9SAndroid Build Coastguard Worker            )
5550*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(
5551*da0073e9SAndroid Build Coastguard Worker                f,
5552*da0073e9SAndroid Build Coastguard Worker                torch.rand(10, dtype=torch.float64, requires_grad=True),
5553*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5554*da0073e9SAndroid Build Coastguard Worker            )
5555*da0073e9SAndroid Build Coastguard Worker
5556*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5557*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5558*da0073e9SAndroid Build Coastguard Worker
5559*da0073e9SAndroid Build Coastguard Worker    @parametrize(
5560*da0073e9SAndroid Build Coastguard Worker        "layout",
5561*da0073e9SAndroid Build Coastguard Worker        (
5562*da0073e9SAndroid Build Coastguard Worker            torch.sparse_coo,
5563*da0073e9SAndroid Build Coastguard Worker            torch.sparse_csr,
5564*da0073e9SAndroid Build Coastguard Worker            torch.sparse_csc,
5565*da0073e9SAndroid Build Coastguard Worker            torch.sparse_bsr,
5566*da0073e9SAndroid Build Coastguard Worker            torch.sparse_bsc,
5567*da0073e9SAndroid Build Coastguard Worker        ),
5568*da0073e9SAndroid Build Coastguard Worker    )
5569*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_input(self, layout):
5570*da0073e9SAndroid Build Coastguard Worker        if layout in {torch.sparse_bsr, torch.sparse_bsc}:
5571*da0073e9SAndroid Build Coastguard Worker            blocksize = (2, 2)
5572*da0073e9SAndroid Build Coastguard Worker            size = (4, 8)
5573*da0073e9SAndroid Build Coastguard Worker        else:
5574*da0073e9SAndroid Build Coastguard Worker            blocksize = None
5575*da0073e9SAndroid Build Coastguard Worker            size = (2, 2)
5576*da0073e9SAndroid Build Coastguard Worker
5577*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode, masked):
5578*da0073e9SAndroid Build Coastguard Worker            def fn(sparse):
5579*da0073e9SAndroid Build Coastguard Worker                return torch.sum(sparse)
5580*da0073e9SAndroid Build Coastguard Worker
5581*da0073e9SAndroid Build Coastguard Worker            gradcheck(
5582*da0073e9SAndroid Build Coastguard Worker                fn,
5583*da0073e9SAndroid Build Coastguard Worker                torch.rand(size, dtype=torch.double)
5584*da0073e9SAndroid Build Coastguard Worker                .to_sparse(layout=layout, blocksize=blocksize)
5585*da0073e9SAndroid Build Coastguard Worker                .requires_grad_(),
5586*da0073e9SAndroid Build Coastguard Worker                masked=masked,
5587*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
5588*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5589*da0073e9SAndroid Build Coastguard Worker            )
5590*da0073e9SAndroid Build Coastguard Worker
5591*da0073e9SAndroid Build Coastguard Worker        for fast_mode, masked in product(*[(True, False)] * 2):
5592*da0073e9SAndroid Build Coastguard Worker            check(fast_mode=fast_mode, masked=masked)
5593*da0073e9SAndroid Build Coastguard Worker
5594*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_nondeterministic(self):
5595*da0073e9SAndroid Build Coastguard Worker        class NonDetFunc(Function):
5596*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5597*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, jitter=0.0):
5598*da0073e9SAndroid Build Coastguard Worker                ctx._jitter = jitter
5599*da0073e9SAndroid Build Coastguard Worker                return x
5600*da0073e9SAndroid Build Coastguard Worker
5601*da0073e9SAndroid Build Coastguard Worker            @staticmethod
5602*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_out):
5603*da0073e9SAndroid Build Coastguard Worker                return (
5604*da0073e9SAndroid Build Coastguard Worker                    NonDetFunc.apply(grad_out, ctx._jitter)
5605*da0073e9SAndroid Build Coastguard Worker                    * (1 + torch.rand_like(grad_out) * ctx._jitter),
5606*da0073e9SAndroid Build Coastguard Worker                    None,
5607*da0073e9SAndroid Build Coastguard Worker                )
5608*da0073e9SAndroid Build Coastguard Worker
5609*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5610*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(5, 5, dtype=torch.double, requires_grad=True)
5611*da0073e9SAndroid Build Coastguard Worker            gradcheck(
5612*da0073e9SAndroid Build Coastguard Worker                lambda x: NonDetFunc.apply(x, 0.0),
5613*da0073e9SAndroid Build Coastguard Worker                inp,
5614*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
5615*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5616*da0073e9SAndroid Build Coastguard Worker            )
5617*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"):
5618*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5619*da0073e9SAndroid Build Coastguard Worker                    lambda x: NonDetFunc.apply(x, 1e-6),
5620*da0073e9SAndroid Build Coastguard Worker                    inp,
5621*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5622*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5623*da0073e9SAndroid Build Coastguard Worker                )
5624*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"):
5625*da0073e9SAndroid Build Coastguard Worker                gradgradcheck(
5626*da0073e9SAndroid Build Coastguard Worker                    lambda x: NonDetFunc.apply(x, 1e-12),
5627*da0073e9SAndroid Build Coastguard Worker                    inp,
5628*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5629*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5630*da0073e9SAndroid Build Coastguard Worker                )
5631*da0073e9SAndroid Build Coastguard Worker            gradcheck(
5632*da0073e9SAndroid Build Coastguard Worker                lambda x: NonDetFunc.apply(x, 0.0),
5633*da0073e9SAndroid Build Coastguard Worker                inp,
5634*da0073e9SAndroid Build Coastguard Worker                nondet_tol=1e-5,
5635*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
5636*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5637*da0073e9SAndroid Build Coastguard Worker            )
5638*da0073e9SAndroid Build Coastguard Worker            gradcheck(
5639*da0073e9SAndroid Build Coastguard Worker                lambda x: NonDetFunc.apply(x, 1e-6),
5640*da0073e9SAndroid Build Coastguard Worker                inp,
5641*da0073e9SAndroid Build Coastguard Worker                nondet_tol=1e-5,
5642*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
5643*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5644*da0073e9SAndroid Build Coastguard Worker            )
5645*da0073e9SAndroid Build Coastguard Worker            gradgradcheck(
5646*da0073e9SAndroid Build Coastguard Worker                lambda x: NonDetFunc.apply(x, 1e-12),
5647*da0073e9SAndroid Build Coastguard Worker                inp,
5648*da0073e9SAndroid Build Coastguard Worker                nondet_tol=1e-5,
5649*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=False,
5650*da0073e9SAndroid Build Coastguard Worker                fast_mode=fast_mode,
5651*da0073e9SAndroid Build Coastguard Worker            )
5652*da0073e9SAndroid Build Coastguard Worker
5653*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5654*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5655*da0073e9SAndroid Build Coastguard Worker
5656*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_validates_inputs(self):
5657*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5658*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, requires_grad=True).to_sparse()
5659*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
5660*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5661*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(),
5662*da0073e9SAndroid Build Coastguard Worker                    (x,),
5663*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5664*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5665*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5666*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5667*da0073e9SAndroid Build Coastguard Worker                )
5668*da0073e9SAndroid Build Coastguard Worker            )
5669*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5670*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5671*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(),
5672*da0073e9SAndroid Build Coastguard Worker                    (x,),
5673*da0073e9SAndroid Build Coastguard Worker                    masked=False,
5674*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5675*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5676*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5677*da0073e9SAndroid Build Coastguard Worker                )
5678*da0073e9SAndroid Build Coastguard Worker            )
5679*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
5680*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5681*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(masked_grad=False),
5682*da0073e9SAndroid Build Coastguard Worker                    (x,),
5683*da0073e9SAndroid Build Coastguard Worker                    masked=False,
5684*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5685*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5686*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5687*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5688*da0073e9SAndroid Build Coastguard Worker                )
5689*da0073e9SAndroid Build Coastguard Worker            )
5690*da0073e9SAndroid Build Coastguard Worker
5691*da0073e9SAndroid Build Coastguard Worker            # when none of the inputs require grad (always raises even if raise_exception=False)
5692*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, requires_grad=False)
5693*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5694*da0073e9SAndroid Build Coastguard Worker                ValueError, "at least one input tensor to require gradient"
5695*da0073e9SAndroid Build Coastguard Worker            ):
5696*da0073e9SAndroid Build Coastguard Worker                gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
5697*da0073e9SAndroid Build Coastguard Worker
5698*da0073e9SAndroid Build Coastguard Worker            # (warning) when inputs are not double precision
5699*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, dtype=torch.float32, requires_grad=True)
5700*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(
5701*da0073e9SAndroid Build Coastguard Worker                UserWarning, "Input #0 requires gradient and is not a double precision"
5702*da0073e9SAndroid Build Coastguard Worker            ):
5703*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(
5704*da0073e9SAndroid Build Coastguard Worker                    gradcheck(lambda x: x, (x,), atol=1e-1, fast_mode=fast_mode)
5705*da0073e9SAndroid Build Coastguard Worker                )
5706*da0073e9SAndroid Build Coastguard Worker
5707*da0073e9SAndroid Build Coastguard Worker            # when layout is not mkldnn(aka has strides) and input has a dimension with stride 0. (always raises
5708*da0073e9SAndroid Build Coastguard Worker            # even if raise_exception=False)
5709*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, dtype=torch.float64, requires_grad=True)
5710*da0073e9SAndroid Build Coastguard Worker            x = x.expand((2, 2))
5711*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5712*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "The 0th input has a dimension with stride 0"
5713*da0073e9SAndroid Build Coastguard Worker            ):
5714*da0073e9SAndroid Build Coastguard Worker                gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode)
5715*da0073e9SAndroid Build Coastguard Worker
5716*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5717*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5718*da0073e9SAndroid Build Coastguard Worker
5719*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
5720*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
5721*da0073e9SAndroid Build Coastguard Worker    )
5722*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_validates_input_mkldnn(self):
5723*da0073e9SAndroid Build Coastguard Worker        # when mkldnn inputs, forward mode testing is not allowed
5724*da0073e9SAndroid Build Coastguard Worker        # Update tolerances below to make sure the gradient match even in single precision floats
5725*da0073e9SAndroid Build Coastguard Worker        # Use the warning assert to hide the float32 warning
5726*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(1).to_mkldnn().requires_grad_()
5727*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
5728*da0073e9SAndroid Build Coastguard Worker            UserWarning, "Input #0 requires gradient and is not a double precision"
5729*da0073e9SAndroid Build Coastguard Worker        ):
5730*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5731*da0073e9SAndroid Build Coastguard Worker                ValueError, "MKLDNN inputs are not support for forward AD gradcheck."
5732*da0073e9SAndroid Build Coastguard Worker            ):
5733*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5734*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(),
5735*da0073e9SAndroid Build Coastguard Worker                    (x,),
5736*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5737*da0073e9SAndroid Build Coastguard Worker                    fast_mode=False,
5738*da0073e9SAndroid Build Coastguard Worker                    check_forward_ad=True,
5739*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5740*da0073e9SAndroid Build Coastguard Worker                    rtol=1e-1,
5741*da0073e9SAndroid Build Coastguard Worker                )
5742*da0073e9SAndroid Build Coastguard Worker
5743*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
5744*da0073e9SAndroid Build Coastguard Worker            UserWarning, "Input #0 requires gradient and is not a double precision"
5745*da0073e9SAndroid Build Coastguard Worker        ):
5746*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5747*da0073e9SAndroid Build Coastguard Worker                ValueError, "MKLDNN inputs are not support for forward AD gradcheck."
5748*da0073e9SAndroid Build Coastguard Worker            ):
5749*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5750*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(),
5751*da0073e9SAndroid Build Coastguard Worker                    (x,),
5752*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5753*da0073e9SAndroid Build Coastguard Worker                    fast_mode=True,
5754*da0073e9SAndroid Build Coastguard Worker                    check_forward_ad=True,
5755*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5756*da0073e9SAndroid Build Coastguard Worker                    rtol=1e-1,
5757*da0073e9SAndroid Build Coastguard Worker                )
5758*da0073e9SAndroid Build Coastguard Worker
5759*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
5760*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
5761*da0073e9SAndroid Build Coastguard Worker    )
5762*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_test_outputs(self):
5763*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5764*da0073e9SAndroid Build Coastguard Worker            # when sparse outputs (always raise even if raise_exception=False)
5765*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, requires_grad=True).to_sparse()
5766*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5767*da0073e9SAndroid Build Coastguard Worker                ValueError, "Sparse output is not supported at gradcheck yet"
5768*da0073e9SAndroid Build Coastguard Worker            ):
5769*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5770*da0073e9SAndroid Build Coastguard Worker                    lambda x: x,
5771*da0073e9SAndroid Build Coastguard Worker                    (x,),
5772*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5773*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5774*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5775*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5776*da0073e9SAndroid Build Coastguard Worker                )
5777*da0073e9SAndroid Build Coastguard Worker
5778*da0073e9SAndroid Build Coastguard Worker            # when mkldnn outputs (always raise even if raise_exception=False)
5779*da0073e9SAndroid Build Coastguard Worker            root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True)
5780*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5781*da0073e9SAndroid Build Coastguard Worker                ValueError, "MKLDNN output is not supported at gradcheck yet"
5782*da0073e9SAndroid Build Coastguard Worker            ):
5783*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5784*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_mkldnn(),
5785*da0073e9SAndroid Build Coastguard Worker                    (root,),
5786*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5787*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5788*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5789*da0073e9SAndroid Build Coastguard Worker                )
5790*da0073e9SAndroid Build Coastguard Worker
5791*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5792*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5793*da0073e9SAndroid Build Coastguard Worker
5794*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_check_no_differentiable_outputs(self):
5795*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5796*da0073e9SAndroid Build Coastguard Worker            # When none of the outputs are differentiable, but numerical gradient is not zero
5797*da0073e9SAndroid Build Coastguard Worker            x = torch.ones((1,), requires_grad=True)
5798*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5799*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Numerical gradient for function expected to be zero"
5800*da0073e9SAndroid Build Coastguard Worker            ):
5801*da0073e9SAndroid Build Coastguard Worker                gradcheck(lambda x: torch.tensor([x]), x)
5802*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5803*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5804*da0073e9SAndroid Build Coastguard Worker                    lambda x: torch.tensor([x]),
5805*da0073e9SAndroid Build Coastguard Worker                    x,
5806*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5807*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5808*da0073e9SAndroid Build Coastguard Worker                )
5809*da0073e9SAndroid Build Coastguard Worker            )
5810*da0073e9SAndroid Build Coastguard Worker
5811*da0073e9SAndroid Build Coastguard Worker            # succeed when no outputs at all
5812*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(lambda x: (), (x,), fast_mode=fast_mode))
5813*da0073e9SAndroid Build Coastguard Worker
5814*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5815*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5816*da0073e9SAndroid Build Coastguard Worker
5817*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_check_batched_grad(self):
5818*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5819*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(10, dtype=torch.double, requires_grad=True).to_sparse()
5820*da0073e9SAndroid Build Coastguard Worker            # runtime error while compute batched grad (print big error)
5821*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5822*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
5823*da0073e9SAndroid Build Coastguard Worker                "gradcheck or gradgradcheck failed while testing batched gradient",
5824*da0073e9SAndroid Build Coastguard Worker            ):
5825*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5826*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(),
5827*da0073e9SAndroid Build Coastguard Worker                    (x,),
5828*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5829*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=True,
5830*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5831*da0073e9SAndroid Build Coastguard Worker                )
5832*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5833*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5834*da0073e9SAndroid Build Coastguard Worker                    lambda x: x.to_dense(),
5835*da0073e9SAndroid Build Coastguard Worker                    (x,),
5836*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5837*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=True,
5838*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5839*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5840*da0073e9SAndroid Build Coastguard Worker                )
5841*da0073e9SAndroid Build Coastguard Worker            )
5842*da0073e9SAndroid Build Coastguard Worker
5843*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5844*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5845*da0073e9SAndroid Build Coastguard Worker
5846*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_backward_mul_by_grad_output(self):
5847*da0073e9SAndroid Build Coastguard Worker        # when grad_input is sparse and has incorrect sparse_dim/dense_dim
5848*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5849*da0073e9SAndroid Build Coastguard Worker            def fn(x):
5850*da0073e9SAndroid Build Coastguard Worker                def hook(grad):
5851*da0073e9SAndroid Build Coastguard Worker                    if grad is not None:
5852*da0073e9SAndroid Build Coastguard Worker                        return grad.to_dense().to_sparse(1)
5853*da0073e9SAndroid Build Coastguard Worker                    return grad
5854*da0073e9SAndroid Build Coastguard Worker
5855*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
5856*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook)
5857*da0073e9SAndroid Build Coastguard Worker                return y.to_dense()
5858*da0073e9SAndroid Build Coastguard Worker
5859*da0073e9SAndroid Build Coastguard Worker            x = torch.ones((2, 2), dtype=torch.double, requires_grad=True).to_sparse()
5860*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5861*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "grad is sparse tensor, but has incorrect sparse_dim"
5862*da0073e9SAndroid Build Coastguard Worker            ):
5863*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5864*da0073e9SAndroid Build Coastguard Worker                    fn,
5865*da0073e9SAndroid Build Coastguard Worker                    (x,),
5866*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5867*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5868*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5869*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5870*da0073e9SAndroid Build Coastguard Worker                )
5871*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5872*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5873*da0073e9SAndroid Build Coastguard Worker                    fn,
5874*da0073e9SAndroid Build Coastguard Worker                    (x,),
5875*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5876*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5877*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5878*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5879*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5880*da0073e9SAndroid Build Coastguard Worker                )
5881*da0073e9SAndroid Build Coastguard Worker            )
5882*da0073e9SAndroid Build Coastguard Worker
5883*da0073e9SAndroid Build Coastguard Worker            # when backward not multiplied by grad_output (non-sparse case)
5884*da0073e9SAndroid Build Coastguard Worker            def fn2(x):
5885*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
5886*da0073e9SAndroid Build Coastguard Worker                y.register_hook(lambda x: x + 1e-2)
5887*da0073e9SAndroid Build Coastguard Worker                return y
5888*da0073e9SAndroid Build Coastguard Worker
5889*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, dtype=torch.double, requires_grad=True)
5890*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5891*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "backward not multiplied by grad_output"
5892*da0073e9SAndroid Build Coastguard Worker            ):
5893*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn2, (x,), atol=1e-1, fast_mode=fast_mode)
5894*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5895*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5896*da0073e9SAndroid Build Coastguard Worker                    fn2, (x,), atol=1e-1, raise_exception=False, fast_mode=fast_mode
5897*da0073e9SAndroid Build Coastguard Worker                )
5898*da0073e9SAndroid Build Coastguard Worker            )
5899*da0073e9SAndroid Build Coastguard Worker
5900*da0073e9SAndroid Build Coastguard Worker            # when backward not multiplied by grad_output (sparse case)
5901*da0073e9SAndroid Build Coastguard Worker            def fn3(x):
5902*da0073e9SAndroid Build Coastguard Worker                y = x.clone().to_dense()
5903*da0073e9SAndroid Build Coastguard Worker                y.register_hook(lambda x: x + 1e-2)
5904*da0073e9SAndroid Build Coastguard Worker                return y
5905*da0073e9SAndroid Build Coastguard Worker
5906*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, dtype=torch.double, requires_grad=True).to_sparse()
5907*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5908*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "backward not multiplied by grad_output"
5909*da0073e9SAndroid Build Coastguard Worker            ):
5910*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5911*da0073e9SAndroid Build Coastguard Worker                    fn3,
5912*da0073e9SAndroid Build Coastguard Worker                    (x,),
5913*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5914*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5915*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5916*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5917*da0073e9SAndroid Build Coastguard Worker                )
5918*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5919*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5920*da0073e9SAndroid Build Coastguard Worker                    fn3,
5921*da0073e9SAndroid Build Coastguard Worker                    (x,),
5922*da0073e9SAndroid Build Coastguard Worker                    atol=1e-1,
5923*da0073e9SAndroid Build Coastguard Worker                    masked=True,
5924*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5925*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5926*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5927*da0073e9SAndroid Build Coastguard Worker                )
5928*da0073e9SAndroid Build Coastguard Worker            )
5929*da0073e9SAndroid Build Coastguard Worker
5930*da0073e9SAndroid Build Coastguard Worker            # when layout of grad_input is not the same as input
5931*da0073e9SAndroid Build Coastguard Worker            class Test(Function):
5932*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5933*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x):
5934*da0073e9SAndroid Build Coastguard Worker                    return x
5935*da0073e9SAndroid Build Coastguard Worker
5936*da0073e9SAndroid Build Coastguard Worker                @staticmethod
5937*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, x):
5938*da0073e9SAndroid Build Coastguard Worker                    return x.to_sparse()
5939*da0073e9SAndroid Build Coastguard Worker
5940*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, dtype=torch.double, requires_grad=True)
5941*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "grad is incorrect layout"):
5942*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5943*da0073e9SAndroid Build Coastguard Worker                    Test.apply, (x,), check_batched_grad=False, fast_mode=fast_mode
5944*da0073e9SAndroid Build Coastguard Worker                )
5945*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
5946*da0073e9SAndroid Build Coastguard Worker                gradcheck(
5947*da0073e9SAndroid Build Coastguard Worker                    Test.apply,
5948*da0073e9SAndroid Build Coastguard Worker                    (x,),
5949*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
5950*da0073e9SAndroid Build Coastguard Worker                    raise_exception=False,
5951*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
5952*da0073e9SAndroid Build Coastguard Worker                )
5953*da0073e9SAndroid Build Coastguard Worker            )
5954*da0073e9SAndroid Build Coastguard Worker
5955*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5956*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5957*da0073e9SAndroid Build Coastguard Worker
5958*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_undefined_grad(self):
5959*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5960*da0073e9SAndroid Build Coastguard Worker            # when encounter runtime error while running backward
5961*da0073e9SAndroid Build Coastguard Worker            def fn(x):
5962*da0073e9SAndroid Build Coastguard Worker                def hook(x):
5963*da0073e9SAndroid Build Coastguard Worker                    if x is None:
5964*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError("x is undefined")
5965*da0073e9SAndroid Build Coastguard Worker
5966*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
5967*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook)
5968*da0073e9SAndroid Build Coastguard Worker                return y
5969*da0073e9SAndroid Build Coastguard Worker
5970*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(1, dtype=torch.double, requires_grad=True)
5971*da0073e9SAndroid Build Coastguard Worker            with self.assertWarnsRegex(
5972*da0073e9SAndroid Build Coastguard Worker                UserWarning,
5973*da0073e9SAndroid Build Coastguard Worker                "Backwards compatibility: New undefined gradient support checking feature",
5974*da0073e9SAndroid Build Coastguard Worker            ):
5975*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
5976*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
5977*da0073e9SAndroid Build Coastguard Worker                    "Expected backward function to handle undefined output grads",
5978*da0073e9SAndroid Build Coastguard Worker                ):
5979*da0073e9SAndroid Build Coastguard Worker                    gradcheck(fn, (x,), fast_mode=fast_mode)
5980*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(
5981*da0073e9SAndroid Build Coastguard Worker                    gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode)
5982*da0073e9SAndroid Build Coastguard Worker                )
5983*da0073e9SAndroid Build Coastguard Worker
5984*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
5985*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
5986*da0073e9SAndroid Build Coastguard Worker
5987*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_jacobian_mismatch(self):
5988*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
5989*da0073e9SAndroid Build Coastguard Worker            def fn(x):  # R -> R, C -> C
5990*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
5991*da0073e9SAndroid Build Coastguard Worker                y.register_hook(lambda x: x + 1e-2)
5992*da0073e9SAndroid Build Coastguard Worker                return y
5993*da0073e9SAndroid Build Coastguard Worker
5994*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2, requires_grad=True)
5995*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
5996*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Jacobian mismatch for output 0 with respect to input 0"
5997*da0073e9SAndroid Build Coastguard Worker            ):
5998*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x,), fast_mode=fast_mode)
5999*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6000*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode)
6001*da0073e9SAndroid Build Coastguard Worker            )
6002*da0073e9SAndroid Build Coastguard Worker
6003*da0073e9SAndroid Build Coastguard Worker            x_c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
6004*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6005*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
6006*da0073e9SAndroid Build Coastguard Worker                "While considering the imaginary part of complex outputs only",
6007*da0073e9SAndroid Build Coastguard Worker            ):
6008*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x_c,), fast_mode=False)
6009*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6010*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x_c,), raise_exception=False, fast_mode=False)
6011*da0073e9SAndroid Build Coastguard Worker            )
6012*da0073e9SAndroid Build Coastguard Worker
6013*da0073e9SAndroid Build Coastguard Worker            def fn2(x):  # R -> C
6014*da0073e9SAndroid Build Coastguard Worker                y = torch.complex(x, x)
6015*da0073e9SAndroid Build Coastguard Worker                y.register_hook(lambda x: x + 1e-2)
6016*da0073e9SAndroid Build Coastguard Worker                return y
6017*da0073e9SAndroid Build Coastguard Worker
6018*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2, requires_grad=True)
6019*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6020*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
6021*da0073e9SAndroid Build Coastguard Worker                "While considering the imaginary part of complex outputs only",
6022*da0073e9SAndroid Build Coastguard Worker            ):
6023*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn2, (x,), fast_mode=False)
6024*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6025*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn2, (x,), raise_exception=False, fast_mode=False)
6026*da0073e9SAndroid Build Coastguard Worker            )
6027*da0073e9SAndroid Build Coastguard Worker
6028*da0073e9SAndroid Build Coastguard Worker            def fn3(x):  # C -> R
6029*da0073e9SAndroid Build Coastguard Worker                y = torch.real(x)
6030*da0073e9SAndroid Build Coastguard Worker                y.register_hook(lambda x: x + 1e-2)
6031*da0073e9SAndroid Build Coastguard Worker                return y
6032*da0073e9SAndroid Build Coastguard Worker
6033*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6034*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Jacobian mismatch for output 0 with respect to input 0"
6035*da0073e9SAndroid Build Coastguard Worker            ):
6036*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn3, (x_c,), fast_mode=False)
6037*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6038*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn3, (x_c,), raise_exception=False, fast_mode=False)
6039*da0073e9SAndroid Build Coastguard Worker            )
6040*da0073e9SAndroid Build Coastguard Worker
6041*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
6042*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
6043*da0073e9SAndroid Build Coastguard Worker
6044*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_dense_and_sparse_inputs(self):
6045*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
6046*da0073e9SAndroid Build Coastguard Worker            def fn(x, y):
6047*da0073e9SAndroid Build Coastguard Worker                return x * y.coalesce().to_dense()
6048*da0073e9SAndroid Build Coastguard Worker
6049*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 2, dtype=torch.double, requires_grad=True)
6050*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(2, 2, dtype=torch.double).to_sparse().requires_grad_(True)
6051*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
6052*da0073e9SAndroid Build Coastguard Worker                gradcheck(
6053*da0073e9SAndroid Build Coastguard Worker                    fn,
6054*da0073e9SAndroid Build Coastguard Worker                    (a, b),
6055*da0073e9SAndroid Build Coastguard Worker                    masked=True,
6056*da0073e9SAndroid Build Coastguard Worker                    check_batched_grad=False,
6057*da0073e9SAndroid Build Coastguard Worker                    fast_mode=fast_mode,
6058*da0073e9SAndroid Build Coastguard Worker                )
6059*da0073e9SAndroid Build Coastguard Worker            )
6060*da0073e9SAndroid Build Coastguard Worker
6061*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
6062*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
6063*da0073e9SAndroid Build Coastguard Worker
6064*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6065*da0073e9SAndroid Build Coastguard Worker        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
6066*da0073e9SAndroid Build Coastguard Worker    )
6067*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_multiple_mkldnn_inputs(self):
6068*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
6069*da0073e9SAndroid Build Coastguard Worker            def fn(x, y):
6070*da0073e9SAndroid Build Coastguard Worker                return x + y.to_dense()
6071*da0073e9SAndroid Build Coastguard Worker
6072*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(10, requires_grad=True)
6073*da0073e9SAndroid Build Coastguard Worker            b = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
6074*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
6075*da0073e9SAndroid Build Coastguard Worker                gradcheck(
6076*da0073e9SAndroid Build Coastguard Worker                    fn, (a, b), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode
6077*da0073e9SAndroid Build Coastguard Worker                )
6078*da0073e9SAndroid Build Coastguard Worker            )
6079*da0073e9SAndroid Build Coastguard Worker
6080*da0073e9SAndroid Build Coastguard Worker            def fn2(x, y):
6081*da0073e9SAndroid Build Coastguard Worker                return x.to_dense() + y.to_dense()
6082*da0073e9SAndroid Build Coastguard Worker
6083*da0073e9SAndroid Build Coastguard Worker            c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
6084*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(
6085*da0073e9SAndroid Build Coastguard Worker                gradcheck(
6086*da0073e9SAndroid Build Coastguard Worker                    fn, (a, c), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode
6087*da0073e9SAndroid Build Coastguard Worker                )
6088*da0073e9SAndroid Build Coastguard Worker            )
6089*da0073e9SAndroid Build Coastguard Worker
6090*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
6091*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
6092*da0073e9SAndroid Build Coastguard Worker
6093*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_output_shape_or_dtype_depend_on_values(self):
6094*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
6095*da0073e9SAndroid Build Coastguard Worker            def fn(x):
6096*da0073e9SAndroid Build Coastguard Worker                if torch.all(x >= 1):
6097*da0073e9SAndroid Build Coastguard Worker                    return torch.cat([x, x])
6098*da0073e9SAndroid Build Coastguard Worker                else:
6099*da0073e9SAndroid Build Coastguard Worker                    return x
6100*da0073e9SAndroid Build Coastguard Worker
6101*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1, dtype=torch.double, requires_grad=True)
6102*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6103*da0073e9SAndroid Build Coastguard Worker                AssertionError,
6104*da0073e9SAndroid Build Coastguard Worker                "return outputs with the same shape when inputs are perturbed",
6105*da0073e9SAndroid Build Coastguard Worker            ):
6106*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(gradcheck(fn, (a,), fast_mode=fast_mode))
6107*da0073e9SAndroid Build Coastguard Worker
6108*da0073e9SAndroid Build Coastguard Worker            def fn2(x):
6109*da0073e9SAndroid Build Coastguard Worker                if torch.all(x >= 1):
6110*da0073e9SAndroid Build Coastguard Worker                    return x.to(torch.float32)
6111*da0073e9SAndroid Build Coastguard Worker                else:
6112*da0073e9SAndroid Build Coastguard Worker                    return x
6113*da0073e9SAndroid Build Coastguard Worker
6114*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6115*da0073e9SAndroid Build Coastguard Worker                AssertionError,
6116*da0073e9SAndroid Build Coastguard Worker                "return outputs with the same dtype when inputs are perturbed",
6117*da0073e9SAndroid Build Coastguard Worker            ):
6118*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(gradcheck(fn2, (a,), fast_mode=fast_mode))
6119*da0073e9SAndroid Build Coastguard Worker
6120*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
6121*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
6122*da0073e9SAndroid Build Coastguard Worker
6123*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_complex_non_complex_outputs(self):
6124*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6125*da0073e9SAndroid Build Coastguard Worker            z = torch.complex(x, y)
6126*da0073e9SAndroid Build Coastguard Worker            return z, x + 1
6127*da0073e9SAndroid Build Coastguard Worker
6128*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
6129*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(2, 2, requires_grad=True, dtype=torch.float64)
6130*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(fn, (a, b)))
6131*da0073e9SAndroid Build Coastguard Worker
6132*da0073e9SAndroid Build Coastguard Worker        def fn2(z):
6133*da0073e9SAndroid Build Coastguard Worker            return z, torch.real(z)
6134*da0073e9SAndroid Build Coastguard Worker
6135*da0073e9SAndroid Build Coastguard Worker        c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128)
6136*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(gradcheck(fn2, (c)))
6137*da0073e9SAndroid Build Coastguard Worker
6138*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_get_numerical_jacobian(self):
6139*da0073e9SAndroid Build Coastguard Worker        # get_numerical_jacobian is deprecated and no longer used internally by gradcheck
6140*da0073e9SAndroid Build Coastguard Worker        from torch.autograd.gradcheck import get_numerical_jacobian
6141*da0073e9SAndroid Build Coastguard Worker
6142*da0073e9SAndroid Build Coastguard Worker        def fn(inputs):
6143*da0073e9SAndroid Build Coastguard Worker            # get_numerical_jacobian requires fn to take inputs as a tuple
6144*da0073e9SAndroid Build Coastguard Worker            # and returns the jacobian wrt the first output
6145*da0073e9SAndroid Build Coastguard Worker            x = inputs[0]
6146*da0073e9SAndroid Build Coastguard Worker            y = inputs[1]
6147*da0073e9SAndroid Build Coastguard Worker            return 2 * x + y, x + 2 * y
6148*da0073e9SAndroid Build Coastguard Worker
6149*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6150*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6151*da0073e9SAndroid Build Coastguard Worker
6152*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
6153*da0073e9SAndroid Build Coastguard Worker            FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API"
6154*da0073e9SAndroid Build Coastguard Worker        ):
6155*da0073e9SAndroid Build Coastguard Worker            jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6)
6156*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
6157*da0073e9SAndroid Build Coastguard Worker
6158*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
6159*da0073e9SAndroid Build Coastguard Worker            FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API"
6160*da0073e9SAndroid Build Coastguard Worker        ):
6161*da0073e9SAndroid Build Coastguard Worker            jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6)
6162*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double))
6163*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jacobian[1], 1 * torch.eye(4, dtype=torch.double))
6164*da0073e9SAndroid Build Coastguard Worker
6165*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
6166*da0073e9SAndroid Build Coastguard Worker            jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=2.0)
6167*da0073e9SAndroid Build Coastguard Worker
6168*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_get_analytical_jacobian(self):
6169*da0073e9SAndroid Build Coastguard Worker        from torch.autograd.gradcheck import get_analytical_jacobian
6170*da0073e9SAndroid Build Coastguard Worker
6171*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6172*da0073e9SAndroid Build Coastguard Worker            return 2 * x + y, x + 2 * y
6173*da0073e9SAndroid Build Coastguard Worker
6174*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6175*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64)
6176*da0073e9SAndroid Build Coastguard Worker
6177*da0073e9SAndroid Build Coastguard Worker        outputs = fn(a, b)
6178*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
6179*da0073e9SAndroid Build Coastguard Worker            FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API"
6180*da0073e9SAndroid Build Coastguard Worker        ):
6181*da0073e9SAndroid Build Coastguard Worker            (
6182*da0073e9SAndroid Build Coastguard Worker                jacobians,
6183*da0073e9SAndroid Build Coastguard Worker                reentrant,
6184*da0073e9SAndroid Build Coastguard Worker                correct_grad_sizes,
6185*da0073e9SAndroid Build Coastguard Worker                correct_grad_types,
6186*da0073e9SAndroid Build Coastguard Worker            ) = get_analytical_jacobian((a, b), outputs[0])
6187*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jacobians[0], 2 * torch.eye(4, dtype=torch.double))
6188*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jacobians[1], 1 * torch.eye(4, dtype=torch.double))
6189*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(reentrant)
6190*da0073e9SAndroid Build Coastguard Worker
6191*da0073e9SAndroid Build Coastguard Worker        class NonDetFunc(Function):
6192*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6193*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, jitter=0.0):
6194*da0073e9SAndroid Build Coastguard Worker                ctx._jitter = jitter
6195*da0073e9SAndroid Build Coastguard Worker                return x
6196*da0073e9SAndroid Build Coastguard Worker
6197*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6198*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_out):
6199*da0073e9SAndroid Build Coastguard Worker                return (
6200*da0073e9SAndroid Build Coastguard Worker                    NonDetFunc.apply(grad_out, ctx._jitter)
6201*da0073e9SAndroid Build Coastguard Worker                    * (1 + torch.rand_like(grad_out) * ctx._jitter),
6202*da0073e9SAndroid Build Coastguard Worker                    None,
6203*da0073e9SAndroid Build Coastguard Worker                )
6204*da0073e9SAndroid Build Coastguard Worker
6205*da0073e9SAndroid Build Coastguard Worker        outputs = NonDetFunc.apply(a, 1e-6)
6206*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(
6207*da0073e9SAndroid Build Coastguard Worker            FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API"
6208*da0073e9SAndroid Build Coastguard Worker        ):
6209*da0073e9SAndroid Build Coastguard Worker            (
6210*da0073e9SAndroid Build Coastguard Worker                jacobians,
6211*da0073e9SAndroid Build Coastguard Worker                reentrant,
6212*da0073e9SAndroid Build Coastguard Worker                correct_grad_sizes,
6213*da0073e9SAndroid Build Coastguard Worker                correct_grad_types,
6214*da0073e9SAndroid Build Coastguard Worker            ) = get_analytical_jacobian((a,), outputs)
6215*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(reentrant)
6216*da0073e9SAndroid Build Coastguard Worker
6217*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"):
6218*da0073e9SAndroid Build Coastguard Worker            jacobians, _, _, _ = get_analytical_jacobian((a,), outputs, grad_out=2.0)
6219*da0073e9SAndroid Build Coastguard Worker
6220*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_custom_error(self):
6221*da0073e9SAndroid Build Coastguard Worker        from torch.autograd.gradcheck import GradcheckError
6222*da0073e9SAndroid Build Coastguard Worker
6223*da0073e9SAndroid Build Coastguard Worker        def check(fast_mode):
6224*da0073e9SAndroid Build Coastguard Worker            def fn(x):
6225*da0073e9SAndroid Build Coastguard Worker                y = x.clone()
6226*da0073e9SAndroid Build Coastguard Worker                y.register_hook(lambda x: x + 1e-2)
6227*da0073e9SAndroid Build Coastguard Worker                return y
6228*da0073e9SAndroid Build Coastguard Worker
6229*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(2, 2, requires_grad=True)
6230*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6231*da0073e9SAndroid Build Coastguard Worker                GradcheckError, "Jacobian mismatch for output 0 with respect to input 0"
6232*da0073e9SAndroid Build Coastguard Worker            ):
6233*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x,), fast_mode=fast_mode)
6234*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6235*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Jacobian mismatch for output 0 with respect to input 0"
6236*da0073e9SAndroid Build Coastguard Worker            ):
6237*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x,), fast_mode=fast_mode)
6238*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(
6239*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode)
6240*da0073e9SAndroid Build Coastguard Worker            )
6241*da0073e9SAndroid Build Coastguard Worker
6242*da0073e9SAndroid Build Coastguard Worker            def fn2(x):
6243*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Not a GradcheckError!")
6244*da0073e9SAndroid Build Coastguard Worker
6245*da0073e9SAndroid Build Coastguard Worker            # Checks that when raise_exception=False, non-GradcheckErrors are not caught by gradcheck
6246*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Not a GradcheckError!"):
6247*da0073e9SAndroid Build Coastguard Worker                gradcheck(fn2, (x,), fast_mode=fast_mode, raise_exception=False)
6248*da0073e9SAndroid Build Coastguard Worker
6249*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=True)
6250*da0073e9SAndroid Build Coastguard Worker        check(fast_mode=False)
6251*da0073e9SAndroid Build Coastguard Worker
6252*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_forward_ad(self):
6253*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
6254*da0073e9SAndroid Build Coastguard Worker            return x + y, y
6255*da0073e9SAndroid Build Coastguard Worker
6256*da0073e9SAndroid Build Coastguard Worker        def bad_fn(x, y):
6257*da0073e9SAndroid Build Coastguard Worker            # Hacky way to check if we're currently inside a forward ad level
6258*da0073e9SAndroid Build Coastguard Worker            is_running_forward_ad = fwAD._current_level >= 0
6259*da0073e9SAndroid Build Coastguard Worker
6260*da0073e9SAndroid Build Coastguard Worker            if is_running_forward_ad:
6261*da0073e9SAndroid Build Coastguard Worker                y_p, y_d = fwAD.unpack_dual(y)
6262*da0073e9SAndroid Build Coastguard Worker                y = fwAD.make_dual(y_p, y_d * 1.1)
6263*da0073e9SAndroid Build Coastguard Worker
6264*da0073e9SAndroid Build Coastguard Worker            return x + y, y
6265*da0073e9SAndroid Build Coastguard Worker
6266*da0073e9SAndroid Build Coastguard Worker        err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1"
6267*da0073e9SAndroid Build Coastguard Worker
6268*da0073e9SAndroid Build Coastguard Worker        for fast_mode in [True, False]:
6269*da0073e9SAndroid Build Coastguard Worker            # Test for all inputs and outputs being real
6270*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(2, dtype=torch.double, requires_grad=True)
6271*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(2, dtype=torch.double, requires_grad=True)
6272*da0073e9SAndroid Build Coastguard Worker
6273*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6274*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err_msg):
6275*da0073e9SAndroid Build Coastguard Worker                gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6276*da0073e9SAndroid Build Coastguard Worker
6277*da0073e9SAndroid Build Coastguard Worker            def basic_mul(x):
6278*da0073e9SAndroid Build Coastguard Worker                return torch.view_as_real(torch.resolve_conj(x * 1j))
6279*da0073e9SAndroid Build Coastguard Worker
6280*da0073e9SAndroid Build Coastguard Worker            gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode)
6281*da0073e9SAndroid Build Coastguard Worker
6282*da0073e9SAndroid Build Coastguard Worker            # Test for one input and one output being complex
6283*da0073e9SAndroid Build Coastguard Worker            x = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
6284*da0073e9SAndroid Build Coastguard Worker
6285*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6286*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err_msg):
6287*da0073e9SAndroid Build Coastguard Worker                gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6288*da0073e9SAndroid Build Coastguard Worker
6289*da0073e9SAndroid Build Coastguard Worker            # Test for all inputs and outputs being complex
6290*da0073e9SAndroid Build Coastguard Worker            y = torch.rand(2, dtype=torch.cdouble, requires_grad=True)
6291*da0073e9SAndroid Build Coastguard Worker
6292*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6293*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err_msg):
6294*da0073e9SAndroid Build Coastguard Worker                gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode)
6295*da0073e9SAndroid Build Coastguard Worker
6296*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_forward_ad_runs_with_no_requires_grad(self):
6297*da0073e9SAndroid Build Coastguard Worker        # Currently requires_grad is used as a easy way for gradcheck to know
6298*da0073e9SAndroid Build Coastguard Worker        # which inputs of the function are meant to be differentiable
6299*da0073e9SAndroid Build Coastguard Worker        # This test checks that when the inputs are passed to the function they should not have
6300*da0073e9SAndroid Build Coastguard Worker        # requires_grad=True even though they may have requires_grad=True when passed
6301*da0073e9SAndroid Build Coastguard Worker        # to gradcheck
6302*da0073e9SAndroid Build Coastguard Worker        class UserFn(Function):
6303*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6304*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
6305*da0073e9SAndroid Build Coastguard Worker                if fwAD._current_level >= 0:
6306*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(x.requires_grad)
6307*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(y.requires_grad)
6308*da0073e9SAndroid Build Coastguard Worker                return x.clone(), y.clone()
6309*da0073e9SAndroid Build Coastguard Worker
6310*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6311*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_t, y_t):
6312*da0073e9SAndroid Build Coastguard Worker                return x_t, y_t
6313*da0073e9SAndroid Build Coastguard Worker
6314*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, dtype=torch.double, requires_grad=True)
6315*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, dtype=torch.double, requires_grad=True)
6316*da0073e9SAndroid Build Coastguard Worker
6317*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6318*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6319*da0073e9SAndroid Build Coastguard Worker            (x, y),
6320*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6321*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=False,
6322*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6323*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6324*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=False,
6325*da0073e9SAndroid Build Coastguard Worker        )
6326*da0073e9SAndroid Build Coastguard Worker
6327*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6328*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6329*da0073e9SAndroid Build Coastguard Worker            (x, y),
6330*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6331*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=True,
6332*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6333*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6334*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=False,
6335*da0073e9SAndroid Build Coastguard Worker        )
6336*da0073e9SAndroid Build Coastguard Worker
6337*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6338*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6339*da0073e9SAndroid Build Coastguard Worker            (x, y),
6340*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6341*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=True,
6342*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6343*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6344*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=True,
6345*da0073e9SAndroid Build Coastguard Worker        )
6346*da0073e9SAndroid Build Coastguard Worker
6347*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, dtype=torch.double, requires_grad=True)
6348*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(2, dtype=torch.double, requires_grad=False)
6349*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6350*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6351*da0073e9SAndroid Build Coastguard Worker            (x, y),
6352*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6353*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=True,
6354*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6355*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6356*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=True,
6357*da0073e9SAndroid Build Coastguard Worker        )
6358*da0073e9SAndroid Build Coastguard Worker
6359*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_forward_ad_respects_requires_grad(self):
6360*da0073e9SAndroid Build Coastguard Worker        # Currently requires_grad is used as a easy way for gradcheck to know
6361*da0073e9SAndroid Build Coastguard Worker        # which inputs of the function are meant to be differentiable
6362*da0073e9SAndroid Build Coastguard Worker        jvp_count = [0]
6363*da0073e9SAndroid Build Coastguard Worker
6364*da0073e9SAndroid Build Coastguard Worker        class UserFn(Function):
6365*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6366*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
6367*da0073e9SAndroid Build Coastguard Worker                return x.clone(), y.clone()
6368*da0073e9SAndroid Build Coastguard Worker
6369*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6370*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_t, y_t):
6371*da0073e9SAndroid Build Coastguard Worker                jvp_count[0] += 1
6372*da0073e9SAndroid Build Coastguard Worker                return x_t, y_t
6373*da0073e9SAndroid Build Coastguard Worker
6374*da0073e9SAndroid Build Coastguard Worker        # NB: In slow gradcheck we need to loop through numel times so use numel = 1 to ensure
6375*da0073e9SAndroid Build Coastguard Worker        #     that fast and slow have the same counts
6376*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, dtype=torch.double, requires_grad=True)
6377*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, dtype=torch.double, requires_grad=True)
6378*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6379*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6380*da0073e9SAndroid Build Coastguard Worker            (x, y),
6381*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6382*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=False,
6383*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6384*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6385*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=False,
6386*da0073e9SAndroid Build Coastguard Worker        )
6387*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jvp_count[0], 2)  # (2) once per input
6388*da0073e9SAndroid Build Coastguard Worker        jvp_count = [0]
6389*da0073e9SAndroid Build Coastguard Worker
6390*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6391*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6392*da0073e9SAndroid Build Coastguard Worker            (x, y),
6393*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6394*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=True,
6395*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6396*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6397*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=False,
6398*da0073e9SAndroid Build Coastguard Worker        )
6399*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6400*da0073e9SAndroid Build Coastguard Worker            jvp_count[0], 6
6401*da0073e9SAndroid Build Coastguard Worker        )  # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2)
6402*da0073e9SAndroid Build Coastguard Worker        jvp_count = [0]
6403*da0073e9SAndroid Build Coastguard Worker
6404*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6405*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6406*da0073e9SAndroid Build Coastguard Worker            (x, y),
6407*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6408*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=True,
6409*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6410*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6411*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=True,
6412*da0073e9SAndroid Build Coastguard Worker        )
6413*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6414*da0073e9SAndroid Build Coastguard Worker            jvp_count[0], 12
6415*da0073e9SAndroid Build Coastguard Worker        )  # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2)
6416*da0073e9SAndroid Build Coastguard Worker        jvp_count = [0]
6417*da0073e9SAndroid Build Coastguard Worker
6418*da0073e9SAndroid Build Coastguard Worker        # Repeat the previous test except we mark one input with requires_grad=False
6419*da0073e9SAndroid Build Coastguard Worker        # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)!
6420*da0073e9SAndroid Build Coastguard Worker        #     Otherwise, other counts are halved.
6421*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, dtype=torch.double, requires_grad=True)
6422*da0073e9SAndroid Build Coastguard Worker        y = torch.rand(1, dtype=torch.double, requires_grad=False)
6423*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6424*da0073e9SAndroid Build Coastguard Worker            UserFn.apply,
6425*da0073e9SAndroid Build Coastguard Worker            (x, y),
6426*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6427*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=True,
6428*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6429*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6430*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=True,
6431*da0073e9SAndroid Build Coastguard Worker        )
6432*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jvp_count[0], 5)  # 1 + 1 + 3
6433*da0073e9SAndroid Build Coastguard Worker
6434*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_check_forward_or_backward_only(self):
6435*da0073e9SAndroid Build Coastguard Worker        """Depending on settings for check_forward_ad and check_backward_ad, the
6436*da0073e9SAndroid Build Coastguard Worker        correct codepaths should be reached (or not reached)
6437*da0073e9SAndroid Build Coastguard Worker        """
6438*da0073e9SAndroid Build Coastguard Worker        fwd_fail_err_msg = "FAIL FWD"
6439*da0073e9SAndroid Build Coastguard Worker        bwd_fail_err_msg = "FAIL BWD"
6440*da0073e9SAndroid Build Coastguard Worker
6441*da0073e9SAndroid Build Coastguard Worker        class UserFn(Function):
6442*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6443*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo, fwd_bad, bwd_bad):
6444*da0073e9SAndroid Build Coastguard Worker                ctx.fwd_bad = fwd_bad
6445*da0073e9SAndroid Build Coastguard Worker                ctx.bwd_bad = bwd_bad
6446*da0073e9SAndroid Build Coastguard Worker                return foo * 2
6447*da0073e9SAndroid Build Coastguard Worker
6448*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6449*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, gO):
6450*da0073e9SAndroid Build Coastguard Worker                if ctx.bwd_bad:
6451*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(bwd_fail_err_msg)
6452*da0073e9SAndroid Build Coastguard Worker                else:
6453*da0073e9SAndroid Build Coastguard Worker                    return 2 * gO, None, None
6454*da0073e9SAndroid Build Coastguard Worker
6455*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6456*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, gI, _1, _2):
6457*da0073e9SAndroid Build Coastguard Worker                if ctx.fwd_bad:
6458*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError(fwd_fail_err_msg)
6459*da0073e9SAndroid Build Coastguard Worker                else:
6460*da0073e9SAndroid Build Coastguard Worker                    return 2 * gI
6461*da0073e9SAndroid Build Coastguard Worker
6462*da0073e9SAndroid Build Coastguard Worker        for fast_mode in (True, False):
6463*da0073e9SAndroid Build Coastguard Worker            for check_forward_ad in (True, False):
6464*da0073e9SAndroid Build Coastguard Worker                for check_backward_ad in (True, False):
6465*da0073e9SAndroid Build Coastguard Worker                    for fwd_bad in (True, False):
6466*da0073e9SAndroid Build Coastguard Worker                        for bwd_bad in (True, False):
6467*da0073e9SAndroid Build Coastguard Worker                            fwd_should_fail = fwd_bad and check_forward_ad
6468*da0073e9SAndroid Build Coastguard Worker                            bwd_should_fail = bwd_bad and check_backward_ad
6469*da0073e9SAndroid Build Coastguard Worker
6470*da0073e9SAndroid Build Coastguard Worker                            def run():
6471*da0073e9SAndroid Build Coastguard Worker                                gradcheck(
6472*da0073e9SAndroid Build Coastguard Worker                                    UserFn.apply,
6473*da0073e9SAndroid Build Coastguard Worker                                    (x, fwd_bad, bwd_bad),
6474*da0073e9SAndroid Build Coastguard Worker                                    check_forward_ad=check_forward_ad,
6475*da0073e9SAndroid Build Coastguard Worker                                    check_backward_ad=check_backward_ad,
6476*da0073e9SAndroid Build Coastguard Worker                                    check_undefined_grad=check_backward_ad,
6477*da0073e9SAndroid Build Coastguard Worker                                    check_batched_grad=check_backward_ad,
6478*da0073e9SAndroid Build Coastguard Worker                                    fast_mode=fast_mode,
6479*da0073e9SAndroid Build Coastguard Worker                                )
6480*da0073e9SAndroid Build Coastguard Worker
6481*da0073e9SAndroid Build Coastguard Worker                            x = torch.rand(2, dtype=torch.double, requires_grad=True)
6482*da0073e9SAndroid Build Coastguard Worker
6483*da0073e9SAndroid Build Coastguard Worker                            if not check_forward_ad and not check_backward_ad:
6484*da0073e9SAndroid Build Coastguard Worker                                with self.assertRaisesRegex(
6485*da0073e9SAndroid Build Coastguard Worker                                    AssertionError, "Expected at least one of"
6486*da0073e9SAndroid Build Coastguard Worker                                ):
6487*da0073e9SAndroid Build Coastguard Worker                                    run()
6488*da0073e9SAndroid Build Coastguard Worker                                continue
6489*da0073e9SAndroid Build Coastguard Worker
6490*da0073e9SAndroid Build Coastguard Worker                            if not fwd_should_fail and not bwd_should_fail:
6491*da0073e9SAndroid Build Coastguard Worker                                run()
6492*da0073e9SAndroid Build Coastguard Worker                            else:
6493*da0073e9SAndroid Build Coastguard Worker                                # If both fail, backward AD failure "hides" forward AD failure
6494*da0073e9SAndroid Build Coastguard Worker                                if fwd_should_fail:
6495*da0073e9SAndroid Build Coastguard Worker                                    fail_msg = fwd_fail_err_msg
6496*da0073e9SAndroid Build Coastguard Worker                                if bwd_should_fail:
6497*da0073e9SAndroid Build Coastguard Worker                                    fail_msg = bwd_fail_err_msg
6498*da0073e9SAndroid Build Coastguard Worker                                with self.assertRaisesRegex(RuntimeError, fail_msg):
6499*da0073e9SAndroid Build Coastguard Worker                                    run()
6500*da0073e9SAndroid Build Coastguard Worker
6501*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_forward_ad_batched_grad(self):
6502*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(2, dtype=torch.double, requires_grad=True)
6503*da0073e9SAndroid Build Coastguard Worker
6504*da0073e9SAndroid Build Coastguard Worker        # multiple inputs and outputs with non-tensors inputs
6505*da0073e9SAndroid Build Coastguard Worker        def fn1(a: torch.Tensor, b: int):
6506*da0073e9SAndroid Build Coastguard Worker            return a.clone(), a + 1
6507*da0073e9SAndroid Build Coastguard Worker
6508*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6509*da0073e9SAndroid Build Coastguard Worker            fn1,
6510*da0073e9SAndroid Build Coastguard Worker            (x, 1),
6511*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6512*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6513*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6514*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=False,
6515*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=True,
6516*da0073e9SAndroid Build Coastguard Worker        )
6517*da0073e9SAndroid Build Coastguard Worker
6518*da0073e9SAndroid Build Coastguard Worker        # unrelated inputs: tangent for c is None
6519*da0073e9SAndroid Build Coastguard Worker        def fn2(a: torch.Tensor, c: torch.Tensor):
6520*da0073e9SAndroid Build Coastguard Worker            return a.clone()
6521*da0073e9SAndroid Build Coastguard Worker
6522*da0073e9SAndroid Build Coastguard Worker        gradcheck(
6523*da0073e9SAndroid Build Coastguard Worker            fn2,
6524*da0073e9SAndroid Build Coastguard Worker            (x, x.clone()),
6525*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
6526*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
6527*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
6528*da0073e9SAndroid Build Coastguard Worker            check_undefined_grad=False,
6529*da0073e9SAndroid Build Coastguard Worker            check_batched_forward_grad=True,
6530*da0073e9SAndroid Build Coastguard Worker        )
6531*da0073e9SAndroid Build Coastguard Worker
6532*da0073e9SAndroid Build Coastguard Worker        class Fn(Function):
6533*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6534*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo):
6535*da0073e9SAndroid Build Coastguard Worker                return foo * 2
6536*da0073e9SAndroid Build Coastguard Worker
6537*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6538*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, gO):
6539*da0073e9SAndroid Build Coastguard Worker                return gO * 2
6540*da0073e9SAndroid Build Coastguard Worker
6541*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6542*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, gI):
6543*da0073e9SAndroid Build Coastguard Worker                torch.randn_like(gI)
6544*da0073e9SAndroid Build Coastguard Worker                return gI * 2
6545*da0073e9SAndroid Build Coastguard Worker
6546*da0073e9SAndroid Build Coastguard Worker        msg = "vmap: We do not yet support calling random operations inside of vmap"
6547*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
6548*da0073e9SAndroid Build Coastguard Worker            gradcheck(
6549*da0073e9SAndroid Build Coastguard Worker                Fn.apply, (x,), check_forward_ad=True, check_batched_forward_grad=True
6550*da0073e9SAndroid Build Coastguard Worker            )
6551*da0073e9SAndroid Build Coastguard Worker
6552*da0073e9SAndroid Build Coastguard Worker    def test_version_counter(self):
6553*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 2)
6554*da0073e9SAndroid Build Coastguard Worker
6555*da0073e9SAndroid Build Coastguard Worker        # In-place op bumps version
6556*da0073e9SAndroid Build Coastguard Worker        x_saved_version = x._version
6557*da0073e9SAndroid Build Coastguard Worker        x.add_(1).add_(1)
6558*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version > x_saved_version)
6559*da0073e9SAndroid Build Coastguard Worker
6560*da0073e9SAndroid Build Coastguard Worker        # Differentiable view shares version counter
6561*da0073e9SAndroid Build Coastguard Worker        xz = x[:]
6562*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version == xz._version)
6563*da0073e9SAndroid Build Coastguard Worker        xz.add_(1)
6564*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version == xz._version)
6565*da0073e9SAndroid Build Coastguard Worker
6566*da0073e9SAndroid Build Coastguard Worker        # `x.data = y` preserves version counter of `x`
6567*da0073e9SAndroid Build Coastguard Worker        x_saved_version = x._version
6568*da0073e9SAndroid Build Coastguard Worker        x.data = torch.randn(2, 3)
6569*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version == x_saved_version)
6570*da0073e9SAndroid Build Coastguard Worker        x.add_(1)
6571*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version > x_saved_version)
6572*da0073e9SAndroid Build Coastguard Worker        # Make sure `x` is still using the same version counter it shares with `xz`
6573*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version == xz._version)
6574*da0073e9SAndroid Build Coastguard Worker
6575*da0073e9SAndroid Build Coastguard Worker        # In-place op on `xz` also updates version of `x`,
6576*da0073e9SAndroid Build Coastguard Worker        # because they share the version counter
6577*da0073e9SAndroid Build Coastguard Worker        xz.add_(1)
6578*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x._version == xz._version)
6579*da0073e9SAndroid Build Coastguard Worker
6580*da0073e9SAndroid Build Coastguard Worker    def test_set_data_tensorimpl_type(self):
6581*da0073e9SAndroid Build Coastguard Worker        # Dense tensor has impl of type `TensorImpl`, while sparse tensor has impl
6582*da0073e9SAndroid Build Coastguard Worker        # of type `SparseTensorImpl`.
6583*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, 2)
6584*da0073e9SAndroid Build Coastguard Worker        x_s = torch.sparse_coo_tensor(torch.zeros([1, 1]), torch.ones([1]))
6585*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "incompatible tensor type"):
6586*da0073e9SAndroid Build Coastguard Worker            x.data = x_s
6587*da0073e9SAndroid Build Coastguard Worker
6588*da0073e9SAndroid Build Coastguard Worker    def test_set_data_preserve_pyobj(self):
6589*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 2)
6590*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1, 2)
6591*da0073e9SAndroid Build Coastguard Worker        b_id_saved = id(b)
6592*da0073e9SAndroid Build Coastguard Worker        b.data = a
6593*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(b_id_saved == id(b))
6594*da0073e9SAndroid Build Coastguard Worker
6595*da0073e9SAndroid Build Coastguard Worker    def test_set_data_self_requires_grad(self):
6596*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
6597*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(2.0)
6598*da0073e9SAndroid Build Coastguard Worker        c = torch.tensor(3, dtype=torch.int64)
6599*da0073e9SAndroid Build Coastguard Worker        a.data = b
6600*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6601*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "must be floating point or complex dtype"
6602*da0073e9SAndroid Build Coastguard Worker        ):
6603*da0073e9SAndroid Build Coastguard Worker            a.data = c
6604*da0073e9SAndroid Build Coastguard Worker
6605*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows")
6606*da0073e9SAndroid Build Coastguard Worker    def test_thread_shutdown(self):
6607*da0073e9SAndroid Build Coastguard Worker        code = """import torch
6608*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Function
6609*da0073e9SAndroid Build Coastguard Workerclass MyFunction(Function):
6610*da0073e9SAndroid Build Coastguard Worker    @staticmethod
6611*da0073e9SAndroid Build Coastguard Worker    def forward(ctx, x):
6612*da0073e9SAndroid Build Coastguard Worker        return x
6613*da0073e9SAndroid Build Coastguard Worker
6614*da0073e9SAndroid Build Coastguard Worker    @staticmethod
6615*da0073e9SAndroid Build Coastguard Worker    def backward(ctx, grad):
6616*da0073e9SAndroid Build Coastguard Worker        return grad
6617*da0073e9SAndroid Build Coastguard Worker
6618*da0073e9SAndroid Build Coastguard Worker# Run on cuda if it is available to ensure that the worker thread
6619*da0073e9SAndroid Build Coastguard Worker# is properly initialized by the time we exit.
6620*da0073e9SAndroid Build Coastguard Workerdevice = "cuda" if torch.cuda.is_available() else "cpu"
6621*da0073e9SAndroid Build Coastguard Worker
6622*da0073e9SAndroid Build Coastguard Workerfor shape in [(1,), ()]:
6623*da0073e9SAndroid Build Coastguard Worker    v = torch.ones(shape, requires_grad=True, device=device)
6624*da0073e9SAndroid Build Coastguard Worker    MyFunction.apply(v).backward()
6625*da0073e9SAndroid Build Coastguard Worker"""
6626*da0073e9SAndroid Build Coastguard Worker        s = TestCase.runWithPytorchAPIUsageStderr(code)
6627*da0073e9SAndroid Build Coastguard Worker        # The autograd engine creates worker threads only when GPU devices are present.
6628*da0073e9SAndroid Build Coastguard Worker        # So make sure that we do shutdown threads when we're testing cuda and make sure
6629*da0073e9SAndroid Build Coastguard Worker        # that there is no thread to shutdown when we're not using cuda.
6630*da0073e9SAndroid Build Coastguard Worker        if TEST_CUDA or torch.backends.mps.is_available() or torch.xpu.is_available():
6631*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
6632*da0073e9SAndroid Build Coastguard Worker        else:
6633*da0073e9SAndroid Build Coastguard Worker            self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown")
6634*da0073e9SAndroid Build Coastguard Worker
6635*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6636*da0073e9SAndroid Build Coastguard Worker        IS_MACOS,
6637*da0073e9SAndroid Build Coastguard Worker        "Fails with SIGBUS on macOS; https://github.com/pytorch/pytorch/issues/25941",
6638*da0073e9SAndroid Build Coastguard Worker    )
6639*da0073e9SAndroid Build Coastguard Worker    def test_deep_reentrant(self):
6640*da0073e9SAndroid Build Coastguard Worker        class DeepReentrant(Function):
6641*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6642*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
6643*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
6644*da0073e9SAndroid Build Coastguard Worker                    ctx.x = Variable(x.detach(), requires_grad=True)
6645*da0073e9SAndroid Build Coastguard Worker                    ctx.x = ctx.x - 1
6646*da0073e9SAndroid Build Coastguard Worker                return ctx.x.detach()
6647*da0073e9SAndroid Build Coastguard Worker
6648*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6649*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
6650*da0073e9SAndroid Build Coastguard Worker                if ctx.x < 0:
6651*da0073e9SAndroid Build Coastguard Worker                    return x
6652*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
6653*da0073e9SAndroid Build Coastguard Worker                    DeepReentrant.apply(ctx.x).sum().backward()
6654*da0073e9SAndroid Build Coastguard Worker                return x
6655*da0073e9SAndroid Build Coastguard Worker
6656*da0073e9SAndroid Build Coastguard Worker        # Test stack overflow escape mechanism
6657*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor(2000.0, requires_grad=True)
6658*da0073e9SAndroid Build Coastguard Worker        # This will cause stack overflow if reentrant calls are handled
6659*da0073e9SAndroid Build Coastguard Worker        # in the same thread recursively
6660*da0073e9SAndroid Build Coastguard Worker        DeepReentrant.apply(v).sum().backward()
6661*da0073e9SAndroid Build Coastguard Worker
6662*da0073e9SAndroid Build Coastguard Worker        # Test stack overflow escape mechanism multiple times
6663*da0073e9SAndroid Build Coastguard Worker        # to ensure reusing workers in the pool works fine
6664*da0073e9SAndroid Build Coastguard Worker        v2 = torch.tensor(200.0, requires_grad=True)
6665*da0073e9SAndroid Build Coastguard Worker        DeepReentrant.apply(v2).sum().backward()
6666*da0073e9SAndroid Build Coastguard Worker
6667*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_priority(self):
6668*da0073e9SAndroid Build Coastguard Worker        order = []
6669*da0073e9SAndroid Build Coastguard Worker
6670*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
6671*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6672*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
6673*da0073e9SAndroid Build Coastguard Worker                return x
6674*da0073e9SAndroid Build Coastguard Worker
6675*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6676*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
6677*da0073e9SAndroid Build Coastguard Worker                order.append("MyFunction")
6678*da0073e9SAndroid Build Coastguard Worker                return x
6679*da0073e9SAndroid Build Coastguard Worker
6680*da0073e9SAndroid Build Coastguard Worker        class Reentrant(Function):
6681*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6682*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
6683*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
6684*da0073e9SAndroid Build Coastguard Worker                    ctx.x = Variable(x.detach(), requires_grad=True)
6685*da0073e9SAndroid Build Coastguard Worker                    ctx.x = ctx.x - 1
6686*da0073e9SAndroid Build Coastguard Worker                return ctx.x.detach()
6687*da0073e9SAndroid Build Coastguard Worker
6688*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6689*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
6690*da0073e9SAndroid Build Coastguard Worker                order.append("Reentrant")
6691*da0073e9SAndroid Build Coastguard Worker                if ctx.x < 0:
6692*da0073e9SAndroid Build Coastguard Worker                    return x
6693*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
6694*da0073e9SAndroid Build Coastguard Worker                    Reentrant.apply(ctx.x).backward()
6695*da0073e9SAndroid Build Coastguard Worker                return x
6696*da0073e9SAndroid Build Coastguard Worker
6697*da0073e9SAndroid Build Coastguard Worker        a = MyFunction.apply(torch.tensor(6.0, requires_grad=True))
6698*da0073e9SAndroid Build Coastguard Worker        b = Reentrant.apply(torch.tensor(9.0, requires_grad=True))
6699*da0073e9SAndroid Build Coastguard Worker        v = a * b
6700*da0073e9SAndroid Build Coastguard Worker        v.backward()
6701*da0073e9SAndroid Build Coastguard Worker        # The tasks for the Reentrant and MyFunction backward() will be added
6702*da0073e9SAndroid Build Coastguard Worker        # to the queue in the autograd engine at the same time. The backward
6703*da0073e9SAndroid Build Coastguard Worker        # for Reentrant will be executed first, which will then add other
6704*da0073e9SAndroid Build Coastguard Worker        # backward tasks to the queue. We want to ensure all the reentrant tasks
6705*da0073e9SAndroid Build Coastguard Worker        # are prioritized over the MyFunction backward task regardless of their
6706*da0073e9SAndroid Build Coastguard Worker        # sequence numbers
6707*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(order), 11)
6708*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(order.count("Reentrant"), 10)
6709*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(order[-1], "MyFunction")
6710*da0073e9SAndroid Build Coastguard Worker
6711*da0073e9SAndroid Build Coastguard Worker    @slowTest
6712*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing(self):
6713*da0073e9SAndroid Build Coastguard Worker        num_inp = 2000
6714*da0073e9SAndroid Build Coastguard Worker        nz_inp = 10
6715*da0073e9SAndroid Build Coastguard Worker        nz_out = 10
6716*da0073e9SAndroid Build Coastguard Worker        nz_bottleneck = 1000
6717*da0073e9SAndroid Build Coastguard Worker
6718*da0073e9SAndroid Build Coastguard Worker        # small proxy network for some complex reasoning we want to do per input
6719*da0073e9SAndroid Build Coastguard Worker        module = nn.Sequential(
6720*da0073e9SAndroid Build Coastguard Worker            nn.Linear(nz_inp, nz_bottleneck),
6721*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
6722*da0073e9SAndroid Build Coastguard Worker            nn.Linear(nz_bottleneck, nz_inp),
6723*da0073e9SAndroid Build Coastguard Worker        )
6724*da0073e9SAndroid Build Coastguard Worker
6725*da0073e9SAndroid Build Coastguard Worker        feat_combined = []
6726*da0073e9SAndroid Build Coastguard Worker        for r in range(num_inp):
6727*da0073e9SAndroid Build Coastguard Worker            data_r = torch.empty(1, nz_inp)
6728*da0073e9SAndroid Build Coastguard Worker            data_r.uniform_()
6729*da0073e9SAndroid Build Coastguard Worker            data_r.requires_grad = True
6730*da0073e9SAndroid Build Coastguard Worker            feat_r = checkpoint(module, data_r, use_reentrant=True)
6731*da0073e9SAndroid Build Coastguard Worker            feat_combined.append(feat_r)
6732*da0073e9SAndroid Build Coastguard Worker
6733*da0073e9SAndroid Build Coastguard Worker        # compute mean as a proxy for some joint reasoning
6734*da0073e9SAndroid Build Coastguard Worker        mean_combined = torch.stack(feat_combined).mean()
6735*da0073e9SAndroid Build Coastguard Worker        mean_combined.backward()
6736*da0073e9SAndroid Build Coastguard Worker
6737*da0073e9SAndroid Build Coastguard Worker    def _test_checkpointing_non_reentrant_autocast(self, device_type):
6738*da0073e9SAndroid Build Coastguard Worker        for enabled in [True, False]:
6739*da0073e9SAndroid Build Coastguard Worker
6740*da0073e9SAndroid Build Coastguard Worker            def foo(x, y, z):
6741*da0073e9SAndroid Build Coastguard Worker                # torch.mm is on autocast's list of ops that should run in
6742*da0073e9SAndroid Build Coastguard Worker                # the autocast precision
6743*da0073e9SAndroid Build Coastguard Worker                x = torch.mm(x, y)
6744*da0073e9SAndroid Build Coastguard Worker                y = torch.mm(x, z)
6745*da0073e9SAndroid Build Coastguard Worker                z = torch.mm(z, z)
6746*da0073e9SAndroid Build Coastguard Worker                expected_dtype = torch.float32 if not enabled else torch.bfloat16
6747*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected_dtype, z.dtype)
6748*da0073e9SAndroid Build Coastguard Worker                return z
6749*da0073e9SAndroid Build Coastguard Worker
6750*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, 3, requires_grad=True)
6751*da0073e9SAndroid Build Coastguard Worker            y = torch.randn(3, 3, requires_grad=True)
6752*da0073e9SAndroid Build Coastguard Worker            z = torch.randn(3, 3, requires_grad=True)
6753*da0073e9SAndroid Build Coastguard Worker            if device_type == "cuda":
6754*da0073e9SAndroid Build Coastguard Worker                x = x.cuda()
6755*da0073e9SAndroid Build Coastguard Worker                y = y.cuda()
6756*da0073e9SAndroid Build Coastguard Worker                z = z.cuda()
6757*da0073e9SAndroid Build Coastguard Worker
6758*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(
6759*da0073e9SAndroid Build Coastguard Worker                enabled=enabled, device_type=device_type, dtype=torch.bfloat16
6760*da0073e9SAndroid Build Coastguard Worker            ):
6761*da0073e9SAndroid Build Coastguard Worker                loss = checkpoint(foo, x, y, z, use_reentrant=False)
6762*da0073e9SAndroid Build Coastguard Worker                loss = loss.sum()
6763*da0073e9SAndroid Build Coastguard Worker
6764*da0073e9SAndroid Build Coastguard Worker            # Without saving + recasting the autocast type, would raise error in autograd
6765*da0073e9SAndroid Build Coastguard Worker            # about mismatched dtypes.
6766*da0073e9SAndroid Build Coastguard Worker            loss.backward()  # triggers recomputation to check it runs in bfloat
6767*da0073e9SAndroid Build Coastguard Worker
6768*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_non_reentrant_autocast_cpu(self):
6769*da0073e9SAndroid Build Coastguard Worker        """
6770*da0073e9SAndroid Build Coastguard Worker        Test that autocast args such as the dtype are preserved during non-reentrant
6771*da0073e9SAndroid Build Coastguard Worker        checkpoint recomputation on CPU.
6772*da0073e9SAndroid Build Coastguard Worker        """
6773*da0073e9SAndroid Build Coastguard Worker        self._test_checkpointing_non_reentrant_autocast(device_type="cpu")
6774*da0073e9SAndroid Build Coastguard Worker
6775*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(
6776*da0073e9SAndroid Build Coastguard Worker        not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(),
6777*da0073e9SAndroid Build Coastguard Worker        "Test requires CUDA bf16 support",
6778*da0073e9SAndroid Build Coastguard Worker    )
6779*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_non_reentrant_autocast_gpu(self):
6780*da0073e9SAndroid Build Coastguard Worker        """
6781*da0073e9SAndroid Build Coastguard Worker        Test that autocast args/kwargs such as the dtype are preserved during
6782*da0073e9SAndroid Build Coastguard Worker        non-reentrant checkpoint recomputation on GPU.
6783*da0073e9SAndroid Build Coastguard Worker        """
6784*da0073e9SAndroid Build Coastguard Worker        self._test_checkpointing_non_reentrant_autocast(device_type="cuda")
6785*da0073e9SAndroid Build Coastguard Worker
6786*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
6787*da0073e9SAndroid Build Coastguard Worker    @slowTest
6788*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_memory_savings(self):
6789*da0073e9SAndroid Build Coastguard Worker        class MyModel(nn.Module):
6790*da0073e9SAndroid Build Coastguard Worker            def __init__(self, n, use_checkpoint, use_reentrant):
6791*da0073e9SAndroid Build Coastguard Worker                super().__init__()
6792*da0073e9SAndroid Build Coastguard Worker                self.n = n
6793*da0073e9SAndroid Build Coastguard Worker                self.use_checkpoint = use_checkpoint
6794*da0073e9SAndroid Build Coastguard Worker                self.use_reentrant = use_reentrant
6795*da0073e9SAndroid Build Coastguard Worker                self.layers = nn.ModuleList()
6796*da0073e9SAndroid Build Coastguard Worker                for i in range(self.n):
6797*da0073e9SAndroid Build Coastguard Worker                    layer = nn.Sequential(
6798*da0073e9SAndroid Build Coastguard Worker                        nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256)
6799*da0073e9SAndroid Build Coastguard Worker                    )
6800*da0073e9SAndroid Build Coastguard Worker                    self.layers.append(layer)
6801*da0073e9SAndroid Build Coastguard Worker                # pre-allocate the grad so that increased memory usage is mainly
6802*da0073e9SAndroid Build Coastguard Worker                # due to activations.
6803*da0073e9SAndroid Build Coastguard Worker                for layer in self.layers:
6804*da0073e9SAndroid Build Coastguard Worker                    for lin in layer:
6805*da0073e9SAndroid Build Coastguard Worker                        lin.weight.grad = torch.ones_like(lin.weight)
6806*da0073e9SAndroid Build Coastguard Worker                        lin.bias.grad = torch.ones_like(lin.bias)
6807*da0073e9SAndroid Build Coastguard Worker
6808*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
6809*da0073e9SAndroid Build Coastguard Worker                for i in range(self.n):
6810*da0073e9SAndroid Build Coastguard Worker                    if not self.use_checkpoint:
6811*da0073e9SAndroid Build Coastguard Worker                        x = self.layers[i](x)
6812*da0073e9SAndroid Build Coastguard Worker                    else:
6813*da0073e9SAndroid Build Coastguard Worker                        x = checkpoint(
6814*da0073e9SAndroid Build Coastguard Worker                            self.layers[i], x, use_reentrant=self.use_reentrant
6815*da0073e9SAndroid Build Coastguard Worker                        )
6816*da0073e9SAndroid Build Coastguard Worker
6817*da0073e9SAndroid Build Coastguard Worker                return x
6818*da0073e9SAndroid Build Coastguard Worker
6819*da0073e9SAndroid Build Coastguard Worker        model_no_checkpoint = MyModel(
6820*da0073e9SAndroid Build Coastguard Worker            8, use_checkpoint=False, use_reentrant=False
6821*da0073e9SAndroid Build Coastguard Worker        ).cuda()
6822*da0073e9SAndroid Build Coastguard Worker        model_reentrant_checkpoint = MyModel(
6823*da0073e9SAndroid Build Coastguard Worker            8, use_checkpoint=True, use_reentrant=True
6824*da0073e9SAndroid Build Coastguard Worker        ).cuda()
6825*da0073e9SAndroid Build Coastguard Worker        model_no_reentrant_checkpoint = MyModel(
6826*da0073e9SAndroid Build Coastguard Worker            8, use_checkpoint=True, use_reentrant=False
6827*da0073e9SAndroid Build Coastguard Worker        ).cuda()
6828*da0073e9SAndroid Build Coastguard Worker
6829*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(100, 256, requires_grad=True, device="cuda")
6830*da0073e9SAndroid Build Coastguard Worker
6831*da0073e9SAndroid Build Coastguard Worker        torch.cuda.reset_peak_memory_stats()
6832*da0073e9SAndroid Build Coastguard Worker        loss = model_no_checkpoint(x.clone()).sum()
6833*da0073e9SAndroid Build Coastguard Worker        loss.backward()
6834*da0073e9SAndroid Build Coastguard Worker        mem_no_checkpoint = torch.cuda.max_memory_allocated()
6835*da0073e9SAndroid Build Coastguard Worker
6836*da0073e9SAndroid Build Coastguard Worker        torch.cuda.reset_peak_memory_stats()
6837*da0073e9SAndroid Build Coastguard Worker        loss = model_reentrant_checkpoint(x.clone()).sum()
6838*da0073e9SAndroid Build Coastguard Worker        loss.backward()
6839*da0073e9SAndroid Build Coastguard Worker        mem_reentrant_checkpoint = torch.cuda.max_memory_allocated()
6840*da0073e9SAndroid Build Coastguard Worker
6841*da0073e9SAndroid Build Coastguard Worker        torch.cuda.reset_peak_memory_stats()
6842*da0073e9SAndroid Build Coastguard Worker        loss = model_no_reentrant_checkpoint(x.clone()).sum()
6843*da0073e9SAndroid Build Coastguard Worker        loss.backward()
6844*da0073e9SAndroid Build Coastguard Worker        mem_no_reentrant_checkpoint = torch.cuda.max_memory_allocated()
6845*da0073e9SAndroid Build Coastguard Worker
6846*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mem_reentrant_checkpoint < mem_no_checkpoint)
6847*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(mem_no_reentrant_checkpoint < mem_no_checkpoint)
6848*da0073e9SAndroid Build Coastguard Worker
6849*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_custom_function_works(self):
6850*da0073e9SAndroid Build Coastguard Worker        msg = "Unpack is being triggered for a tensor that was already unpacked once"
6851*da0073e9SAndroid Build Coastguard Worker
6852*da0073e9SAndroid Build Coastguard Worker        class MyFunc(torch.autograd.Function):
6853*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6854*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y, z):
6855*da0073e9SAndroid Build Coastguard Worker                w = x * y * z
6856*da0073e9SAndroid Build Coastguard Worker                out = w + w
6857*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x, y, z, w, out)
6858*da0073e9SAndroid Build Coastguard Worker                return out
6859*da0073e9SAndroid Build Coastguard Worker
6860*da0073e9SAndroid Build Coastguard Worker            @staticmethod
6861*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_out):
6862*da0073e9SAndroid Build Coastguard Worker                x, y, z, w, out = ctx.saved_tensors
6863*da0073e9SAndroid Build Coastguard Worker                # Accessing the saved Tensors a second time will raise because
6864*da0073e9SAndroid Build Coastguard Worker                # recomputed tensors get cleared as soon as they are unpacked.
6865*da0073e9SAndroid Build Coastguard Worker                # A recomputation is only triggered if your backward has a new
6866*da0073e9SAndroid Build Coastguard Worker                # graph-task id.
6867*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, msg):
6868*da0073e9SAndroid Build Coastguard Worker                    x_2, y_2, z_2, w_2, out_2 = ctx.saved_tensors
6869*da0073e9SAndroid Build Coastguard Worker                return x, y, z
6870*da0073e9SAndroid Build Coastguard Worker
6871*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1.0, requires_grad=True)
6872*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(2.0, requires_grad=True)
6873*da0073e9SAndroid Build Coastguard Worker        z = torch.tensor(3.0, requires_grad=True)
6874*da0073e9SAndroid Build Coastguard Worker
6875*da0073e9SAndroid Build Coastguard Worker        def foo(x, y, z):
6876*da0073e9SAndroid Build Coastguard Worker            x = x * y * z
6877*da0073e9SAndroid Build Coastguard Worker            y = y * y * z
6878*da0073e9SAndroid Build Coastguard Worker            z = z * z
6879*da0073e9SAndroid Build Coastguard Worker            out = MyFunc.apply(x, y, z)
6880*da0073e9SAndroid Build Coastguard Worker            return out
6881*da0073e9SAndroid Build Coastguard Worker
6882*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(foo, x, y, z, use_reentrant=False)
6883*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
6884*da0073e9SAndroid Build Coastguard Worker
6885*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_with_context_fn(self):
6886*da0073e9SAndroid Build Coastguard Worker        class VerboseTorchDispatchMode(TorchDispatchMode):
6887*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
6888*da0073e9SAndroid Build Coastguard Worker                self.operators = []
6889*da0073e9SAndroid Build Coastguard Worker
6890*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
6891*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
6892*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
6893*da0073e9SAndroid Build Coastguard Worker                self.operators.append(func.__name__)
6894*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
6895*da0073e9SAndroid Build Coastguard Worker
6896*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1.0, requires_grad=True)
6897*da0073e9SAndroid Build Coastguard Worker        verbose_mode = VerboseTorchDispatchMode()
6898*da0073e9SAndroid Build Coastguard Worker
6899*da0073e9SAndroid Build Coastguard Worker        def context_fn():
6900*da0073e9SAndroid Build Coastguard Worker            return verbose_mode, contextlib.nullcontext()
6901*da0073e9SAndroid Build Coastguard Worker
6902*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(
6903*da0073e9SAndroid Build Coastguard Worker            lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
6904*da0073e9SAndroid Build Coastguard Worker        )
6905*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(verbose_mode.operators, ["exp.default"])
6906*da0073e9SAndroid Build Coastguard Worker
6907*da0073e9SAndroid Build Coastguard Worker        verbose_mode.operators = []
6908*da0073e9SAndroid Build Coastguard Worker
6909*da0073e9SAndroid Build Coastguard Worker        def context_fn():
6910*da0073e9SAndroid Build Coastguard Worker            return contextlib.nullcontext(), verbose_mode
6911*da0073e9SAndroid Build Coastguard Worker
6912*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(
6913*da0073e9SAndroid Build Coastguard Worker            lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn
6914*da0073e9SAndroid Build Coastguard Worker        )
6915*da0073e9SAndroid Build Coastguard Worker        out.backward()
6916*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
6917*da0073e9SAndroid Build Coastguard Worker            verbose_mode.operators, ["exp.default", "detach.default", "detach.default"]
6918*da0073e9SAndroid Build Coastguard Worker        )
6919*da0073e9SAndroid Build Coastguard Worker
6920*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6921*da0073e9SAndroid Build Coastguard Worker            Exception, "only supported when use_reentrant=False"
6922*da0073e9SAndroid Build Coastguard Worker        ):
6923*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(
6924*da0073e9SAndroid Build Coastguard Worker                lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn
6925*da0073e9SAndroid Build Coastguard Worker            )
6926*da0073e9SAndroid Build Coastguard Worker
6927*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_warns_if_use_reentrant_not_passed_explcitly(self):
6928*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, requires_grad=True)
6929*da0073e9SAndroid Build Coastguard Worker
6930*da0073e9SAndroid Build Coastguard Worker        # Passing explicitly should not warn
6931*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(lambda: checkpoint(lambda x: x, a, use_reentrant=False))
6932*da0073e9SAndroid Build Coastguard Worker
6933*da0073e9SAndroid Build Coastguard Worker        # Not passing explicitly warns
6934*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
6935*da0073e9SAndroid Build Coastguard Worker            UserWarning, ".*the use_reentrant parameter should be passed explicitly.*"
6936*da0073e9SAndroid Build Coastguard Worker        ):
6937*da0073e9SAndroid Build Coastguard Worker            checkpoint(lambda x: x, a)
6938*da0073e9SAndroid Build Coastguard Worker
6939*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_sequential_warns_if_use_reentrant_not_passed_explcitly(self):
6940*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, requires_grad=True)
6941*da0073e9SAndroid Build Coastguard Worker        modules_list = [
6942*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(3, 3),
6943*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(3, 3),
6944*da0073e9SAndroid Build Coastguard Worker            torch.nn.Linear(3, 3),
6945*da0073e9SAndroid Build Coastguard Worker        ]
6946*da0073e9SAndroid Build Coastguard Worker
6947*da0073e9SAndroid Build Coastguard Worker        # Passing explicitly should not warn
6948*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(
6949*da0073e9SAndroid Build Coastguard Worker            lambda: checkpoint_sequential(modules_list, 3, a, use_reentrant=False)
6950*da0073e9SAndroid Build Coastguard Worker        )
6951*da0073e9SAndroid Build Coastguard Worker
6952*da0073e9SAndroid Build Coastguard Worker        # Not passing explicitly warns
6953*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsOnceRegex(
6954*da0073e9SAndroid Build Coastguard Worker            UserWarning, ".*the use_reentrant parameter should be passed explicitly.*"
6955*da0073e9SAndroid Build Coastguard Worker        ):
6956*da0073e9SAndroid Build Coastguard Worker            checkpoint_sequential(modules_list, 3, a)
6957*da0073e9SAndroid Build Coastguard Worker
6958*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_detects_non_determinism(self):
6959*da0073e9SAndroid Build Coastguard Worker        def save_3_tensors(x):
6960*da0073e9SAndroid Build Coastguard Worker            out = x.sin().exp()
6961*da0073e9SAndroid Build Coastguard Worker            out = out.sin()
6962*da0073e9SAndroid Build Coastguard Worker            return out
6963*da0073e9SAndroid Build Coastguard Worker
6964*da0073e9SAndroid Build Coastguard Worker        def save_2_tensors(x):
6965*da0073e9SAndroid Build Coastguard Worker            return x.sin().exp()
6966*da0073e9SAndroid Build Coastguard Worker
6967*da0073e9SAndroid Build Coastguard Worker        def save_2_tensors_alt(x):
6968*da0073e9SAndroid Build Coastguard Worker            return x.sin() * torch.tensor([1.0, 2.0])
6969*da0073e9SAndroid Build Coastguard Worker
6970*da0073e9SAndroid Build Coastguard Worker        def get_non_det_fn(orig_fn, recompute_fn):
6971*da0073e9SAndroid Build Coastguard Worker            counter = [0]
6972*da0073e9SAndroid Build Coastguard Worker
6973*da0073e9SAndroid Build Coastguard Worker            def fn(x):
6974*da0073e9SAndroid Build Coastguard Worker                if counter[0] == 0:
6975*da0073e9SAndroid Build Coastguard Worker                    counter[0] += 1
6976*da0073e9SAndroid Build Coastguard Worker                    return orig_fn(x)
6977*da0073e9SAndroid Build Coastguard Worker                else:
6978*da0073e9SAndroid Build Coastguard Worker                    return recompute_fn(x)
6979*da0073e9SAndroid Build Coastguard Worker
6980*da0073e9SAndroid Build Coastguard Worker            return fn
6981*da0073e9SAndroid Build Coastguard Worker
6982*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, requires_grad=True)
6983*da0073e9SAndroid Build Coastguard Worker
6984*da0073e9SAndroid Build Coastguard Worker        # Save fewer tensors during recompute
6985*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_3_tensors, recompute_fn=save_2_tensors)
6986*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
6987*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "A different number of tensors was saved"
6988*da0073e9SAndroid Build Coastguard Worker        ):
6989*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=False)
6990*da0073e9SAndroid Build Coastguard Worker            out.backward()
6991*da0073e9SAndroid Build Coastguard Worker
6992*da0073e9SAndroid Build Coastguard Worker        # Save more tensors during recompute
6993*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors)
6994*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(False):
6995*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
6996*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "trying to save more tensors during recomputation"
6997*da0073e9SAndroid Build Coastguard Worker            ):
6998*da0073e9SAndroid Build Coastguard Worker                out = checkpoint(fn, a, use_reentrant=False)
6999*da0073e9SAndroid Build Coastguard Worker                out.backward()
7000*da0073e9SAndroid Build Coastguard Worker
7001*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors)
7002*da0073e9SAndroid Build Coastguard Worker        # If early stopping is enabled, we would not raise (the results would be correct anyway)
7003*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, a, use_reentrant=False)
7004*da0073e9SAndroid Build Coastguard Worker        out.backward()
7005*da0073e9SAndroid Build Coastguard Worker
7006*da0073e9SAndroid Build Coastguard Worker        # Save the same number of tensors but the shape is different
7007*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7008*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "tensors have different metadata"):
7009*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=False)
7010*da0073e9SAndroid Build Coastguard Worker            out.backward()
7011*da0073e9SAndroid Build Coastguard Worker
7012*da0073e9SAndroid Build Coastguard Worker        # Get the debug message if debug=True
7013*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7014*da0073e9SAndroid Build Coastguard Worker
7015*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7016*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7017*da0073e9SAndroid Build Coastguard Worker            "You are seeing this error because you passed `debug=True` to checkpoint",
7018*da0073e9SAndroid Build Coastguard Worker        ):
7019*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=False, debug=True)
7020*da0073e9SAndroid Build Coastguard Worker            out.backward()
7021*da0073e9SAndroid Build Coastguard Worker
7022*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7023*da0073e9SAndroid Build Coastguard Worker
7024*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7025*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7026*da0073e9SAndroid Build Coastguard Worker            "You are seeing this error because you passed `debug=True` to checkpoint",
7027*da0073e9SAndroid Build Coastguard Worker        ):
7028*da0073e9SAndroid Build Coastguard Worker            with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):
7029*da0073e9SAndroid Build Coastguard Worker                out = checkpoint(fn, a, use_reentrant=False, debug=False)
7030*da0073e9SAndroid Build Coastguard Worker                out.backward()
7031*da0073e9SAndroid Build Coastguard Worker
7032*da0073e9SAndroid Build Coastguard Worker        fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt)
7033*da0073e9SAndroid Build Coastguard Worker
7034*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7035*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Recomputed values for the following tensors have different"
7036*da0073e9SAndroid Build Coastguard Worker        ):
7037*da0073e9SAndroid Build Coastguard Worker            with torch.utils.checkpoint.set_checkpoint_debug_enabled(False):
7038*da0073e9SAndroid Build Coastguard Worker                out = checkpoint(fn, a, use_reentrant=False, debug=True)
7039*da0073e9SAndroid Build Coastguard Worker                out.backward()
7040*da0073e9SAndroid Build Coastguard Worker
7041*da0073e9SAndroid Build Coastguard Worker    def test_access_saved_tensor_twice_without_recomputation_works(self):
7042*da0073e9SAndroid Build Coastguard Worker        count = [0]
7043*da0073e9SAndroid Build Coastguard Worker
7044*da0073e9SAndroid Build Coastguard Worker        def foo(a):
7045*da0073e9SAndroid Build Coastguard Worker            count[0] += 1
7046*da0073e9SAndroid Build Coastguard Worker            b = a * a
7047*da0073e9SAndroid Build Coastguard Worker            c = a * b
7048*da0073e9SAndroid Build Coastguard Worker            d = torch.exp(a)
7049*da0073e9SAndroid Build Coastguard Worker            return d
7050*da0073e9SAndroid Build Coastguard Worker
7051*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
7052*da0073e9SAndroid Build Coastguard Worker        d = checkpoint(foo, a, use_reentrant=False)
7053*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 1)
7054*da0073e9SAndroid Build Coastguard Worker        # Recomputed variables only persist within a particular backward call.
7055*da0073e9SAndroid Build Coastguard Worker        # If _saved_result is accessed outside of a backward, it will trigger
7056*da0073e9SAndroid Build Coastguard Worker        # a recompute. And afterwards, those recomputed results are immediately
7057*da0073e9SAndroid Build Coastguard Worker        # cleared.
7058*da0073e9SAndroid Build Coastguard Worker        d.grad_fn._saved_result
7059*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 2)
7060*da0073e9SAndroid Build Coastguard Worker        # Second access will trigger another recompute
7061*da0073e9SAndroid Build Coastguard Worker        d.grad_fn._saved_result
7062*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 3)
7063*da0073e9SAndroid Build Coastguard Worker        # Backward clears the saved variable
7064*da0073e9SAndroid Build Coastguard Worker        d.sum().backward()
7065*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 4)
7066*da0073e9SAndroid Build Coastguard Worker        # Now it raises an error
7067*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7068*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7069*da0073e9SAndroid Build Coastguard Worker            "or directly access saved tensors after they have already been freed",
7070*da0073e9SAndroid Build Coastguard Worker        ):
7071*da0073e9SAndroid Build Coastguard Worker            d.grad_fn._saved_result
7072*da0073e9SAndroid Build Coastguard Worker
7073*da0073e9SAndroid Build Coastguard Worker    @slowTest
7074*da0073e9SAndroid Build Coastguard Worker    @parametrize("input_requires_grad", [True, False])
7075*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant(self, input_requires_grad):
7076*da0073e9SAndroid Build Coastguard Worker        """
7077*da0073e9SAndroid Build Coastguard Worker        Basic test for checkpoint without reentrant autograd.
7078*da0073e9SAndroid Build Coastguard Worker        """
7079*da0073e9SAndroid Build Coastguard Worker        num_inp = 2000
7080*da0073e9SAndroid Build Coastguard Worker        nz_inp = 10
7081*da0073e9SAndroid Build Coastguard Worker        nz_out = 10
7082*da0073e9SAndroid Build Coastguard Worker        nz_bottleneck = 1000
7083*da0073e9SAndroid Build Coastguard Worker
7084*da0073e9SAndroid Build Coastguard Worker        # small proxy network for some complex reasoning we want to do per input
7085*da0073e9SAndroid Build Coastguard Worker        module = nn.Sequential(
7086*da0073e9SAndroid Build Coastguard Worker            nn.Linear(nz_inp, nz_bottleneck),
7087*da0073e9SAndroid Build Coastguard Worker            nn.ReLU(),
7088*da0073e9SAndroid Build Coastguard Worker            nn.Linear(nz_bottleneck, nz_inp),
7089*da0073e9SAndroid Build Coastguard Worker        )
7090*da0073e9SAndroid Build Coastguard Worker
7091*da0073e9SAndroid Build Coastguard Worker        # Module holder for testing activation checkpointing with no_reentrant
7092*da0073e9SAndroid Build Coastguard Worker        # supports kwargs.
7093*da0073e9SAndroid Build Coastguard Worker        class MyModule(nn.Module):
7094*da0073e9SAndroid Build Coastguard Worker            def __init__(self, mod):
7095*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7096*da0073e9SAndroid Build Coastguard Worker                self.module = mod
7097*da0073e9SAndroid Build Coastguard Worker
7098*da0073e9SAndroid Build Coastguard Worker            def forward(self, data):
7099*da0073e9SAndroid Build Coastguard Worker                return self.module(data)
7100*da0073e9SAndroid Build Coastguard Worker
7101*da0073e9SAndroid Build Coastguard Worker        module = MyModule(mod=module)
7102*da0073e9SAndroid Build Coastguard Worker
7103*da0073e9SAndroid Build Coastguard Worker        # Run model with and without checkpointing and verify gradients are
7104*da0073e9SAndroid Build Coastguard Worker        # equivalent, regardless of if inputs require grads or not.
7105*da0073e9SAndroid Build Coastguard Worker        module_copy = deepcopy(module)
7106*da0073e9SAndroid Build Coastguard Worker
7107*da0073e9SAndroid Build Coastguard Worker        feat_combined = []
7108*da0073e9SAndroid Build Coastguard Worker        feat_combined_no_checkpoint = []
7109*da0073e9SAndroid Build Coastguard Worker        for r in range(num_inp):
7110*da0073e9SAndroid Build Coastguard Worker            data_r = torch.empty(1, nz_inp)
7111*da0073e9SAndroid Build Coastguard Worker            data_r.uniform_()
7112*da0073e9SAndroid Build Coastguard Worker            data_r.requires_grad = input_requires_grad
7113*da0073e9SAndroid Build Coastguard Worker            data_r_copy = data_r.clone()
7114*da0073e9SAndroid Build Coastguard Worker            feat_r = checkpoint(module, data=data_r, use_reentrant=False)
7115*da0073e9SAndroid Build Coastguard Worker            feat_combined.append(feat_r)
7116*da0073e9SAndroid Build Coastguard Worker            feat_r_no_checkpoint = module_copy(data_r)
7117*da0073e9SAndroid Build Coastguard Worker            feat_combined_no_checkpoint.append(feat_r_no_checkpoint)
7118*da0073e9SAndroid Build Coastguard Worker
7119*da0073e9SAndroid Build Coastguard Worker        # compute mean as a proxy for some joint reasoning
7120*da0073e9SAndroid Build Coastguard Worker        mean_combined = torch.stack(feat_combined).mean()
7121*da0073e9SAndroid Build Coastguard Worker        mean_combined.backward()
7122*da0073e9SAndroid Build Coastguard Worker        mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean()
7123*da0073e9SAndroid Build Coastguard Worker        mean_combined_no_checkpoint.backward()
7124*da0073e9SAndroid Build Coastguard Worker
7125*da0073e9SAndroid Build Coastguard Worker        for checkpoint_param, param in zip(
7126*da0073e9SAndroid Build Coastguard Worker            module.parameters(), module_copy.parameters()
7127*da0073e9SAndroid Build Coastguard Worker        ):
7128*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(checkpoint_param.grad, param.grad)
7129*da0073e9SAndroid Build Coastguard Worker
7130*da0073e9SAndroid Build Coastguard Worker    def test_checkpoint_valid_reset_on_error(self):
7131*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, requires_grad=True)
7132*da0073e9SAndroid Build Coastguard Worker
7133*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7134*da0073e9SAndroid Build Coastguard Worker            Exception, "torch.utils.checkpoint is incompatible"
7135*da0073e9SAndroid Build Coastguard Worker        ):
7136*da0073e9SAndroid Build Coastguard Worker            b = checkpoint(torch.exp, a, use_reentrant=True).sum()
7137*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(b, (a,))
7138*da0073e9SAndroid Build Coastguard Worker
7139*da0073e9SAndroid Build Coastguard Worker        c = checkpoint(torch.exp, a, use_reentrant=True).sum()
7140*da0073e9SAndroid Build Coastguard Worker        c.backward()
7141*da0073e9SAndroid Build Coastguard Worker
7142*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_reentrant", [True, False])
7143*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant):
7144*da0073e9SAndroid Build Coastguard Worker        class NoGradModule(torch.nn.Module):
7145*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7146*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7147*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(2, 2, bias=False)
7148*da0073e9SAndroid Build Coastguard Worker                self.lin2 = nn.Linear(2, 2, bias=False)
7149*da0073e9SAndroid Build Coastguard Worker
7150*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
7151*da0073e9SAndroid Build Coastguard Worker                with torch.no_grad():
7152*da0073e9SAndroid Build Coastguard Worker                    return self.lin2(self.linear(x))
7153*da0073e9SAndroid Build Coastguard Worker
7154*da0073e9SAndroid Build Coastguard Worker        module = NoGradModule()
7155*da0073e9SAndroid Build Coastguard Worker
7156*da0073e9SAndroid Build Coastguard Worker        err_ctx = (
7157*da0073e9SAndroid Build Coastguard Worker            self.assertRaisesRegex(
7158*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "none of output has requires_grad=True"
7159*da0073e9SAndroid Build Coastguard Worker            )
7160*da0073e9SAndroid Build Coastguard Worker            if use_reentrant
7161*da0073e9SAndroid Build Coastguard Worker            else contextlib.nullcontext()
7162*da0073e9SAndroid Build Coastguard Worker        )
7163*da0073e9SAndroid Build Coastguard Worker
7164*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, requires_grad=True)
7165*da0073e9SAndroid Build Coastguard Worker        for _ in range(3):
7166*da0073e9SAndroid Build Coastguard Worker            with err_ctx:
7167*da0073e9SAndroid Build Coastguard Worker                # out does not require grad
7168*da0073e9SAndroid Build Coastguard Worker                out = checkpoint(module, a, use_reentrant=use_reentrant)
7169*da0073e9SAndroid Build Coastguard Worker                # Make loss require grad, otherwise we would run into
7170*da0073e9SAndroid Build Coastguard Worker                # "element 0 of tensors does not require grad and does not have a grad_fn"
7171*da0073e9SAndroid Build Coastguard Worker                out += a
7172*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
7173*da0073e9SAndroid Build Coastguard Worker
7174*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_saved_object_identity(self):
7175*da0073e9SAndroid Build Coastguard Worker        x_backward = None
7176*da0073e9SAndroid Build Coastguard Worker
7177*da0073e9SAndroid Build Coastguard Worker        class Test(torch.autograd.Function):
7178*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7179*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
7180*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(y)
7181*da0073e9SAndroid Build Coastguard Worker                return x
7182*da0073e9SAndroid Build Coastguard Worker
7183*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7184*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
7185*da0073e9SAndroid Build Coastguard Worker                nonlocal x_backward
7186*da0073e9SAndroid Build Coastguard Worker                (x_backward,) = ctx.saved_tensors
7187*da0073e9SAndroid Build Coastguard Worker                return x, None
7188*da0073e9SAndroid Build Coastguard Worker
7189*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
7190*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(1.0, requires_grad=False)
7191*da0073e9SAndroid Build Coastguard Worker
7192*da0073e9SAndroid Build Coastguard Worker        Test.apply(a, b).backward()
7193*da0073e9SAndroid Build Coastguard Worker        self.assertIs(b, x_backward)
7194*da0073e9SAndroid Build Coastguard Worker
7195*da0073e9SAndroid Build Coastguard Worker        x_backward = None
7196*da0073e9SAndroid Build Coastguard Worker        checkpoint(Test.apply, a, b, use_reentrant=False).backward()
7197*da0073e9SAndroid Build Coastguard Worker        self.assertIs(b, x_backward)
7198*da0073e9SAndroid Build Coastguard Worker
7199*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_correct_grad(self):
7200*da0073e9SAndroid Build Coastguard Worker        """
7201*da0073e9SAndroid Build Coastguard Worker        Verifies that correct gradients are calculated for checkpoint
7202*da0073e9SAndroid Build Coastguard Worker        without reentrant autograd, for both backward() and autograd.grad().
7203*da0073e9SAndroid Build Coastguard Worker        """
7204*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, requires_grad=True)
7205*da0073e9SAndroid Build Coastguard Worker
7206*da0073e9SAndroid Build Coastguard Worker        b = torch.exp(a).sum()
7207*da0073e9SAndroid Build Coastguard Worker        b.backward()
7208*da0073e9SAndroid Build Coastguard Worker        b_grad = a.grad
7209*da0073e9SAndroid Build Coastguard Worker
7210*da0073e9SAndroid Build Coastguard Worker        a.grad = None
7211*da0073e9SAndroid Build Coastguard Worker        c = checkpoint(torch.exp, a, use_reentrant=False).sum()
7212*da0073e9SAndroid Build Coastguard Worker        c.backward()
7213*da0073e9SAndroid Build Coastguard Worker        c_grad = a.grad
7214*da0073e9SAndroid Build Coastguard Worker
7215*da0073e9SAndroid Build Coastguard Worker        a.grad = None
7216*da0073e9SAndroid Build Coastguard Worker        d = checkpoint(torch.exp, a, use_reentrant=False).sum()
7217*da0073e9SAndroid Build Coastguard Worker        (d_grad,) = torch.autograd.grad(d, (a,))
7218*da0073e9SAndroid Build Coastguard Worker
7219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_grad, c_grad)
7220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_grad, d_grad)
7221*da0073e9SAndroid Build Coastguard Worker
7222*da0073e9SAndroid Build Coastguard Worker    # PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally
7223*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115")
7224*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_dataparallel(self):
7225*da0073e9SAndroid Build Coastguard Worker        """
7226*da0073e9SAndroid Build Coastguard Worker        Verifies gradient correctness when checkpoint without reentrant autograd
7227*da0073e9SAndroid Build Coastguard Worker        is used in conjunction with DataParallel.
7228*da0073e9SAndroid Build Coastguard Worker        """
7229*da0073e9SAndroid Build Coastguard Worker
7230*da0073e9SAndroid Build Coastguard Worker        class LinearModule(torch.nn.Module):
7231*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7232*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7233*da0073e9SAndroid Build Coastguard Worker                self.linear = nn.Linear(2, 2, bias=False)
7234*da0073e9SAndroid Build Coastguard Worker
7235*da0073e9SAndroid Build Coastguard Worker            def forward(self, inp):
7236*da0073e9SAndroid Build Coastguard Worker                return self.linear(inp)
7237*da0073e9SAndroid Build Coastguard Worker
7238*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, requires_grad=True)
7239*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7240*da0073e9SAndroid Build Coastguard Worker            a = a.cuda()
7241*da0073e9SAndroid Build Coastguard Worker
7242*da0073e9SAndroid Build Coastguard Worker        model = LinearModule()
7243*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available():
7244*da0073e9SAndroid Build Coastguard Worker            model = model.cuda()
7245*da0073e9SAndroid Build Coastguard Worker
7246*da0073e9SAndroid Build Coastguard Worker        b = deepcopy(model)(a).sum()
7247*da0073e9SAndroid Build Coastguard Worker        b.backward()
7248*da0073e9SAndroid Build Coastguard Worker        b_grad = a.grad
7249*da0073e9SAndroid Build Coastguard Worker
7250*da0073e9SAndroid Build Coastguard Worker        a.grad = None
7251*da0073e9SAndroid Build Coastguard Worker
7252*da0073e9SAndroid Build Coastguard Worker        module = torch.nn.DataParallel(deepcopy(model))
7253*da0073e9SAndroid Build Coastguard Worker        c = checkpoint(module, a, use_reentrant=False).sum()
7254*da0073e9SAndroid Build Coastguard Worker        c.backward()
7255*da0073e9SAndroid Build Coastguard Worker        c_grad = a.grad
7256*da0073e9SAndroid Build Coastguard Worker
7257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_grad, c_grad)
7258*da0073e9SAndroid Build Coastguard Worker
7259*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_parameter_used_in_an_out(self):
7260*da0073e9SAndroid Build Coastguard Worker        """
7261*da0073e9SAndroid Build Coastguard Worker        Ensures that gradient hooks are only called once per tensor.
7262*da0073e9SAndroid Build Coastguard Worker        """
7263*da0073e9SAndroid Build Coastguard Worker        w = torch.randn(10, 10, requires_grad=True)
7264*da0073e9SAndroid Build Coastguard Worker        count = 0
7265*da0073e9SAndroid Build Coastguard Worker
7266*da0073e9SAndroid Build Coastguard Worker        def hook(grad):
7267*da0073e9SAndroid Build Coastguard Worker            nonlocal count
7268*da0073e9SAndroid Build Coastguard Worker            count += 1
7269*da0073e9SAndroid Build Coastguard Worker
7270*da0073e9SAndroid Build Coastguard Worker        w.register_hook(hook)
7271*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(10, 10, requires_grad=True)
7272*da0073e9SAndroid Build Coastguard Worker        h = w * x  # Using w outside the checkpoint
7273*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(
7274*da0073e9SAndroid Build Coastguard Worker            lambda x: w * x, h, use_reentrant=False
7275*da0073e9SAndroid Build Coastguard Worker        )  # Using w inside the checkpoint
7276*da0073e9SAndroid Build Coastguard Worker
7277*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
7278*da0073e9SAndroid Build Coastguard Worker        # should only call hook once
7279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count, 1)
7280*da0073e9SAndroid Build Coastguard Worker
7281*da0073e9SAndroid Build Coastguard Worker    # https://github.com/pytorch/pytorch/issues/127115
7282*da0073e9SAndroid Build Coastguard Worker    @xfailIfTorchDynamo
7283*da0073e9SAndroid Build Coastguard Worker    def test_checkpointing_without_reentrant_arbitrary_input_output(self):
7284*da0073e9SAndroid Build Coastguard Worker        """
7285*da0073e9SAndroid Build Coastguard Worker        Ensures checkpointing without reentrant autograd works with functions
7286*da0073e9SAndroid Build Coastguard Worker        with arbitrary input/output structures.
7287*da0073e9SAndroid Build Coastguard Worker        """
7288*da0073e9SAndroid Build Coastguard Worker
7289*da0073e9SAndroid Build Coastguard Worker        class MyModel(torch.nn.Module):
7290*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
7291*da0073e9SAndroid Build Coastguard Worker                super().__init__()
7292*da0073e9SAndroid Build Coastguard Worker                self.layer = torch.nn.Linear(5, 5, bias=False)
7293*da0073e9SAndroid Build Coastguard Worker
7294*da0073e9SAndroid Build Coastguard Worker            def forward(self, dict_input):
7295*da0073e9SAndroid Build Coastguard Worker                tensor = dict_input["tensor"]
7296*da0073e9SAndroid Build Coastguard Worker                return {"result": self.layer(tensor)}
7297*da0073e9SAndroid Build Coastguard Worker
7298*da0073e9SAndroid Build Coastguard Worker        model_no_checkpoint = MyModel()
7299*da0073e9SAndroid Build Coastguard Worker        model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint)
7300*da0073e9SAndroid Build Coastguard Worker
7301*da0073e9SAndroid Build Coastguard Worker        inp = {"tensor": torch.randn(5, 5)}
7302*da0073e9SAndroid Build Coastguard Worker
7303*da0073e9SAndroid Build Coastguard Worker        out_no_checkpoint = model_no_checkpoint(inp)["result"].sum()
7304*da0073e9SAndroid Build Coastguard Worker
7305*da0073e9SAndroid Build Coastguard Worker        out_checkpoint = checkpoint(
7306*da0073e9SAndroid Build Coastguard Worker            model_checkpoint_without_reentrant, inp, use_reentrant=False
7307*da0073e9SAndroid Build Coastguard Worker        )["result"].sum()
7308*da0073e9SAndroid Build Coastguard Worker
7309*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_checkpoint, out_no_checkpoint)
7310*da0073e9SAndroid Build Coastguard Worker
7311*da0073e9SAndroid Build Coastguard Worker        out_no_checkpoint.backward()
7312*da0073e9SAndroid Build Coastguard Worker        out_checkpoint.backward()
7313*da0073e9SAndroid Build Coastguard Worker
7314*da0073e9SAndroid Build Coastguard Worker        for param, checkpoint_param in zip(
7315*da0073e9SAndroid Build Coastguard Worker            model_no_checkpoint.parameters(),
7316*da0073e9SAndroid Build Coastguard Worker            model_checkpoint_without_reentrant.parameters(),
7317*da0073e9SAndroid Build Coastguard Worker        ):
7318*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(param.grad, checkpoint_param.grad)
7319*da0073e9SAndroid Build Coastguard Worker
7320*da0073e9SAndroid Build Coastguard Worker    def test_callback_adds_callback(self):
7321*da0073e9SAndroid Build Coastguard Worker        called = [0]
7322*da0073e9SAndroid Build Coastguard Worker
7323*da0073e9SAndroid Build Coastguard Worker        def callback_final():
7324*da0073e9SAndroid Build Coastguard Worker            called[0] += 1
7325*da0073e9SAndroid Build Coastguard Worker
7326*da0073e9SAndroid Build Coastguard Worker        def callback_adds_callback():
7327*da0073e9SAndroid Build Coastguard Worker            called[0] += 1
7328*da0073e9SAndroid Build Coastguard Worker            Variable._execution_engine.queue_callback(callback_final)
7329*da0073e9SAndroid Build Coastguard Worker
7330*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
7331*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7332*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
7333*da0073e9SAndroid Build Coastguard Worker                return input
7334*da0073e9SAndroid Build Coastguard Worker
7335*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7336*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
7337*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
7338*da0073e9SAndroid Build Coastguard Worker                Variable._execution_engine.queue_callback(callback_adds_callback)
7339*da0073e9SAndroid Build Coastguard Worker                return grad
7340*da0073e9SAndroid Build Coastguard Worker
7341*da0073e9SAndroid Build Coastguard Worker        a = torch.rand((3, 3), requires_grad=True)
7342*da0073e9SAndroid Build Coastguard Worker        b = MyFunc.apply(a)
7343*da0073e9SAndroid Build Coastguard Worker        b.sum().backward()
7344*da0073e9SAndroid Build Coastguard Worker
7345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called[0], 2)
7346*da0073e9SAndroid Build Coastguard Worker
7347*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
7348*da0073e9SAndroid Build Coastguard Worker    def test_callback_propagates_errors_from_device_thread(self):
7349*da0073e9SAndroid Build Coastguard Worker        def callback():
7350*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("blah")
7351*da0073e9SAndroid Build Coastguard Worker
7352*da0073e9SAndroid Build Coastguard Worker        def hook_with_callback(*args):
7353*da0073e9SAndroid Build Coastguard Worker            torch.autograd.Variable._execution_engine.queue_callback(callback)
7354*da0073e9SAndroid Build Coastguard Worker
7355*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda"))
7356*da0073e9SAndroid Build Coastguard Worker        t.register_hook(hook_with_callback)
7357*da0073e9SAndroid Build Coastguard Worker        output = t**2
7358*da0073e9SAndroid Build Coastguard Worker        loss = output.sum()
7359*da0073e9SAndroid Build Coastguard Worker
7360*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "blah"):
7361*da0073e9SAndroid Build Coastguard Worker            loss.backward()
7362*da0073e9SAndroid Build Coastguard Worker
7363*da0073e9SAndroid Build Coastguard Worker    def _test_reentrant_with_callbacks(self, install_callbacks_in_depths):
7364*da0073e9SAndroid Build Coastguard Worker        counter = {}
7365*da0073e9SAndroid Build Coastguard Worker        counter["inner"] = 0
7366*da0073e9SAndroid Build Coastguard Worker        counter["outer"] = 0
7367*da0073e9SAndroid Build Coastguard Worker
7368*da0073e9SAndroid Build Coastguard Worker        def inc_inner_counter():
7369*da0073e9SAndroid Build Coastguard Worker            counter["inner"] += 1
7370*da0073e9SAndroid Build Coastguard Worker
7371*da0073e9SAndroid Build Coastguard Worker        def inc_outer_counter():
7372*da0073e9SAndroid Build Coastguard Worker            counter["outer"] += 1
7373*da0073e9SAndroid Build Coastguard Worker
7374*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
7375*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7376*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
7377*da0073e9SAndroid Build Coastguard Worker                return input
7378*da0073e9SAndroid Build Coastguard Worker
7379*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7380*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
7381*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, input):
7382*da0073e9SAndroid Build Coastguard Worker                if 1 in install_callbacks_in_depths:
7383*da0073e9SAndroid Build Coastguard Worker                    # Add a callback to execute.
7384*da0073e9SAndroid Build Coastguard Worker                    Variable._execution_engine.queue_callback(inc_inner_counter)
7385*da0073e9SAndroid Build Coastguard Worker
7386*da0073e9SAndroid Build Coastguard Worker                return input
7387*da0073e9SAndroid Build Coastguard Worker
7388*da0073e9SAndroid Build Coastguard Worker        class MyReentrantFunc(Function):
7389*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7390*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
7391*da0073e9SAndroid Build Coastguard Worker                return input
7392*da0073e9SAndroid Build Coastguard Worker
7393*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7394*da0073e9SAndroid Build Coastguard Worker            @once_differentiable
7395*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, input):
7396*da0073e9SAndroid Build Coastguard Worker                if 0 in install_callbacks_in_depths:
7397*da0073e9SAndroid Build Coastguard Worker                    # Add a callback to execute.
7398*da0073e9SAndroid Build Coastguard Worker                    Variable._execution_engine.queue_callback(inc_outer_counter)
7399*da0073e9SAndroid Build Coastguard Worker                # Reentrant backward call.
7400*da0073e9SAndroid Build Coastguard Worker                tmp_inp = input.detach().requires_grad_()
7401*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
7402*da0073e9SAndroid Build Coastguard Worker                    tmp_out = (MyFunc.apply(tmp_inp)).sum()
7403*da0073e9SAndroid Build Coastguard Worker                tmp_out.backward()
7404*da0073e9SAndroid Build Coastguard Worker                return input
7405*da0073e9SAndroid Build Coastguard Worker
7406*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand((3, 3), requires_grad=True)
7407*da0073e9SAndroid Build Coastguard Worker        t2 = MyReentrantFunc.apply(t1)
7408*da0073e9SAndroid Build Coastguard Worker        t3 = t2.sum()
7409*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward([t3])
7410*da0073e9SAndroid Build Coastguard Worker
7411*da0073e9SAndroid Build Coastguard Worker        return counter
7412*da0073e9SAndroid Build Coastguard Worker
7413*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_with_callbacks_depth_0(self):
7414*da0073e9SAndroid Build Coastguard Worker        # Verify callback is called only once.
7415*da0073e9SAndroid Build Coastguard Worker        ret = self._test_reentrant_with_callbacks([0])
7416*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, ret["outer"])
7417*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, ret["inner"])
7418*da0073e9SAndroid Build Coastguard Worker
7419*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_with_callbacks_depth_1(self):
7420*da0073e9SAndroid Build Coastguard Worker        # Verify callback is called only once.
7421*da0073e9SAndroid Build Coastguard Worker        ret = self._test_reentrant_with_callbacks([1])
7422*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(0, ret["outer"])
7423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, ret["inner"])
7424*da0073e9SAndroid Build Coastguard Worker
7425*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_with_callbacks_both_depths(self):
7426*da0073e9SAndroid Build Coastguard Worker        # Verify callback is called twice.
7427*da0073e9SAndroid Build Coastguard Worker        ret = self._test_reentrant_with_callbacks([0, 1])
7428*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, ret["outer"])
7429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(1, ret["inner"])
7430*da0073e9SAndroid Build Coastguard Worker
7431*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_with_leaf_variable_hook(self):
7432*da0073e9SAndroid Build Coastguard Worker        handle = None
7433*da0073e9SAndroid Build Coastguard Worker        param = torch.rand(10, requires_grad=True)
7434*da0073e9SAndroid Build Coastguard Worker
7435*da0073e9SAndroid Build Coastguard Worker        def add_gradient_penalty_to_grad(grad):
7436*da0073e9SAndroid Build Coastguard Worker            handle.remove()
7437*da0073e9SAndroid Build Coastguard Worker            old_param_grad = grad
7438*da0073e9SAndroid Build Coastguard Worker            param.grad = None
7439*da0073e9SAndroid Build Coastguard Worker            # Add some sort of gradient penalty by directly updating the gradients
7440*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
7441*da0073e9SAndroid Build Coastguard Worker                g = grad.detach().requires_grad_()
7442*da0073e9SAndroid Build Coastguard Worker                new_param = param.detach().requires_grad_()
7443*da0073e9SAndroid Build Coastguard Worker                out = ((g * 2) + new_param).sum()
7444*da0073e9SAndroid Build Coastguard Worker                out.backward()
7445*da0073e9SAndroid Build Coastguard Worker            res = g.grad + grad
7446*da0073e9SAndroid Build Coastguard Worker            param.grad = old_param_grad
7447*da0073e9SAndroid Build Coastguard Worker            return res
7448*da0073e9SAndroid Build Coastguard Worker
7449*da0073e9SAndroid Build Coastguard Worker        handle = param.register_hook(add_gradient_penalty_to_grad)
7450*da0073e9SAndroid Build Coastguard Worker        # Forward pass
7451*da0073e9SAndroid Build Coastguard Worker        tmp = param * param
7452*da0073e9SAndroid Build Coastguard Worker        loss = tmp.sum()
7453*da0073e9SAndroid Build Coastguard Worker        # Compute the gradients
7454*da0073e9SAndroid Build Coastguard Worker        loss.backward()
7455*da0073e9SAndroid Build Coastguard Worker
7456*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_with_non_leaf_variable_hook(self):
7457*da0073e9SAndroid Build Coastguard Worker        handle = None
7458*da0073e9SAndroid Build Coastguard Worker        param = torch.rand(10, requires_grad=True)
7459*da0073e9SAndroid Build Coastguard Worker
7460*da0073e9SAndroid Build Coastguard Worker        def manual_increase_gradient(grad):
7461*da0073e9SAndroid Build Coastguard Worker            handle.remove()
7462*da0073e9SAndroid Build Coastguard Worker            # Add some sort of gradient penalty by directly updating the gradients
7463*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
7464*da0073e9SAndroid Build Coastguard Worker                g = grad.detach().requires_grad_()
7465*da0073e9SAndroid Build Coastguard Worker                out = ((g * 2) + 5).sum()
7466*da0073e9SAndroid Build Coastguard Worker                out.backward()
7467*da0073e9SAndroid Build Coastguard Worker            res = g.grad + grad
7468*da0073e9SAndroid Build Coastguard Worker            return res
7469*da0073e9SAndroid Build Coastguard Worker
7470*da0073e9SAndroid Build Coastguard Worker        # Forward pass
7471*da0073e9SAndroid Build Coastguard Worker        tmp = param * param
7472*da0073e9SAndroid Build Coastguard Worker        handle = tmp.register_hook(manual_increase_gradient)
7473*da0073e9SAndroid Build Coastguard Worker        loss = tmp.sum()
7474*da0073e9SAndroid Build Coastguard Worker        # Compute the gradients
7475*da0073e9SAndroid Build Coastguard Worker        loss.backward()
7476*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(param.grad, 6 * param)
7477*da0073e9SAndroid Build Coastguard Worker
7478*da0073e9SAndroid Build Coastguard Worker    def test_grad_fn_attr_bindings(self):
7479*da0073e9SAndroid Build Coastguard Worker        # Check that the getter of each type returns what we want
7480*da0073e9SAndroid Build Coastguard Worker        # See `gen_autograd_functions.py` for how the getters are generated
7481*da0073e9SAndroid Build Coastguard Worker        #
7482*da0073e9SAndroid Build Coastguard Worker        # This test is only meant to check if the codegen'd bindings work
7483*da0073e9SAndroid Build Coastguard Worker        # Please help update this test if you update the names of any the fields we check!
7484*da0073e9SAndroid Build Coastguard Worker        #
7485*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, requires_grad=True)
7486*da0073e9SAndroid Build Coastguard Worker        b = torch.zeros(1, requires_grad=True)
7487*da0073e9SAndroid Build Coastguard Worker        out1 = torch.stack([a, b], dim=0)
7488*da0073e9SAndroid Build Coastguard Worker        out2 = (a * 2) * b
7489*da0073e9SAndroid Build Coastguard Worker        # TODO: I don't think we have a backward saving a list of tensors
7490*da0073e9SAndroid Build Coastguard Worker        #       at the moment. It used to be stack, but for no reason...
7491*da0073e9SAndroid Build Coastguard Worker        #       see discussion in #84993
7492*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(out.grad_fn._saved_tensors, (a, b))              # TewnsorList -> Tuple[Tensor]
7493*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out2.grad_fn._saved_self, a * 2)
7494*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor)
7495*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7496*da0073e9SAndroid Build Coastguard Worker            out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor
7497*da0073e9SAndroid Build Coastguard Worker        )
7498*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1.grad_fn._saved_dim, 0)  # int64_t -> int
7499*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out1.grad_fn._saved_dim, int)
7500*da0073e9SAndroid Build Coastguard Worker
7501*da0073e9SAndroid Build Coastguard Worker        out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
7502*da0073e9SAndroid Build Coastguard Worker
7503*da0073e9SAndroid Build Coastguard Worker        out2.sum().backward()
7504*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7505*da0073e9SAndroid Build Coastguard Worker            out2.grad_fn._saved_self
7506*da0073e9SAndroid Build Coastguard Worker        # TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0],
7507*da0073e9SAndroid Build Coastguard Worker        #       not when using a saved tensor, see discussion in #84993
7508*da0073e9SAndroid Build Coastguard Worker        # with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7509*da0073e9SAndroid Build Coastguard Worker        #     out2.grad_fn._raw_saved_self
7510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out1.grad_fn._saved_dim, 0)
7511*da0073e9SAndroid Build Coastguard Worker
7512*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 2, requires_grad=True)
7513*da0073e9SAndroid Build Coastguard Worker        indices = torch.tensor([0, 1])
7514*da0073e9SAndroid Build Coastguard Worker        out = a[:, indices]
7515*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7516*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_indices, (None, indices)
7517*da0073e9SAndroid Build Coastguard Worker        )  # c10::List<std::optional<Tensor>> -> Tuple[Tensor?]
7518*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor)
7519*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7520*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor
7521*da0073e9SAndroid Build Coastguard Worker        )
7522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7523*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_self_sym_sizes, a.shape
7524*da0073e9SAndroid Build Coastguard Worker        )  # SymIntArrayRef -> Tuple[SymInt]
7525*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int)
7526*da0073e9SAndroid Build Coastguard Worker
7527*da0073e9SAndroid Build Coastguard Worker        out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x)
7528*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
7529*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x)
7530*da0073e9SAndroid Build Coastguard Worker
7531*da0073e9SAndroid Build Coastguard Worker        out = a.mean()
7532*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7533*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_self_sym_sizes, a.shape
7534*da0073e9SAndroid Build Coastguard Worker        )  # IntArrayRef -> Tuple[int]
7535*da0073e9SAndroid Build Coastguard Worker
7536*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 2, requires_grad=True)
7537*da0073e9SAndroid Build Coastguard Worker        out = a * a
7538*da0073e9SAndroid Build Coastguard Worker        out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
7539*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
7540*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "after it has been freed"):
7541*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
7542*da0073e9SAndroid Build Coastguard Worker
7543*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 2, requires_grad=True)
7544*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.interpolate(a, 4, mode="linear")
7545*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7546*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_output_size, (4,)
7547*da0073e9SAndroid Build Coastguard Worker        )  # std::optional<IntArrayRef> -> int[]?
7548*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_output_size[0], int)
7549*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_align_corners, False)  # bool -> bool
7550*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_align_corners, bool)
7551*da0073e9SAndroid Build Coastguard Worker        if hasattr(out.grad_fn, "_saved_scale_factors"):
7552*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(
7553*da0073e9SAndroid Build Coastguard Worker                out.grad_fn._saved_scale_factors
7554*da0073e9SAndroid Build Coastguard Worker            )  # std::optional<ArrayRef<double>> -> float[]?
7555*da0073e9SAndroid Build Coastguard Worker        else:
7556*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(
7557*da0073e9SAndroid Build Coastguard Worker                out.grad_fn._saved_scales
7558*da0073e9SAndroid Build Coastguard Worker            )  # std::optional<ArrayRef<double>> -> float[]?
7559*da0073e9SAndroid Build Coastguard Worker
7560*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 3, 3, requires_grad=True)
7561*da0073e9SAndroid Build Coastguard Worker        out = nn.Conv2d(1, 1, 3)(a)
7562*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7563*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_bias_sym_sizes_opt, (1,)
7564*da0073e9SAndroid Build Coastguard Worker        )  # std::optional<SymIntArrayRef> -> SymInt[]?
7565*da0073e9SAndroid Build Coastguard Worker        out = nn.Conv2d(1, 1, 3, bias=False)(a)
7566*da0073e9SAndroid Build Coastguard Worker        # TODO: This is BAD! we converted a std::nullopt into a (0,)
7567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,))
7568*da0073e9SAndroid Build Coastguard Worker
7569*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 3, 3, requires_grad=True)
7570*da0073e9SAndroid Build Coastguard Worker        out = torch.addbmm(a.squeeze(0), a, a)
7571*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1)  # int64_t
7572*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3)  # int64_t
7573*da0073e9SAndroid Build Coastguard Worker
7574*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 3, 3, requires_grad=True)
7575*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.unfold(a, 3)
7576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3)  # SymInt
7577*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3)  # SymInt
7578*da0073e9SAndroid Build Coastguard Worker
7579*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 2, requires_grad=True)
7580*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear")
7581*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_scales, 0.5)
7582*da0073e9SAndroid Build Coastguard Worker
7583*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 2, requires_grad=True)
7584*da0073e9SAndroid Build Coastguard Worker        out = torch.pdist(a, p=1)
7585*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_p, 1.0)  # double -> float
7586*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_p, float)
7587*da0073e9SAndroid Build Coastguard Worker
7588*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, 1, 2, requires_grad=True)
7589*da0073e9SAndroid Build Coastguard Worker        out = torch.logit(a, 1.0)
7590*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad_fn._saved_eps, 1.0)  # c10:optional<double> -> float?
7591*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_eps, float)
7592*da0073e9SAndroid Build Coastguard Worker        out = torch.logit(a)
7593*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn._saved_eps)
7594*da0073e9SAndroid Build Coastguard Worker
7595*da0073e9SAndroid Build Coastguard Worker        if torch._C.has_lapack:
7596*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1, 1, requires_grad=True)
7597*da0073e9SAndroid Build Coastguard Worker            q, r = torch.linalg.qr(a, mode="reduced")
7598*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(q.grad_fn._saved_mode, "reduced")  # std::string -> str
7599*da0073e9SAndroid Build Coastguard Worker
7600*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0], requires_grad=True)
7601*da0073e9SAndroid Build Coastguard Worker        out = torch.div(a, 2.0, rounding_mode="trunc")
7602*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7603*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_rounding_mode, "trunc"
7604*da0073e9SAndroid Build Coastguard Worker        )  # std::optional<std::string> -> str?
7605*da0073e9SAndroid Build Coastguard Worker        out = torch.div(a, 2.0, rounding_mode=None)
7606*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(
7607*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_rounding_mode
7608*da0073e9SAndroid Build Coastguard Worker        )  # std::optional<std::string> -> str?
7609*da0073e9SAndroid Build Coastguard Worker
7610*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(5, requires_grad=True)
7611*da0073e9SAndroid Build Coastguard Worker        out = torch.threshold(x, threshold=(1 + 0j), value=(1 + 0j))
7612*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7613*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_threshold, complex
7614*da0073e9SAndroid Build Coastguard Worker        )  # Scalar(complex double) -> complex
7615*da0073e9SAndroid Build Coastguard Worker        cfloat = torch.tensor(1 + 0j, dtype=torch.complex64)
7616*da0073e9SAndroid Build Coastguard Worker        out = torch.threshold(x, threshold=cfloat, value=(1 + 0j))
7617*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7618*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_threshold, complex
7619*da0073e9SAndroid Build Coastguard Worker        )  # Scalar(complex float) -> complex
7620*da0073e9SAndroid Build Coastguard Worker        out = torch.threshold(x, threshold=1.0, value=1.0)
7621*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7622*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_threshold, float
7623*da0073e9SAndroid Build Coastguard Worker        )  # Scalar(floating point) -> float
7624*da0073e9SAndroid Build Coastguard Worker        out = torch.threshold(x, threshold=1, value=1)
7625*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7626*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_threshold, int
7627*da0073e9SAndroid Build Coastguard Worker        )  # Scalar(integral) -> int
7628*da0073e9SAndroid Build Coastguard Worker        out = torch.threshold(x, threshold=False, value=False)
7629*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(
7630*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_threshold, bool
7631*da0073e9SAndroid Build Coastguard Worker        )  # Scalar(bool) -> bool
7632*da0073e9SAndroid Build Coastguard Worker
7633*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, 2, requires_grad=True)
7634*da0073e9SAndroid Build Coastguard Worker        out = a.as_strided((3,), (1,), 1)
7635*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7636*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_storage_offset, 1
7637*da0073e9SAndroid Build Coastguard Worker        )  # c10:optional<int64_t> -> int?
7638*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn._saved_storage_offset, int)
7639*da0073e9SAndroid Build Coastguard Worker        out = a.as_strided((3,), (1,))
7640*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn._saved_storage_offset)
7641*da0073e9SAndroid Build Coastguard Worker
7642*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, requires_grad=True)
7643*da0073e9SAndroid Build Coastguard Worker        out = torch.tanh(a)
7644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out, out.grad_fn._saved_result)  # saved variable when output
7645*da0073e9SAndroid Build Coastguard Worker
7646*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(3, 5, requires_grad=True)
7647*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([1, 0, 4])
7648*da0073e9SAndroid Build Coastguard Worker        loss = nn.NLLLoss()
7649*da0073e9SAndroid Build Coastguard Worker        out = loss(a, b)
7650*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn._saved_weight)
7651*da0073e9SAndroid Build Coastguard Worker        loss = nn.NLLLoss(weight=torch.ones((5,)))
7652*da0073e9SAndroid Build Coastguard Worker        out = loss(a, b)
7653*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
7654*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_weight, torch.ones((5,))
7655*da0073e9SAndroid Build Coastguard Worker        )  # c10:optional<Tensor> -> Tensor?
7656*da0073e9SAndroid Build Coastguard Worker
7657*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
7658*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7659*da0073e9SAndroid Build Coastguard Worker            out.grad_fn._saved_weight
7660*da0073e9SAndroid Build Coastguard Worker
7661*da0073e9SAndroid Build Coastguard Worker        num_tensors = 3
7662*da0073e9SAndroid Build Coastguard Worker        input_tensors = [
7663*da0073e9SAndroid Build Coastguard Worker            torch.ones(2, 2, requires_grad=True) for _ in range(num_tensors)
7664*da0073e9SAndroid Build Coastguard Worker        ]
7665*da0073e9SAndroid Build Coastguard Worker        scalars = [
7666*da0073e9SAndroid Build Coastguard Worker            0.0 for _ in range(num_tensors)
7667*da0073e9SAndroid Build Coastguard Worker        ]  # ArrayRef<Scalar> -> Tuple[Scalar, ...]
7668*da0073e9SAndroid Build Coastguard Worker        results = torch._foreach_maximum(input_tensors, scalars)
7669*da0073e9SAndroid Build Coastguard Worker        for t in results:
7670*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.grad_fn._saved_scalars, scalars)
7671*da0073e9SAndroid Build Coastguard Worker
7672*da0073e9SAndroid Build Coastguard Worker    def test_cant_create_saved_tensors(self):
7673*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7674*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7675*da0073e9SAndroid Build Coastguard Worker            "Trying to create a SavedTensor object from Python is forbidden",
7676*da0073e9SAndroid Build Coastguard Worker        ):
7677*da0073e9SAndroid Build Coastguard Worker            torch.autograd.SavedTensor()
7678*da0073e9SAndroid Build Coastguard Worker
7679*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_saved_tensors(self):
7680*da0073e9SAndroid Build Coastguard Worker        def getFn(save=True):
7681*da0073e9SAndroid Build Coastguard Worker            class MyFn(Function):
7682*da0073e9SAndroid Build Coastguard Worker                @staticmethod
7683*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x):
7684*da0073e9SAndroid Build Coastguard Worker                    if save:
7685*da0073e9SAndroid Build Coastguard Worker                        ctx.save_for_backward(x, None)
7686*da0073e9SAndroid Build Coastguard Worker                    return x
7687*da0073e9SAndroid Build Coastguard Worker
7688*da0073e9SAndroid Build Coastguard Worker                @staticmethod
7689*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, g):
7690*da0073e9SAndroid Build Coastguard Worker                    return g
7691*da0073e9SAndroid Build Coastguard Worker
7692*da0073e9SAndroid Build Coastguard Worker            return MyFn
7693*da0073e9SAndroid Build Coastguard Worker
7694*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
7695*da0073e9SAndroid Build Coastguard Worker
7696*da0073e9SAndroid Build Coastguard Worker        y = getFn(True).apply(a)
7697*da0073e9SAndroid Build Coastguard Worker
7698*da0073e9SAndroid Build Coastguard Worker        self.assertEqual((a, None), y.grad_fn.saved_tensors)
7699*da0073e9SAndroid Build Coastguard Worker        saved = y.grad_fn._raw_saved_tensors
7700*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(saved[0], torch._C._autograd.SavedTensor)
7701*da0073e9SAndroid Build Coastguard Worker        # We can't tell the underlying tensor is None without unpacking it
7702*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(saved[1], torch._C._autograd.SavedTensor)
7703*da0073e9SAndroid Build Coastguard Worker
7704*da0073e9SAndroid Build Coastguard Worker        # We catch that error when the user calls register_hooks on it
7705*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "None is forbidden"):
7706*da0073e9SAndroid Build Coastguard Worker            saved[1].register_hooks(lambda x: x, lambda x: x)
7707*da0073e9SAndroid Build Coastguard Worker
7708*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
7709*da0073e9SAndroid Build Coastguard Worker            saved[0].register_hooks(lambda x: x)
7710*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
7711*da0073e9SAndroid Build Coastguard Worker            saved[0].register_hooks(1, 1)
7712*da0073e9SAndroid Build Coastguard Worker        saved[0].register_hooks(lambda x: x, lambda x: x)
7713*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "already been set"):
7714*da0073e9SAndroid Build Coastguard Worker            saved[0].register_hooks(lambda x: x, lambda x: x)
7715*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
7716*da0073e9SAndroid Build Coastguard Worker
7717*da0073e9SAndroid Build Coastguard Worker        # Using a reference to the SavedTensor object after the
7718*da0073e9SAndroid Build Coastguard Worker        # saved variables have been released can lead to undefined behavior
7719*da0073e9SAndroid Build Coastguard Worker        del saved
7720*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7721*da0073e9SAndroid Build Coastguard Worker            y.grad_fn._raw_saved_tensors
7722*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
7723*da0073e9SAndroid Build Coastguard Worker            y.grad_fn.saved_tensors
7724*da0073e9SAndroid Build Coastguard Worker
7725*da0073e9SAndroid Build Coastguard Worker        y = getFn(False).apply(a)
7726*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad_fn.saved_tensors, ())
7727*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad_fn._raw_saved_tensors, ())
7728*da0073e9SAndroid Build Coastguard Worker
7729*da0073e9SAndroid Build Coastguard Worker    def test_autograd_node_isinstance(self):
7730*da0073e9SAndroid Build Coastguard Worker        # Node is a "virtual" base class of codegen'd nodes. This means that
7731*da0073e9SAndroid Build Coastguard Worker        # isinstance and issubclass are overridden, but mro is unchanged
7732*da0073e9SAndroid Build Coastguard Worker        Node = torch.autograd.graph.Node
7733*da0073e9SAndroid Build Coastguard Worker
7734*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(3, 3, requires_grad=True)
7735*da0073e9SAndroid Build Coastguard Worker        b = a.exp()
7736*da0073e9SAndroid Build Coastguard Worker
7737*da0073e9SAndroid Build Coastguard Worker        # Some nodes have codegened registrations to the torch._C._function module
7738*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(b.grad_fn, Node)
7739*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(type(b.grad_fn), Node))
7740*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(Node not in type(b.grad_fn).mro())
7741*da0073e9SAndroid Build Coastguard Worker
7742*da0073e9SAndroid Build Coastguard Worker        # Other nodes have manual registrations to the torch._C._function module
7743*da0073e9SAndroid Build Coastguard Worker        self.assertNotIsInstance(torch._C._functions.AccumulateGrad, Node)
7744*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(torch._C._functions.AccumulateGrad, Node))
7745*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(b.grad_fn.next_functions[0][0], Node)
7746*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(torch._C._functions.DelayedError, Node))
7747*da0073e9SAndroid Build Coastguard Worker
7748*da0073e9SAndroid Build Coastguard Worker        # Special cases
7749*da0073e9SAndroid Build Coastguard Worker        self.assertNotIsInstance(None, Node)
7750*da0073e9SAndroid Build Coastguard Worker        self.assertNotIsInstance(1, Node)
7751*da0073e9SAndroid Build Coastguard Worker        self.assertNotIsInstance(Node, Node)
7752*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(Node, Node))
7753*da0073e9SAndroid Build Coastguard Worker
7754*da0073e9SAndroid Build Coastguard Worker        # Custom function case
7755*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(torch.autograd.function.BackwardCFunction, Node))
7756*da0073e9SAndroid Build Coastguard Worker
7757*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
7758*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7759*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
7760*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(ctx, Node)
7761*da0073e9SAndroid Build Coastguard Worker                return x
7762*da0073e9SAndroid Build Coastguard Worker
7763*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7764*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
7765*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(ctx, Node)
7766*da0073e9SAndroid Build Coastguard Worker                return x
7767*da0073e9SAndroid Build Coastguard Worker
7768*da0073e9SAndroid Build Coastguard Worker        out = Func.apply(a)
7769*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(out.grad_fn, Node)
7770*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(issubclass(type(out.grad_fn), Node))
7771*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(Node not in type(out.grad_fn).mro())
7772*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
7773*da0073e9SAndroid Build Coastguard Worker
7774*da0073e9SAndroid Build Coastguard Worker    def test_autograd_views_codegen(self):
7775*da0073e9SAndroid Build Coastguard Worker        # This is not necessarily the absolute correct behavior, but this is the current
7776*da0073e9SAndroid Build Coastguard Worker        # one. This test is here to make sure that any change to this behavior is detected
7777*da0073e9SAndroid Build Coastguard Worker        # and not silent. The TODOs below mark the places with unexpected behavior.
7778*da0073e9SAndroid Build Coastguard Worker        # Note that any change in these test will be BC-breaking and should be done carefully.
7779*da0073e9SAndroid Build Coastguard Worker
7780*da0073e9SAndroid Build Coastguard Worker        # This test checks the behavior of two codegen functions (view_as and unbind)
7781*da0073e9SAndroid Build Coastguard Worker        # with respect to view tracking and inplace operation on the output.
7782*da0073e9SAndroid Build Coastguard Worker
7783*da0073e9SAndroid Build Coastguard Worker        def run_test(grad_mode, requires_grad, is_view, should_raise_tuple):
7784*da0073e9SAndroid Build Coastguard Worker            def maybe_check_raise(fn, should_raise):
7785*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(should_raise is None or isinstance(should_raise, str))
7786*da0073e9SAndroid Build Coastguard Worker                if should_raise is not None:
7787*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, should_raise):
7788*da0073e9SAndroid Build Coastguard Worker                        fn()
7789*da0073e9SAndroid Build Coastguard Worker                else:
7790*da0073e9SAndroid Build Coastguard Worker                    fn()
7791*da0073e9SAndroid Build Coastguard Worker
7792*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(2, requires_grad=requires_grad).clone()
7793*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(grad_mode):
7794*da0073e9SAndroid Build Coastguard Worker                out = inp.view_as(inp)
7795*da0073e9SAndroid Build Coastguard Worker            # Are they differentiable views?
7796*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out._is_view() == is_view)
7797*da0073e9SAndroid Build Coastguard Worker            # Are inplace allowed?
7798*da0073e9SAndroid Build Coastguard Worker            maybe_check_raise(lambda: out.add_(1), should_raise_tuple[0])
7799*da0073e9SAndroid Build Coastguard Worker
7800*da0073e9SAndroid Build Coastguard Worker            inp = torch.rand(2, requires_grad=requires_grad).clone()
7801*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(grad_mode):
7802*da0073e9SAndroid Build Coastguard Worker                out = inp.unbind()
7803*da0073e9SAndroid Build Coastguard Worker            # Are they differentiable views?
7804*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out[0]._is_view() == is_view)
7805*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out[1]._is_view() == is_view)
7806*da0073e9SAndroid Build Coastguard Worker            # Are inplace allowed?
7807*da0073e9SAndroid Build Coastguard Worker            maybe_check_raise(lambda: out[0].add_(1), should_raise_tuple[1])
7808*da0073e9SAndroid Build Coastguard Worker            maybe_check_raise(lambda: out[1].add_(1), should_raise_tuple[2])
7809*da0073e9SAndroid Build Coastguard Worker
7810*da0073e9SAndroid Build Coastguard Worker        # should_raise contains None if it should not raise
7811*da0073e9SAndroid Build Coastguard Worker        # should_raise contains a string of the error if it should raise
7812*da0073e9SAndroid Build Coastguard Worker        # The 3 elements are for view_as, first output of unbind and second output of unbind
7813*da0073e9SAndroid Build Coastguard Worker        run_test(
7814*da0073e9SAndroid Build Coastguard Worker            grad_mode=True,
7815*da0073e9SAndroid Build Coastguard Worker            requires_grad=False,
7816*da0073e9SAndroid Build Coastguard Worker            is_view=True,
7817*da0073e9SAndroid Build Coastguard Worker            should_raise_tuple=(None, None, None),
7818*da0073e9SAndroid Build Coastguard Worker        )
7819*da0073e9SAndroid Build Coastguard Worker        inp_change_err = (
7820*da0073e9SAndroid Build Coastguard Worker            "Output {} of UnbindBackward0 is a view and is being modified inplace."
7821*da0073e9SAndroid Build Coastguard Worker        )
7822*da0073e9SAndroid Build Coastguard Worker        run_test(
7823*da0073e9SAndroid Build Coastguard Worker            grad_mode=True,
7824*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
7825*da0073e9SAndroid Build Coastguard Worker            is_view=True,
7826*da0073e9SAndroid Build Coastguard Worker            should_raise_tuple=(
7827*da0073e9SAndroid Build Coastguard Worker                None,
7828*da0073e9SAndroid Build Coastguard Worker                inp_change_err.format("0"),
7829*da0073e9SAndroid Build Coastguard Worker                inp_change_err.format("1"),
7830*da0073e9SAndroid Build Coastguard Worker            ),
7831*da0073e9SAndroid Build Coastguard Worker        )
7832*da0073e9SAndroid Build Coastguard Worker        leaf_grad_err = (
7833*da0073e9SAndroid Build Coastguard Worker            "A view was created in no_grad mode and is being modified inplace"
7834*da0073e9SAndroid Build Coastguard Worker        )
7835*da0073e9SAndroid Build Coastguard Worker        run_test(
7836*da0073e9SAndroid Build Coastguard Worker            grad_mode=False,
7837*da0073e9SAndroid Build Coastguard Worker            requires_grad=True,
7838*da0073e9SAndroid Build Coastguard Worker            is_view=True,
7839*da0073e9SAndroid Build Coastguard Worker            should_raise_tuple=(leaf_grad_err, leaf_grad_err, leaf_grad_err),
7840*da0073e9SAndroid Build Coastguard Worker        )
7841*da0073e9SAndroid Build Coastguard Worker        run_test(
7842*da0073e9SAndroid Build Coastguard Worker            grad_mode=False,
7843*da0073e9SAndroid Build Coastguard Worker            requires_grad=False,
7844*da0073e9SAndroid Build Coastguard Worker            is_view=True,
7845*da0073e9SAndroid Build Coastguard Worker            should_raise_tuple=(None, None, None),
7846*da0073e9SAndroid Build Coastguard Worker        )
7847*da0073e9SAndroid Build Coastguard Worker
7848*da0073e9SAndroid Build Coastguard Worker    def test_inplace_not_requires_grad(self):
7849*da0073e9SAndroid Build Coastguard Worker        class MyFn(torch.autograd.Function):
7850*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7851*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp):
7852*da0073e9SAndroid Build Coastguard Worker                return inp.view_as(inp)
7853*da0073e9SAndroid Build Coastguard Worker
7854*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7855*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
7856*da0073e9SAndroid Build Coastguard Worker                return grad
7857*da0073e9SAndroid Build Coastguard Worker
7858*da0073e9SAndroid Build Coastguard Worker        # Original Tensor does not require grad
7859*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, 2)
7860*da0073e9SAndroid Build Coastguard Worker
7861*da0073e9SAndroid Build Coastguard Worker        # Tensor being written does require grad
7862*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1, requires_grad=True)
7863*da0073e9SAndroid Build Coastguard Worker
7864*da0073e9SAndroid Build Coastguard Worker        # Take an invalid view on 'a' that should raise an error (warns during deprecation)
7865*da0073e9SAndroid Build Coastguard Worker        view_a = MyFn.apply(a)
7866*da0073e9SAndroid Build Coastguard Worker
7867*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7868*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "This view was created inside a custom Function"
7869*da0073e9SAndroid Build Coastguard Worker        ):
7870*da0073e9SAndroid Build Coastguard Worker            view_a += b
7871*da0073e9SAndroid Build Coastguard Worker
7872*da0073e9SAndroid Build Coastguard Worker        # Extra test for copy_ that is a manual implementation and could be easily
7873*da0073e9SAndroid Build Coastguard Worker        # forgotten when the codegen is updated (warns during deprecation)
7874*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, 2)
7875*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1, requires_grad=True)
7876*da0073e9SAndroid Build Coastguard Worker        view_a = MyFn.apply(a)
7877*da0073e9SAndroid Build Coastguard Worker
7878*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7879*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "This view was created inside a custom Function"
7880*da0073e9SAndroid Build Coastguard Worker        ):
7881*da0073e9SAndroid Build Coastguard Worker            view_a.copy_(b)
7882*da0073e9SAndroid Build Coastguard Worker
7883*da0073e9SAndroid Build Coastguard Worker        # Functions that should throw must properly throw
7884*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, 2)
7885*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1, requires_grad=True)
7886*da0073e9SAndroid Build Coastguard Worker        view_a = a.unbind()[0]
7887*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
7888*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
7889*da0073e9SAndroid Build Coastguard Worker            "This view is the output of a function that returns " "multiple views.",
7890*da0073e9SAndroid Build Coastguard Worker        ):
7891*da0073e9SAndroid Build Coastguard Worker            view_a.copy_(b)
7892*da0073e9SAndroid Build Coastguard Worker
7893*da0073e9SAndroid Build Coastguard Worker        # Sanity check that views that should work still work
7894*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(1, 2)
7895*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(1, requires_grad=True)
7896*da0073e9SAndroid Build Coastguard Worker        a.select(1, 0).copy_(b)
7897*da0073e9SAndroid Build Coastguard Worker
7898*da0073e9SAndroid Build Coastguard Worker    def _do_test_autograd_simple_views_python(self, dtype):
7899*da0073e9SAndroid Build Coastguard Worker        # This is not necessarily the absolute correct behavior, but this is the current
7900*da0073e9SAndroid Build Coastguard Worker        # one. This test is here to make sure that any change to this behavior is detected
7901*da0073e9SAndroid Build Coastguard Worker        # and not silent. The TODOs below mark the places with unexpected behavior.
7902*da0073e9SAndroid Build Coastguard Worker        # Note that any change in these test will be BC-breaking and should be done carefully.
7903*da0073e9SAndroid Build Coastguard Worker
7904*da0073e9SAndroid Build Coastguard Worker        # This checks the autograd.Function behavior when we return one or multiple outputs
7905*da0073e9SAndroid Build Coastguard Worker        # while one of these is an input, a view of an input or of a temporary tensor.
7906*da0073e9SAndroid Build Coastguard Worker
7907*da0073e9SAndroid Build Coastguard Worker        # This indicator is used to track how many times the backward function was called
7908*da0073e9SAndroid Build Coastguard Worker        bw_called = [0]
7909*da0073e9SAndroid Build Coastguard Worker        # This indicator is used to check if the argument `ga` contains non-zero values
7910*da0073e9SAndroid Build Coastguard Worker        ga_nz = [False]
7911*da0073e9SAndroid Build Coastguard Worker
7912*da0073e9SAndroid Build Coastguard Worker        class IdOneOutput(Function):
7913*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7914*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b, make_view):
7915*da0073e9SAndroid Build Coastguard Worker                if make_view:
7916*da0073e9SAndroid Build Coastguard Worker                    a = a.narrow(0, 0, 2)
7917*da0073e9SAndroid Build Coastguard Worker                else:
7918*da0073e9SAndroid Build Coastguard Worker                    a = a.clone()
7919*da0073e9SAndroid Build Coastguard Worker                return a
7920*da0073e9SAndroid Build Coastguard Worker
7921*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7922*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, ga):
7923*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
7924*da0073e9SAndroid Build Coastguard Worker                return ga, None, None
7925*da0073e9SAndroid Build Coastguard Worker
7926*da0073e9SAndroid Build Coastguard Worker        class IdTwoOutput(Function):
7927*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7928*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b, make_view):
7929*da0073e9SAndroid Build Coastguard Worker                if make_view:
7930*da0073e9SAndroid Build Coastguard Worker                    a = a.narrow(0, 0, 2)
7931*da0073e9SAndroid Build Coastguard Worker                else:
7932*da0073e9SAndroid Build Coastguard Worker                    a = a.clone()
7933*da0073e9SAndroid Build Coastguard Worker                return a, a + b
7934*da0073e9SAndroid Build Coastguard Worker
7935*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7936*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, ga, gab):
7937*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
7938*da0073e9SAndroid Build Coastguard Worker                if ga.eq(0).all():
7939*da0073e9SAndroid Build Coastguard Worker                    ga_nz[0] = False
7940*da0073e9SAndroid Build Coastguard Worker                else:
7941*da0073e9SAndroid Build Coastguard Worker                    ga_nz[0] = True
7942*da0073e9SAndroid Build Coastguard Worker                return ga + gab, gab, None
7943*da0073e9SAndroid Build Coastguard Worker
7944*da0073e9SAndroid Build Coastguard Worker        class ViewOfTemp(Function):
7945*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7946*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, make_view):
7947*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(a)
7948*da0073e9SAndroid Build Coastguard Worker                if make_view:
7949*da0073e9SAndroid Build Coastguard Worker                    a = a.narrow(0, 0, 2)
7950*da0073e9SAndroid Build Coastguard Worker                else:
7951*da0073e9SAndroid Build Coastguard Worker                    a = a.clone()
7952*da0073e9SAndroid Build Coastguard Worker                b = a.clone()
7953*da0073e9SAndroid Build Coastguard Worker                return b.select(0, 0)
7954*da0073e9SAndroid Build Coastguard Worker
7955*da0073e9SAndroid Build Coastguard Worker            @staticmethod
7956*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
7957*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
7958*da0073e9SAndroid Build Coastguard Worker                (a,) = ctx.saved_tensors
7959*da0073e9SAndroid Build Coastguard Worker                res = torch.zeros_like(a)
7960*da0073e9SAndroid Build Coastguard Worker                res.select(0, 0).copy_(grad)
7961*da0073e9SAndroid Build Coastguard Worker                return res, None
7962*da0073e9SAndroid Build Coastguard Worker
7963*da0073e9SAndroid Build Coastguard Worker        fn_id_to_inplace_on_view_err_msg = {
7964*da0073e9SAndroid Build Coastguard Worker            "one_output": (
7965*da0073e9SAndroid Build Coastguard Worker                "Output 0 of IdOneOutputBackward is a view and is being "
7966*da0073e9SAndroid Build Coastguard Worker                "modified inplace. This view was created inside a custom Function"
7967*da0073e9SAndroid Build Coastguard Worker            ),
7968*da0073e9SAndroid Build Coastguard Worker            "two_output": (
7969*da0073e9SAndroid Build Coastguard Worker                "Output 0 of IdTwoOutputBackward is a view and is being modified inplace."
7970*da0073e9SAndroid Build Coastguard Worker                " This view is the output of a function that returns multiple views."
7971*da0073e9SAndroid Build Coastguard Worker            ),
7972*da0073e9SAndroid Build Coastguard Worker            "view_of_temp": (
7973*da0073e9SAndroid Build Coastguard Worker                "Output 0 of ViewOfTempBackward is a view and is being "
7974*da0073e9SAndroid Build Coastguard Worker                "modified inplace. This view was created inside a custom Function"
7975*da0073e9SAndroid Build Coastguard Worker            ),
7976*da0073e9SAndroid Build Coastguard Worker        }
7977*da0073e9SAndroid Build Coastguard Worker
7978*da0073e9SAndroid Build Coastguard Worker        for fn_id in ["one_output", "two_output", "view_of_temp"]:
7979*da0073e9SAndroid Build Coastguard Worker            for inplace in [True, False]:
7980*da0073e9SAndroid Build Coastguard Worker                for make_view in [True, False]:
7981*da0073e9SAndroid Build Coastguard Worker                    # Used for special casing the tests below
7982*da0073e9SAndroid Build Coastguard Worker                    output_is_a_view = make_view or fn_id == "view_of_temp"
7983*da0073e9SAndroid Build Coastguard Worker
7984*da0073e9SAndroid Build Coastguard Worker                    def fn(a, b):
7985*da0073e9SAndroid Build Coastguard Worker                        # never modify a, b inplace for gracheck
7986*da0073e9SAndroid Build Coastguard Worker                        a = a.clone()
7987*da0073e9SAndroid Build Coastguard Worker                        b = b.clone()
7988*da0073e9SAndroid Build Coastguard Worker                        if fn_id == "two_output":
7989*da0073e9SAndroid Build Coastguard Worker                            tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view)
7990*da0073e9SAndroid Build Coastguard Worker                            if inplace:
7991*da0073e9SAndroid Build Coastguard Worker                                tmp1 += 3
7992*da0073e9SAndroid Build Coastguard Worker                                tmp2 += 3
7993*da0073e9SAndroid Build Coastguard Worker                            else:
7994*da0073e9SAndroid Build Coastguard Worker                                tmp1 = tmp1 + 3
7995*da0073e9SAndroid Build Coastguard Worker                                tmp2 = tmp2 + 3
7996*da0073e9SAndroid Build Coastguard Worker                            tmp = tmp1 * tmp2
7997*da0073e9SAndroid Build Coastguard Worker                        else:
7998*da0073e9SAndroid Build Coastguard Worker                            if fn_id == "one_output":
7999*da0073e9SAndroid Build Coastguard Worker                                tmp = IdOneOutput.apply(a, b, make_view)
8000*da0073e9SAndroid Build Coastguard Worker                            else:
8001*da0073e9SAndroid Build Coastguard Worker                                tmp = ViewOfTemp.apply(a + b, make_view)
8002*da0073e9SAndroid Build Coastguard Worker                            if inplace:
8003*da0073e9SAndroid Build Coastguard Worker                                tmp += 3
8004*da0073e9SAndroid Build Coastguard Worker                            else:
8005*da0073e9SAndroid Build Coastguard Worker                                tmp = tmp + 3
8006*da0073e9SAndroid Build Coastguard Worker
8007*da0073e9SAndroid Build Coastguard Worker                        return tmp.sum()
8008*da0073e9SAndroid Build Coastguard Worker
8009*da0073e9SAndroid Build Coastguard Worker                    a = torch.ones(2, dtype=dtype, requires_grad=True)
8010*da0073e9SAndroid Build Coastguard Worker                    b = torch.ones(2, dtype=dtype, requires_grad=True)
8011*da0073e9SAndroid Build Coastguard Worker
8012*da0073e9SAndroid Build Coastguard Worker                    err_msg = fn_id_to_inplace_on_view_err_msg[fn_id]
8013*da0073e9SAndroid Build Coastguard Worker
8014*da0073e9SAndroid Build Coastguard Worker                    if not inplace or not output_is_a_view:
8015*da0073e9SAndroid Build Coastguard Worker                        gradcheck(fn, (a, b), check_batched_grad=False)
8016*da0073e9SAndroid Build Coastguard Worker
8017*da0073e9SAndroid Build Coastguard Worker                    # Was the custom backward called properly
8018*da0073e9SAndroid Build Coastguard Worker                    bw_called[0] = 0
8019*da0073e9SAndroid Build Coastguard Worker                    ga_nz[0] = True  # For the case where the backward is called
8020*da0073e9SAndroid Build Coastguard Worker
8021*da0073e9SAndroid Build Coastguard Worker                    if inplace and output_is_a_view:
8022*da0073e9SAndroid Build Coastguard Worker                        with self.assertRaisesRegex(RuntimeError, err_msg):
8023*da0073e9SAndroid Build Coastguard Worker                            fn(a, b)
8024*da0073e9SAndroid Build Coastguard Worker                    else:
8025*da0073e9SAndroid Build Coastguard Worker                        fn(a, b).abs().backward()
8026*da0073e9SAndroid Build Coastguard Worker
8027*da0073e9SAndroid Build Coastguard Worker                    expected_called = 1
8028*da0073e9SAndroid Build Coastguard Worker                    expected_ga_nz = True
8029*da0073e9SAndroid Build Coastguard Worker
8030*da0073e9SAndroid Build Coastguard Worker                    if output_is_a_view and inplace:
8031*da0073e9SAndroid Build Coastguard Worker                        expected_called = 0
8032*da0073e9SAndroid Build Coastguard Worker
8033*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(bw_called[0] == expected_called)
8034*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(ga_nz[0] == expected_ga_nz)
8035*da0073e9SAndroid Build Coastguard Worker
8036*da0073e9SAndroid Build Coastguard Worker    def test_autograd_simple_views_python(self):
8037*da0073e9SAndroid Build Coastguard Worker        self._do_test_autograd_simple_views_python(torch.double)
8038*da0073e9SAndroid Build Coastguard Worker        self._do_test_autograd_simple_views_python(torch.cdouble)
8039*da0073e9SAndroid Build Coastguard Worker
8040*da0073e9SAndroid Build Coastguard Worker    def test_autograd_inplace_views_creation_meta(self):
8041*da0073e9SAndroid Build Coastguard Worker        # Tests creation_meta properly handled for inplace views
8042*da0073e9SAndroid Build Coastguard Worker
8043*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8044*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8045*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
8046*da0073e9SAndroid Build Coastguard Worker                return x.view_as(x)
8047*da0073e9SAndroid Build Coastguard Worker
8048*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8049*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
8050*da0073e9SAndroid Build Coastguard Worker                return x
8051*da0073e9SAndroid Build Coastguard Worker
8052*da0073e9SAndroid Build Coastguard Worker        view_custom = Func.apply
8053*da0073e9SAndroid Build Coastguard Worker
8054*da0073e9SAndroid Build Coastguard Worker        def run_test(
8055*da0073e9SAndroid Build Coastguard Worker            fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2
8056*da0073e9SAndroid Build Coastguard Worker        ):
8057*da0073e9SAndroid Build Coastguard Worker            # This test checks the behavior of inplace-view functions when
8058*da0073e9SAndroid Build Coastguard Worker            # the views are created in grad mode or not
8059*da0073e9SAndroid Build Coastguard Worker            base = torch.rand(2, 3, requires_grad=requires_grad).clone()
8060*da0073e9SAndroid Build Coastguard Worker            # 1. Create a view with `grad_mode=grad_mode_view`
8061*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(grad_mode_view):
8062*da0073e9SAndroid Build Coastguard Worker                if fn_type == "multi_view":
8063*da0073e9SAndroid Build Coastguard Worker                    inp = base.unbind()[0]
8064*da0073e9SAndroid Build Coastguard Worker                elif fn_type == "custom":
8065*da0073e9SAndroid Build Coastguard Worker                    inp = view_custom(base)
8066*da0073e9SAndroid Build Coastguard Worker                else:
8067*da0073e9SAndroid Build Coastguard Worker                    inp = base.view_as(base)
8068*da0073e9SAndroid Build Coastguard Worker
8069*da0073e9SAndroid Build Coastguard Worker            # 2. Perform inplace view with `grad_mode=grad_mode_iview`
8070*da0073e9SAndroid Build Coastguard Worker            with torch.set_grad_enabled(grad_mode_iview):
8071*da0073e9SAndroid Build Coastguard Worker                if error1 is not None:
8072*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, error1):
8073*da0073e9SAndroid Build Coastguard Worker                        fn(inp)
8074*da0073e9SAndroid Build Coastguard Worker                    return
8075*da0073e9SAndroid Build Coastguard Worker                else:
8076*da0073e9SAndroid Build Coastguard Worker                    # If error is None, check that runs without error
8077*da0073e9SAndroid Build Coastguard Worker                    fn(inp)
8078*da0073e9SAndroid Build Coastguard Worker            # 3. Do inplace on the (new) view
8079*da0073e9SAndroid Build Coastguard Worker            if error2 is not None:
8080*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, error2):
8081*da0073e9SAndroid Build Coastguard Worker                    inp.add_(1)
8082*da0073e9SAndroid Build Coastguard Worker            else:
8083*da0073e9SAndroid Build Coastguard Worker                # If error is None, check that runs without error
8084*da0073e9SAndroid Build Coastguard Worker                inp.add_(1)
8085*da0073e9SAndroid Build Coastguard Worker
8086*da0073e9SAndroid Build Coastguard Worker        no_grad_err = "A view was created in no_grad mode"
8087*da0073e9SAndroid Build Coastguard Worker        multi_view_err = "function that returns multiple views"
8088*da0073e9SAndroid Build Coastguard Worker        custom_err = "view was created inside a custom Function"
8089*da0073e9SAndroid Build Coastguard Worker
8090*da0073e9SAndroid Build Coastguard Worker        def run_tests(fn):
8091*da0073e9SAndroid Build Coastguard Worker            for fn_type in ("normal", "multi_view", "custom"):
8092*da0073e9SAndroid Build Coastguard Worker                for grad_mode_view in (True, False):
8093*da0073e9SAndroid Build Coastguard Worker                    for grad_mode_iview in (True, False):
8094*da0073e9SAndroid Build Coastguard Worker                        for requires_grad in (True, False):
8095*da0073e9SAndroid Build Coastguard Worker                            error1 = None  # expected error when we do inplace_view on original view
8096*da0073e9SAndroid Build Coastguard Worker                            error2 = None  # expected error when we do inplace on the resulting view
8097*da0073e9SAndroid Build Coastguard Worker
8098*da0073e9SAndroid Build Coastguard Worker                            if requires_grad:
8099*da0073e9SAndroid Build Coastguard Worker                                if not grad_mode_view and grad_mode_iview:
8100*da0073e9SAndroid Build Coastguard Worker                                    error1 = no_grad_err
8101*da0073e9SAndroid Build Coastguard Worker                                if not grad_mode_view and not grad_mode_iview:
8102*da0073e9SAndroid Build Coastguard Worker                                    error2 = no_grad_err
8103*da0073e9SAndroid Build Coastguard Worker
8104*da0073e9SAndroid Build Coastguard Worker                                if fn_type == "multi_view":
8105*da0073e9SAndroid Build Coastguard Worker                                    if grad_mode_view and grad_mode_iview:
8106*da0073e9SAndroid Build Coastguard Worker                                        error1 = multi_view_err
8107*da0073e9SAndroid Build Coastguard Worker                                    if grad_mode_view and not grad_mode_iview:
8108*da0073e9SAndroid Build Coastguard Worker                                        error2 = multi_view_err
8109*da0073e9SAndroid Build Coastguard Worker
8110*da0073e9SAndroid Build Coastguard Worker                                if fn_type == "custom":
8111*da0073e9SAndroid Build Coastguard Worker                                    if grad_mode_view and grad_mode_iview:
8112*da0073e9SAndroid Build Coastguard Worker                                        error1 = custom_err
8113*da0073e9SAndroid Build Coastguard Worker                                    if grad_mode_view and not grad_mode_iview:
8114*da0073e9SAndroid Build Coastguard Worker                                        error2 = custom_err
8115*da0073e9SAndroid Build Coastguard Worker
8116*da0073e9SAndroid Build Coastguard Worker                            run_test(
8117*da0073e9SAndroid Build Coastguard Worker                                fn,
8118*da0073e9SAndroid Build Coastguard Worker                                fn_type,
8119*da0073e9SAndroid Build Coastguard Worker                                grad_mode_view,
8120*da0073e9SAndroid Build Coastguard Worker                                grad_mode_iview,
8121*da0073e9SAndroid Build Coastguard Worker                                requires_grad,
8122*da0073e9SAndroid Build Coastguard Worker                                error1,
8123*da0073e9SAndroid Build Coastguard Worker                                error2,
8124*da0073e9SAndroid Build Coastguard Worker                            )
8125*da0073e9SAndroid Build Coastguard Worker
8126*da0073e9SAndroid Build Coastguard Worker        # This list was created by logging gen_inplace_or_view_type.py
8127*da0073e9SAndroid Build Coastguard Worker        #   detach_ is excluded for this test because it cannot be applied to
8128*da0073e9SAndroid Build Coastguard Worker        #   views and thus does not return a view
8129*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.as_strided_((1, 0), (2, 2)))
8130*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.transpose_(0, 0))
8131*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.t_())
8132*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.squeeze_(0))
8133*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.unsqueeze_(0))
8134*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.swapdims_(0, 0))
8135*da0073e9SAndroid Build Coastguard Worker        run_tests(lambda v: v.swapaxes_(0, 0))
8136*da0073e9SAndroid Build Coastguard Worker
8137*da0073e9SAndroid Build Coastguard Worker    def test_autograd_print_tensor(self):
8138*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, requires_grad=True)
8139*da0073e9SAndroid Build Coastguard Worker        a_clone = a.clone()
8140*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(a), "tensor([1.], requires_grad=True)")
8141*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(a_clone), "tensor([1.], grad_fn=<CloneBackward0>)")
8142*da0073e9SAndroid Build Coastguard Worker
8143*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
8144*da0073e9SAndroid Build Coastguard Worker            b = a[:]
8145*da0073e9SAndroid Build Coastguard Worker            b *= 2
8146*da0073e9SAndroid Build Coastguard Worker
8147*da0073e9SAndroid Build Coastguard Worker        # Special handling for printing view created in no-grad and modified
8148*da0073e9SAndroid Build Coastguard Worker        # in-placed in no-grad.
8149*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(b), "tensor([2.], grad_fn=<Invalid>)")
8150*da0073e9SAndroid Build Coastguard Worker
8151*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8152*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8153*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
8154*da0073e9SAndroid Build Coastguard Worker                return x
8155*da0073e9SAndroid Build Coastguard Worker
8156*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8157*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, x):
8158*da0073e9SAndroid Build Coastguard Worker                return x
8159*da0073e9SAndroid Build Coastguard Worker
8160*da0073e9SAndroid Build Coastguard Worker        c = Func.apply(a)
8161*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(repr(c), "tensor([2.], grad_fn=<FuncBackward>)")
8162*da0073e9SAndroid Build Coastguard Worker
8163*da0073e9SAndroid Build Coastguard Worker    def test_autograd_inplace_view_of_view(self):
8164*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2)
8165*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
8166*da0073e9SAndroid Build Coastguard Worker            y = x.view(2)
8167*da0073e9SAndroid Build Coastguard Worker        y.requires_grad_(True)
8168*da0073e9SAndroid Build Coastguard Worker        z = y.view(2)
8169*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8170*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "a view of a view .* is being .* inside the no_grad block"
8171*da0073e9SAndroid Build Coastguard Worker        ):
8172*da0073e9SAndroid Build Coastguard Worker            z /= 2
8173*da0073e9SAndroid Build Coastguard Worker
8174*da0073e9SAndroid Build Coastguard Worker        x = torch.zeros(2)
8175*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
8176*da0073e9SAndroid Build Coastguard Worker            y = x.view(2)
8177*da0073e9SAndroid Build Coastguard Worker        y.requires_grad_(True)
8178*da0073e9SAndroid Build Coastguard Worker        z = y.view(2)
8179*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8180*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "a view of a view .* is being .* inside the inference_mode"
8181*da0073e9SAndroid Build Coastguard Worker        ):
8182*da0073e9SAndroid Build Coastguard Worker            z /= 2
8183*da0073e9SAndroid Build Coastguard Worker
8184*da0073e9SAndroid Build Coastguard Worker    # TODO This is not the correct behavior -
8185*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627
8186*da0073e9SAndroid Build Coastguard Worker    def test_autograd_inplace_views_cross_dtype(self):
8187*da0073e9SAndroid Build Coastguard Worker        # This test is here to make sure that any change to this behavior is detected
8188*da0073e9SAndroid Build Coastguard Worker        # and not silent. The TODOs below mark the places with unexpected behavior.
8189*da0073e9SAndroid Build Coastguard Worker        a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
8190*da0073e9SAndroid Build Coastguard Worker        a = a_orig.clone()
8191*da0073e9SAndroid Build Coastguard Worker        b = torch.view_as_real(a)
8192*da0073e9SAndroid Build Coastguard Worker        b = b.transpose(0, 1)
8193*da0073e9SAndroid Build Coastguard Worker        b += 1
8194*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
8195*da0073e9SAndroid Build Coastguard Worker        non_inplace_grad = a_orig.grad
8196*da0073e9SAndroid Build Coastguard Worker
8197*da0073e9SAndroid Build Coastguard Worker        a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64)
8198*da0073e9SAndroid Build Coastguard Worker        a = a_orig.clone()
8199*da0073e9SAndroid Build Coastguard Worker        b = torch.view_as_real(a)
8200*da0073e9SAndroid Build Coastguard Worker        b.transpose_(0, 1)
8201*da0073e9SAndroid Build Coastguard Worker        b += 1
8202*da0073e9SAndroid Build Coastguard Worker        b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2))
8203*da0073e9SAndroid Build Coastguard Worker        inplace_grad = a_orig.grad
8204*da0073e9SAndroid Build Coastguard Worker
8205*da0073e9SAndroid Build Coastguard Worker        # TODO: this is a bug!
8206*da0073e9SAndroid Build Coastguard Worker        # once this is fixed, it should have the transpose removed:
8207*da0073e9SAndroid Build Coastguard Worker        # self.assertEqual(non_inplace_grad, inplace_grad)
8208*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(non_inplace_grad.T, inplace_grad)
8209*da0073e9SAndroid Build Coastguard Worker
8210*da0073e9SAndroid Build Coastguard Worker    def test_autograd_multiple_views_python(self):
8211*da0073e9SAndroid Build Coastguard Worker        # This is not necessarily the absolute correct behavior, but this is the current
8212*da0073e9SAndroid Build Coastguard Worker        # one. This test is here to make sure that any change to this behavior is detected
8213*da0073e9SAndroid Build Coastguard Worker        # and not silent. The TODOs below mark the places with unexpected behavior.
8214*da0073e9SAndroid Build Coastguard Worker        # Note that any change in these test will be BC-breaking and should be done carefully.
8215*da0073e9SAndroid Build Coastguard Worker
8216*da0073e9SAndroid Build Coastguard Worker        # This checks that multiples views in the forward are properly traced and how they
8217*da0073e9SAndroid Build Coastguard Worker        # behave with respect to inplace operations.
8218*da0073e9SAndroid Build Coastguard Worker
8219*da0073e9SAndroid Build Coastguard Worker        # This indicator is used to track how many times the backward function was called
8220*da0073e9SAndroid Build Coastguard Worker        bw_called = [0]
8221*da0073e9SAndroid Build Coastguard Worker
8222*da0073e9SAndroid Build Coastguard Worker        class ComplexView(Function):
8223*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8224*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, idx):
8225*da0073e9SAndroid Build Coastguard Worker                res = a.narrow(0, idx, 1)
8226*da0073e9SAndroid Build Coastguard Worker                res = a.select(0, idx)
8227*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(a)
8228*da0073e9SAndroid Build Coastguard Worker                ctx.idx = idx
8229*da0073e9SAndroid Build Coastguard Worker                return res
8230*da0073e9SAndroid Build Coastguard Worker
8231*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8232*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
8233*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
8234*da0073e9SAndroid Build Coastguard Worker                (a,) = ctx.saved_tensors
8235*da0073e9SAndroid Build Coastguard Worker                res = torch.zeros_like(a)
8236*da0073e9SAndroid Build Coastguard Worker                res.select(0, ctx.idx).copy_(grad)
8237*da0073e9SAndroid Build Coastguard Worker                return res, None
8238*da0073e9SAndroid Build Coastguard Worker
8239*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, requires_grad=True)
8240*da0073e9SAndroid Build Coastguard Worker        idx = 1
8241*da0073e9SAndroid Build Coastguard Worker
8242*da0073e9SAndroid Build Coastguard Worker        bw_called[0] = 0
8243*da0073e9SAndroid Build Coastguard Worker        out = ComplexView.apply(a.clone(), idx)
8244*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
8245*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bw_called[0] == 1)
8246*da0073e9SAndroid Build Coastguard Worker
8247*da0073e9SAndroid Build Coastguard Worker        out = ComplexView.apply(a.clone(), idx)
8248*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8249*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
8250*da0073e9SAndroid Build Coastguard Worker            "Output 0 of ComplexViewBackward is a view and is being modified inplace",
8251*da0073e9SAndroid Build Coastguard Worker        ):
8252*da0073e9SAndroid Build Coastguard Worker            out += 1
8253*da0073e9SAndroid Build Coastguard Worker
8254*da0073e9SAndroid Build Coastguard Worker    def test_autograd_python_custom_function_inplace(self):
8255*da0073e9SAndroid Build Coastguard Worker        # This is not necessarily the absolute correct behavior, but this is the current
8256*da0073e9SAndroid Build Coastguard Worker        # one. This test is here to make sure that any change to this behavior is detected
8257*da0073e9SAndroid Build Coastguard Worker        # and not silent. The TODOs below mark the places with unexpected behavior.
8258*da0073e9SAndroid Build Coastguard Worker        # Note that any change in these test will be BC-breaking and should be done carefully.
8259*da0073e9SAndroid Build Coastguard Worker
8260*da0073e9SAndroid Build Coastguard Worker        # This test checks custom autograd.Function that perform inplace operations
8261*da0073e9SAndroid Build Coastguard Worker
8262*da0073e9SAndroid Build Coastguard Worker        bw_called = [0]
8263*da0073e9SAndroid Build Coastguard Worker
8264*da0073e9SAndroid Build Coastguard Worker        # I) Single output
8265*da0073e9SAndroid Build Coastguard Worker        class MyAdder(Function):
8266*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8267*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
8268*da0073e9SAndroid Build Coastguard Worker                a.add_(b)
8269*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(a)
8270*da0073e9SAndroid Build Coastguard Worker                return a
8271*da0073e9SAndroid Build Coastguard Worker
8272*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8273*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
8274*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
8275*da0073e9SAndroid Build Coastguard Worker                return grad, grad
8276*da0073e9SAndroid Build Coastguard Worker
8277*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, requires_grad=True)
8278*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(2, requires_grad=True)
8279*da0073e9SAndroid Build Coastguard Worker
8280*da0073e9SAndroid Build Coastguard Worker        # No extra inplace
8281*da0073e9SAndroid Build Coastguard Worker        c = MyAdder.apply(a.clone(), b)
8282*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
8283*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bw_called[0] == 1)
8284*da0073e9SAndroid Build Coastguard Worker
8285*da0073e9SAndroid Build Coastguard Worker        # With extra inplace on the output
8286*da0073e9SAndroid Build Coastguard Worker        bw_called[0] = 0
8287*da0073e9SAndroid Build Coastguard Worker        c = MyAdder.apply(a.clone(), b)
8288*da0073e9SAndroid Build Coastguard Worker        c += 2
8289*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
8290*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bw_called[0] == 1)
8291*da0073e9SAndroid Build Coastguard Worker
8292*da0073e9SAndroid Build Coastguard Worker        # The input is a view
8293*da0073e9SAndroid Build Coastguard Worker        bw_called[0] = 0
8294*da0073e9SAndroid Build Coastguard Worker        c = MyAdder.apply(a.clone().view_as(a), b)
8295*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
8296*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bw_called[0] == 1)
8297*da0073e9SAndroid Build Coastguard Worker
8298*da0073e9SAndroid Build Coastguard Worker        # Should not give non-inputs to mark_dirty
8299*da0073e9SAndroid Build Coastguard Worker        class MyAdderBad(Function):
8300*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8301*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
8302*da0073e9SAndroid Build Coastguard Worker                c = 3 * a
8303*da0073e9SAndroid Build Coastguard Worker                c.add_(b)
8304*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(c)
8305*da0073e9SAndroid Build Coastguard Worker                return c
8306*da0073e9SAndroid Build Coastguard Worker
8307*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8308*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
8309*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
8310*da0073e9SAndroid Build Coastguard Worker                grad = 3 * grad
8311*da0073e9SAndroid Build Coastguard Worker                return grad, grad
8312*da0073e9SAndroid Build Coastguard Worker
8313*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(2, requires_grad=True)
8314*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(2, requires_grad=True)
8315*da0073e9SAndroid Build Coastguard Worker
8316*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as w:
8317*da0073e9SAndroid Build Coastguard Worker            MyAdderBad.apply(a.clone(), b)
8318*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(w), 1)
8319*da0073e9SAndroid Build Coastguard Worker
8320*da0073e9SAndroid Build Coastguard Worker        # II) Multiple outputs
8321*da0073e9SAndroid Build Coastguard Worker        class MyBadAdder(Function):
8322*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8323*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
8324*da0073e9SAndroid Build Coastguard Worker                a.add_(b)
8325*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(a)
8326*da0073e9SAndroid Build Coastguard Worker                return a, a + b
8327*da0073e9SAndroid Build Coastguard Worker
8328*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8329*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, ga, gab):
8330*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
8331*da0073e9SAndroid Build Coastguard Worker                return ga + gab, ga + gab
8332*da0073e9SAndroid Build Coastguard Worker
8333*da0073e9SAndroid Build Coastguard Worker        # No extra inplace
8334*da0073e9SAndroid Build Coastguard Worker        bw_called[0] = 0
8335*da0073e9SAndroid Build Coastguard Worker        c, d = MyBadAdder.apply(a.clone(), b)
8336*da0073e9SAndroid Build Coastguard Worker        (c * d).sum().backward()
8337*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bw_called[0] == 1)
8338*da0073e9SAndroid Build Coastguard Worker
8339*da0073e9SAndroid Build Coastguard Worker        # With extra inplace on the output
8340*da0073e9SAndroid Build Coastguard Worker        bw_called[0] = 0
8341*da0073e9SAndroid Build Coastguard Worker        c, d = MyBadAdder.apply(a.clone(), b)
8342*da0073e9SAndroid Build Coastguard Worker        c += 2
8343*da0073e9SAndroid Build Coastguard Worker        (c * d).sum().backward()
8344*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(bw_called[0] == 1)
8345*da0073e9SAndroid Build Coastguard Worker
8346*da0073e9SAndroid Build Coastguard Worker        # The input is a view
8347*da0073e9SAndroid Build Coastguard Worker        inplace_on_view_err = (
8348*da0073e9SAndroid Build Coastguard Worker            "your Function modifies inplace an input that is a view of another Tensor"
8349*da0073e9SAndroid Build Coastguard Worker        )
8350*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, inplace_on_view_err):
8351*da0073e9SAndroid Build Coastguard Worker            c, d = MyBadAdder.apply(a.clone().view_as(a), b)
8352*da0073e9SAndroid Build Coastguard Worker
8353*da0073e9SAndroid Build Coastguard Worker        # III) Inplace + other op
8354*da0073e9SAndroid Build Coastguard Worker        class MyOutPlaceAdder(Function):
8355*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8356*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a, b):
8357*da0073e9SAndroid Build Coastguard Worker                a.add_(b)
8358*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(a)
8359*da0073e9SAndroid Build Coastguard Worker                return a.clone(), a + b
8360*da0073e9SAndroid Build Coastguard Worker
8361*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8362*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, ga, gab):
8363*da0073e9SAndroid Build Coastguard Worker                bw_called[0] += 1
8364*da0073e9SAndroid Build Coastguard Worker                return ga + gab, ga + 2 * gab
8365*da0073e9SAndroid Build Coastguard Worker
8366*da0073e9SAndroid Build Coastguard Worker        # We don't reuse the input
8367*da0073e9SAndroid Build Coastguard Worker        def fn(a, b):
8368*da0073e9SAndroid Build Coastguard Worker            orig_a = a.clone().view_as(a)
8369*da0073e9SAndroid Build Coastguard Worker            c, d = MyOutPlaceAdder.apply(orig_a, b)
8370*da0073e9SAndroid Build Coastguard Worker            return (c * d).sum()
8371*da0073e9SAndroid Build Coastguard Worker
8372*da0073e9SAndroid Build Coastguard Worker        bad_mark_dirty_err = "Some elements marked as dirty during the forward method were not returned as output."
8373*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
8374*da0073e9SAndroid Build Coastguard Worker            fn(a, b)
8375*da0073e9SAndroid Build Coastguard Worker
8376*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_mark_dirty_not_differentiable(self):
8377*da0073e9SAndroid Build Coastguard Worker        def get_custom_fn(jvp_err):
8378*da0073e9SAndroid Build Coastguard Worker            class InplaceMul(torch.autograd.Function):
8379*da0073e9SAndroid Build Coastguard Worker                @staticmethod
8380*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x):
8381*da0073e9SAndroid Build Coastguard Worker                    result = x.mul_(2)
8382*da0073e9SAndroid Build Coastguard Worker                    ctx.mark_dirty(result)
8383*da0073e9SAndroid Build Coastguard Worker                    return result
8384*da0073e9SAndroid Build Coastguard Worker
8385*da0073e9SAndroid Build Coastguard Worker                @staticmethod
8386*da0073e9SAndroid Build Coastguard Worker                def backward(ctx, grad_output):
8387*da0073e9SAndroid Build Coastguard Worker                    pass
8388*da0073e9SAndroid Build Coastguard Worker
8389*da0073e9SAndroid Build Coastguard Worker                @staticmethod
8390*da0073e9SAndroid Build Coastguard Worker                def jvp(ctx, x_t):
8391*da0073e9SAndroid Build Coastguard Worker                    if jvp_err:
8392*da0073e9SAndroid Build Coastguard Worker                        return x_t
8393*da0073e9SAndroid Build Coastguard Worker                    else:
8394*da0073e9SAndroid Build Coastguard Worker                        return x_t.mul_(2)
8395*da0073e9SAndroid Build Coastguard Worker
8396*da0073e9SAndroid Build Coastguard Worker            return InplaceMul
8397*da0073e9SAndroid Build Coastguard Worker
8398*da0073e9SAndroid Build Coastguard Worker        for requires_grad, jvp_err in product([True, False], repeat=2):
8399*da0073e9SAndroid Build Coastguard Worker            InplaceMul = get_custom_fn(jvp_err)
8400*da0073e9SAndroid Build Coastguard Worker            # Make sure that tensor is always returned as-is if marked dirty
8401*da0073e9SAndroid Build Coastguard Worker            z = torch.tensor(1.0, requires_grad=requires_grad)
8402*da0073e9SAndroid Build Coastguard Worker            x = z.clone()
8403*da0073e9SAndroid Build Coastguard Worker            y = InplaceMul.apply(x)
8404*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x is y)
8405*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, z * 2)
8406*da0073e9SAndroid Build Coastguard Worker
8407*da0073e9SAndroid Build Coastguard Worker            # jvp must properly modify the input grad if mark_dirty is set
8408*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
8409*da0073e9SAndroid Build Coastguard Worker                x_tangent = torch.ones_like(x)
8410*da0073e9SAndroid Build Coastguard Worker                x_dual = fwAD.make_dual(x, x_tangent)
8411*da0073e9SAndroid Build Coastguard Worker
8412*da0073e9SAndroid Build Coastguard Worker                if jvp_err:
8413*da0073e9SAndroid Build Coastguard Worker                    bad_mark_dirty_err = (
8414*da0073e9SAndroid Build Coastguard Worker                        "jvp function must modify the corresponding gradient inplace"
8415*da0073e9SAndroid Build Coastguard Worker                    )
8416*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
8417*da0073e9SAndroid Build Coastguard Worker                        InplaceMul.apply(x_dual)
8418*da0073e9SAndroid Build Coastguard Worker                else:
8419*da0073e9SAndroid Build Coastguard Worker                    out_dual = InplaceMul.apply(x_dual)
8420*da0073e9SAndroid Build Coastguard Worker                    _, out_tangent = fwAD.unpack_dual(out_dual)
8421*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out_dual is x_dual)
8422*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out_tangent is x_tangent)
8423*da0073e9SAndroid Build Coastguard Worker
8424*da0073e9SAndroid Build Coastguard Worker    def test_named_tensor_for_complex_views(self):
8425*da0073e9SAndroid Build Coastguard Worker        names = ["batch", "height", "width", "complex"]
8426*da0073e9SAndroid Build Coastguard Worker        z = torch.ones((2, 1, 2, 2), requires_grad=True)
8427*da0073e9SAndroid Build Coastguard Worker        z_named = z.refine_names(*names)
8428*da0073e9SAndroid Build Coastguard Worker        z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(
8429*da0073e9SAndroid Build Coastguard Worker            *names[:-1]
8430*da0073e9SAndroid Build Coastguard Worker        )
8431*da0073e9SAndroid Build Coastguard Worker        z_complex.sum().abs().backward()
8432*da0073e9SAndroid Build Coastguard Worker        expected = torch.ones_like(z_complex).rename(None)
8433*da0073e9SAndroid Build Coastguard Worker        abs_1_1j = abs(1 + 1j)
8434*da0073e9SAndroid Build Coastguard Worker        expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2))
8435*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.grad, torch.view_as_real(expected))
8436*da0073e9SAndroid Build Coastguard Worker
8437*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_return_view_in_nograd(self):
8438*da0073e9SAndroid Build Coastguard Worker        class Alias(Function):
8439*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8440*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
8441*da0073e9SAndroid Build Coastguard Worker                return x[:]
8442*da0073e9SAndroid Build Coastguard Worker
8443*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8444*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gx):
8445*da0073e9SAndroid Build Coastguard Worker                return gx
8446*da0073e9SAndroid Build Coastguard Worker
8447*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2, requires_grad=True)
8448*da0073e9SAndroid Build Coastguard Worker
8449*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
8450*da0073e9SAndroid Build Coastguard Worker            output = Alias.apply(inp)
8451*da0073e9SAndroid Build Coastguard Worker
8452*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
8453*da0073e9SAndroid Build Coastguard Worker            expected_output = inp[:]
8454*da0073e9SAndroid Build Coastguard Worker
8455*da0073e9SAndroid Build Coastguard Worker        # Calling the custom function should operate as if we called an equivalent op
8456*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(output.requires_grad, expected_output.requires_grad)
8457*da0073e9SAndroid Build Coastguard Worker
8458*da0073e9SAndroid Build Coastguard Worker        # Check that in-place modification on view throws
8459*da0073e9SAndroid Build Coastguard Worker        leaf_grad_err = (
8460*da0073e9SAndroid Build Coastguard Worker            "A view was created in no_grad mode and is being modified inplace"
8461*da0073e9SAndroid Build Coastguard Worker        )
8462*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, leaf_grad_err):
8463*da0073e9SAndroid Build Coastguard Worker            output.zero_()
8464*da0073e9SAndroid Build Coastguard Worker
8465*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_preserve_torch_function_when_return_as_is(self):
8466*da0073e9SAndroid Build Coastguard Worker        class Custom(torch.Tensor):
8467*da0073e9SAndroid Build Coastguard Worker            def __init__(self, data):
8468*da0073e9SAndroid Build Coastguard Worker                super().__init__()
8469*da0073e9SAndroid Build Coastguard Worker                self._data = data
8470*da0073e9SAndroid Build Coastguard Worker
8471*da0073e9SAndroid Build Coastguard Worker            @classmethod
8472*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, func, types, args=(), kwargs=None):
8473*da0073e9SAndroid Build Coastguard Worker                kwargs = {} if kwargs is None else kwargs
8474*da0073e9SAndroid Build Coastguard Worker                args = tuple(a._data if isinstance(a, cls) else a for a in args)
8475*da0073e9SAndroid Build Coastguard Worker                out = func(*args, **kwargs)
8476*da0073e9SAndroid Build Coastguard Worker                if isinstance(out, torch.Tensor):
8477*da0073e9SAndroid Build Coastguard Worker                    out = cls(out)
8478*da0073e9SAndroid Build Coastguard Worker                return out
8479*da0073e9SAndroid Build Coastguard Worker
8480*da0073e9SAndroid Build Coastguard Worker        class Fn(torch.autograd.Function):
8481*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8482*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
8483*da0073e9SAndroid Build Coastguard Worker                return input
8484*da0073e9SAndroid Build Coastguard Worker
8485*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8486*da0073e9SAndroid Build Coastguard Worker            def backward(ctx):
8487*da0073e9SAndroid Build Coastguard Worker                pass
8488*da0073e9SAndroid Build Coastguard Worker
8489*da0073e9SAndroid Build Coastguard Worker        x = Custom(torch.randn(2, 3))
8490*da0073e9SAndroid Build Coastguard Worker        y = Fn.apply(x)
8491*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(isinstance(y, Custom))
8492*da0073e9SAndroid Build Coastguard Worker
8493*da0073e9SAndroid Build Coastguard Worker    def test_grad_mode_restored_reentrant(self):
8494*da0073e9SAndroid Build Coastguard Worker        class MyFunction(Function):
8495*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8496*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp):
8497*da0073e9SAndroid Build Coastguard Worker                return inp.clone()
8498*da0073e9SAndroid Build Coastguard Worker
8499*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8500*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, go):
8501*da0073e9SAndroid Build Coastguard Worker                original = torch._C.is_grad_enabled()
8502*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
8503*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch._C.is_grad_enabled())
8504*da0073e9SAndroid Build Coastguard Worker                    foo = torch.rand(go.size(), requires_grad=True)
8505*da0073e9SAndroid Build Coastguard Worker                    (grad,) = torch.autograd.grad(foo**3, foo, grad_outputs=go)
8506*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(torch._C.is_grad_enabled())
8507*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch._C.is_grad_enabled() == original)
8508*da0073e9SAndroid Build Coastguard Worker                return grad
8509*da0073e9SAndroid Build Coastguard Worker
8510*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(3, requires_grad=True)
8511*da0073e9SAndroid Build Coastguard Worker
8512*da0073e9SAndroid Build Coastguard Worker        # Case where original==False
8513*da0073e9SAndroid Build Coastguard Worker        MyFunction.apply(inp).sum().backward()
8514*da0073e9SAndroid Build Coastguard Worker        # Case where original==True
8515*da0073e9SAndroid Build Coastguard Worker        MyFunction.apply(inp).sum().backward(create_graph=True)
8516*da0073e9SAndroid Build Coastguard Worker
8517*da0073e9SAndroid Build Coastguard Worker    def test_power_function(self):
8518*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([0.0, 0.0, 0.0])
8519*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
8520*da0073e9SAndroid Build Coastguard Worker        c = torch.sum(a**b)
8521*da0073e9SAndroid Build Coastguard Worker        c.backward()
8522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0]))
8523*da0073e9SAndroid Build Coastguard Worker
8524*da0073e9SAndroid Build Coastguard Worker        s = 0
8525*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True)
8526*da0073e9SAndroid Build Coastguard Worker        c = torch.sum(s**b)
8527*da0073e9SAndroid Build Coastguard Worker        c.backward()
8528*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0]))
8529*da0073e9SAndroid Build Coastguard Worker
8530*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_error(self):
8531*da0073e9SAndroid Build Coastguard Worker        class BadFw(Function):
8532*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8533*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, foo):
8534*da0073e9SAndroid Build Coastguard Worker                return foo
8535*da0073e9SAndroid Build Coastguard Worker
8536*da0073e9SAndroid Build Coastguard Worker        class BadBw(Function):
8537*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8538*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo):
8539*da0073e9SAndroid Build Coastguard Worker                return foo.clone()
8540*da0073e9SAndroid Build Coastguard Worker
8541*da0073e9SAndroid Build Coastguard Worker        class BadBw2(Function):
8542*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8543*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo):
8544*da0073e9SAndroid Build Coastguard Worker                return foo.clone()
8545*da0073e9SAndroid Build Coastguard Worker
8546*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8547*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, foo):
8548*da0073e9SAndroid Build Coastguard Worker                return foo
8549*da0073e9SAndroid Build Coastguard Worker
8550*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8551*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, foo):
8552*da0073e9SAndroid Build Coastguard Worker                return foo
8553*da0073e9SAndroid Build Coastguard Worker
8554*da0073e9SAndroid Build Coastguard Worker        class BadJvp(Function):
8555*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8556*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo):
8557*da0073e9SAndroid Build Coastguard Worker                return foo.clone()
8558*da0073e9SAndroid Build Coastguard Worker
8559*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(1, requires_grad=True)
8560*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(NotImplementedError, "must implement the forward"):
8561*da0073e9SAndroid Build Coastguard Worker            BadFw.apply(inp)
8562*da0073e9SAndroid Build Coastguard Worker
8563*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must implement either the backward"):
8564*da0073e9SAndroid Build Coastguard Worker            BadBw.apply(inp).sum().backward()
8565*da0073e9SAndroid Build Coastguard Worker
8566*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8567*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Implementing both 'backward' and 'vjp'"
8568*da0073e9SAndroid Build Coastguard Worker        ):
8569*da0073e9SAndroid Build Coastguard Worker            BadBw2.apply(inp).sum().backward()
8570*da0073e9SAndroid Build Coastguard Worker
8571*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"):
8572*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
8573*da0073e9SAndroid Build Coastguard Worker                d = fwAD.make_dual(inp, torch.rand_like(inp))
8574*da0073e9SAndroid Build Coastguard Worker                res = BadJvp.apply(d)
8575*da0073e9SAndroid Build Coastguard Worker
8576*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_forward_mode_view_checks(self):
8577*da0073e9SAndroid Build Coastguard Worker        flag_to_error = {
8578*da0073e9SAndroid Build Coastguard Worker            "ok": None,
8579*da0073e9SAndroid Build Coastguard Worker            "not_a_view": "jvp is not returning a view",
8580*da0073e9SAndroid Build Coastguard Worker            "not_a_view_of_inp": "jvp is not returning a view of the given",
8581*da0073e9SAndroid Build Coastguard Worker            "not_a_view_of_inp_base": "jvp is not returning a view of the same base",
8582*da0073e9SAndroid Build Coastguard Worker        }
8583*da0073e9SAndroid Build Coastguard Worker
8584*da0073e9SAndroid Build Coastguard Worker        class ViewFn(Function):
8585*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8586*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo, flag):
8587*da0073e9SAndroid Build Coastguard Worker                ctx.flag = flag
8588*da0073e9SAndroid Build Coastguard Worker                ctx.size = foo.size()
8589*da0073e9SAndroid Build Coastguard Worker                return foo.narrow(0, 0, 2)
8590*da0073e9SAndroid Build Coastguard Worker
8591*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8592*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, gO):
8593*da0073e9SAndroid Build Coastguard Worker                gI = gO.new_zeros(ctx.size)
8594*da0073e9SAndroid Build Coastguard Worker                gI.narrow(0, 0, 2).copy_(gO)
8595*da0073e9SAndroid Build Coastguard Worker                return gI, None
8596*da0073e9SAndroid Build Coastguard Worker
8597*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8598*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, gI, _):
8599*da0073e9SAndroid Build Coastguard Worker                res = gI.narrow(0, 0, 2)
8600*da0073e9SAndroid Build Coastguard Worker                if ctx.flag != "ok":
8601*da0073e9SAndroid Build Coastguard Worker                    # Break the view in the gradients!
8602*da0073e9SAndroid Build Coastguard Worker                    res = res.clone()
8603*da0073e9SAndroid Build Coastguard Worker                if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]:
8604*da0073e9SAndroid Build Coastguard Worker                    # Result should be a view, just of the wrong thing
8605*da0073e9SAndroid Build Coastguard Worker                    res = res.view_as(res)
8606*da0073e9SAndroid Build Coastguard Worker                return res
8607*da0073e9SAndroid Build Coastguard Worker
8608*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
8609*da0073e9SAndroid Build Coastguard Worker
8610*da0073e9SAndroid Build Coastguard Worker        for flag, msg in flag_to_error.items():
8611*da0073e9SAndroid Build Coastguard Worker
8612*da0073e9SAndroid Build Coastguard Worker            def test_fn(inp):
8613*da0073e9SAndroid Build Coastguard Worker                if flag == "not_a_view_of_inp_base":
8614*da0073e9SAndroid Build Coastguard Worker                    inp = inp.view_as(inp)
8615*da0073e9SAndroid Build Coastguard Worker                return ViewFn.apply(inp, flag)
8616*da0073e9SAndroid Build Coastguard Worker
8617*da0073e9SAndroid Build Coastguard Worker            if msg is None:
8618*da0073e9SAndroid Build Coastguard Worker                gradcheck(test_fn, inp, check_forward_ad=True)
8619*da0073e9SAndroid Build Coastguard Worker            else:
8620*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, msg):
8621*da0073e9SAndroid Build Coastguard Worker                    gradcheck(test_fn, inp, check_forward_ad=True)
8622*da0073e9SAndroid Build Coastguard Worker
8623*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_forward_mode_inplace_checks(self):
8624*da0073e9SAndroid Build Coastguard Worker        class InplaceFn(Function):
8625*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8626*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo, flag):
8627*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(foo)
8628*da0073e9SAndroid Build Coastguard Worker                ctx.flag = flag
8629*da0073e9SAndroid Build Coastguard Worker                foo.mul_(2)
8630*da0073e9SAndroid Build Coastguard Worker                return foo
8631*da0073e9SAndroid Build Coastguard Worker
8632*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8633*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, gO):
8634*da0073e9SAndroid Build Coastguard Worker                return 2 * gO, None
8635*da0073e9SAndroid Build Coastguard Worker
8636*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8637*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, gI, _):
8638*da0073e9SAndroid Build Coastguard Worker                if ctx.flag:
8639*da0073e9SAndroid Build Coastguard Worker                    # Don't do the change inplace
8640*da0073e9SAndroid Build Coastguard Worker                    return 2 * gI
8641*da0073e9SAndroid Build Coastguard Worker                else:
8642*da0073e9SAndroid Build Coastguard Worker                    gI.mul_(2)
8643*da0073e9SAndroid Build Coastguard Worker                    return gI
8644*da0073e9SAndroid Build Coastguard Worker
8645*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
8646*da0073e9SAndroid Build Coastguard Worker
8647*da0073e9SAndroid Build Coastguard Worker        def test_fn(inp, flag):
8648*da0073e9SAndroid Build Coastguard Worker            inp = inp.clone()
8649*da0073e9SAndroid Build Coastguard Worker            return InplaceFn.apply(inp, flag)
8650*da0073e9SAndroid Build Coastguard Worker
8651*da0073e9SAndroid Build Coastguard Worker        gradcheck(test_fn, (inp, False), check_forward_ad=True)
8652*da0073e9SAndroid Build Coastguard Worker
8653*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8654*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
8655*da0073e9SAndroid Build Coastguard Worker            "inplace custom Function is not modifying the forward mode gradients inplace",
8656*da0073e9SAndroid Build Coastguard Worker        ):
8657*da0073e9SAndroid Build Coastguard Worker            gradcheck(test_fn, (inp, True), check_forward_ad=True)
8658*da0073e9SAndroid Build Coastguard Worker
8659*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_forward_mode_wrong_formula(self):
8660*da0073e9SAndroid Build Coastguard Worker        class UserFn(Function):
8661*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8662*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, foo, should_fail):
8663*da0073e9SAndroid Build Coastguard Worker                ctx.should_fail = should_fail
8664*da0073e9SAndroid Build Coastguard Worker                return foo * 2
8665*da0073e9SAndroid Build Coastguard Worker
8666*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8667*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, gO):
8668*da0073e9SAndroid Build Coastguard Worker                return 2 * gO, None
8669*da0073e9SAndroid Build Coastguard Worker
8670*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8671*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, gI, _):
8672*da0073e9SAndroid Build Coastguard Worker                if ctx.should_fail:
8673*da0073e9SAndroid Build Coastguard Worker                    # Wrong gradient formula
8674*da0073e9SAndroid Build Coastguard Worker                    return 3 * gI
8675*da0073e9SAndroid Build Coastguard Worker                else:
8676*da0073e9SAndroid Build Coastguard Worker                    return 2 * gI
8677*da0073e9SAndroid Build Coastguard Worker
8678*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(10, dtype=torch.double, requires_grad=True)
8679*da0073e9SAndroid Build Coastguard Worker        gradcheck(UserFn.apply, (inp, False), check_forward_ad=True)
8680*da0073e9SAndroid Build Coastguard Worker
8681*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
8682*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Jacobian computed with forward mode mismatch for output 0"
8683*da0073e9SAndroid Build Coastguard Worker        ):
8684*da0073e9SAndroid Build Coastguard Worker            gradcheck(UserFn.apply, (inp, True), check_forward_ad=True)
8685*da0073e9SAndroid Build Coastguard Worker
8686*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_forward_mode_non_tensor_before_tensor_args(self):
8687*da0073e9SAndroid Build Coastguard Worker        class MyFn(torch.autograd.Function):
8688*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8689*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, nt, x, nt2, y):
8690*da0073e9SAndroid Build Coastguard Worker                return x * 2 + y * 3
8691*da0073e9SAndroid Build Coastguard Worker
8692*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8693*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, nt, x_t, nt2, y_t):
8694*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(nt)
8695*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(nt2)
8696*da0073e9SAndroid Build Coastguard Worker                return x_t * 2 + y_t * 3
8697*da0073e9SAndroid Build Coastguard Worker
8698*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1.0, dtype=torch.double)
8699*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1.0, dtype=torch.double)
8700*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(1.0, dtype=torch.double)
8701*da0073e9SAndroid Build Coastguard Worker
8702*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8703*da0073e9SAndroid Build Coastguard Worker            dual_x = fwAD.make_dual(x, t)
8704*da0073e9SAndroid Build Coastguard Worker            MyFn.apply(1, dual_x, 1, y)
8705*da0073e9SAndroid Build Coastguard Worker
8706*da0073e9SAndroid Build Coastguard Worker        gradcheck(
8707*da0073e9SAndroid Build Coastguard Worker            MyFn.apply,
8708*da0073e9SAndroid Build Coastguard Worker            (1, x.requires_grad_(True), 1, y.requires_grad_(True)),
8709*da0073e9SAndroid Build Coastguard Worker            check_forward_ad=True,
8710*da0073e9SAndroid Build Coastguard Worker            check_backward_ad=False,
8711*da0073e9SAndroid Build Coastguard Worker            check_batched_grad=False,
8712*da0073e9SAndroid Build Coastguard Worker        )
8713*da0073e9SAndroid Build Coastguard Worker
8714*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_forward_mode_forward_is_no_op(self):
8715*da0073e9SAndroid Build Coastguard Worker        error_regex = (
8716*da0073e9SAndroid Build Coastguard Worker            "A custom Function's forward is returning a view \\(or an input as-is\\)"
8717*da0073e9SAndroid Build Coastguard Worker        )
8718*da0073e9SAndroid Build Coastguard Worker
8719*da0073e9SAndroid Build Coastguard Worker        return_lambdas = {
8720*da0073e9SAndroid Build Coastguard Worker            # If we return an input as-is in forward, that is treated
8721*da0073e9SAndroid Build Coastguard Worker            # as if self.view_as(self) is performed. If jvp returns x.view_as(x),
8722*da0073e9SAndroid Build Coastguard Worker            # this is OK.
8723*da0073e9SAndroid Build Coastguard Worker            "view_as": lambda x: x.view_as(x),
8724*da0073e9SAndroid Build Coastguard Worker            # Expect this to raise an error
8725*da0073e9SAndroid Build Coastguard Worker            "self": lambda x: x,
8726*da0073e9SAndroid Build Coastguard Worker            # Expect this to raise the same error
8727*da0073e9SAndroid Build Coastguard Worker            "mul_by_2": lambda x: x * 2,
8728*da0073e9SAndroid Build Coastguard Worker        }
8729*da0073e9SAndroid Build Coastguard Worker
8730*da0073e9SAndroid Build Coastguard Worker        for k, fn in return_lambdas.items():
8731*da0073e9SAndroid Build Coastguard Worker
8732*da0073e9SAndroid Build Coastguard Worker            class MyFn(torch.autograd.Function):
8733*da0073e9SAndroid Build Coastguard Worker                @staticmethod
8734*da0073e9SAndroid Build Coastguard Worker                def forward(ctx, x, y):
8735*da0073e9SAndroid Build Coastguard Worker                    return x + y, x
8736*da0073e9SAndroid Build Coastguard Worker
8737*da0073e9SAndroid Build Coastguard Worker                @staticmethod
8738*da0073e9SAndroid Build Coastguard Worker                def vjp(ctx, gO1, gO2):
8739*da0073e9SAndroid Build Coastguard Worker                    return gO1 + gO2, gO1
8740*da0073e9SAndroid Build Coastguard Worker
8741*da0073e9SAndroid Build Coastguard Worker                @staticmethod
8742*da0073e9SAndroid Build Coastguard Worker                def jvp(ctx, x_t, y_t):
8743*da0073e9SAndroid Build Coastguard Worker                    return x_t + y_t, fn(x_t)
8744*da0073e9SAndroid Build Coastguard Worker
8745*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
8746*da0073e9SAndroid Build Coastguard Worker            t = torch.tensor(1.0, dtype=torch.double)
8747*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor(1.0, dtype=torch.double, requires_grad=True)
8748*da0073e9SAndroid Build Coastguard Worker
8749*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor(1.0, dtype=torch.double)
8750*da0073e9SAndroid Build Coastguard Worker            t2 = torch.tensor(1.0, dtype=torch.double)
8751*da0073e9SAndroid Build Coastguard Worker            d = torch.tensor(1.0, dtype=torch.double)
8752*da0073e9SAndroid Build Coastguard Worker
8753*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
8754*da0073e9SAndroid Build Coastguard Worker                a_dual = fwAD.make_dual(a, t)
8755*da0073e9SAndroid Build Coastguard Worker                c_dual = fwAD.make_dual(c, t2)
8756*da0073e9SAndroid Build Coastguard Worker
8757*da0073e9SAndroid Build Coastguard Worker                if k == "view_as":
8758*da0073e9SAndroid Build Coastguard Worker                    _, out2 = MyFn.apply(a_dual, b)
8759*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t)
8760*da0073e9SAndroid Build Coastguard Worker
8761*da0073e9SAndroid Build Coastguard Worker                    _, out2 = MyFn.apply(c_dual, d)
8762*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t2)
8763*da0073e9SAndroid Build Coastguard Worker                else:
8764*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, error_regex):
8765*da0073e9SAndroid Build Coastguard Worker                        MyFn.apply(a_dual, b)
8766*da0073e9SAndroid Build Coastguard Worker
8767*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, error_regex):
8768*da0073e9SAndroid Build Coastguard Worker                        MyFn.apply(c_dual, d)
8769*da0073e9SAndroid Build Coastguard Worker
8770*da0073e9SAndroid Build Coastguard Worker            if k == "view_as":
8771*da0073e9SAndroid Build Coastguard Worker                gradcheck(MyFn.apply, (a, c), check_forward_ad=True)
8772*da0073e9SAndroid Build Coastguard Worker            else:
8773*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, error_regex):
8774*da0073e9SAndroid Build Coastguard Worker                    gradcheck(MyFn.apply, (a, c), check_forward_ad=True)
8775*da0073e9SAndroid Build Coastguard Worker
8776*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_save_for_forward(self):
8777*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8778*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8779*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
8780*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x, y)
8781*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_forward(x, y)
8782*da0073e9SAndroid Build Coastguard Worker                ctx.z = z
8783*da0073e9SAndroid Build Coastguard Worker                ctx.prod = x * y
8784*da0073e9SAndroid Build Coastguard Worker                return z * ctx.prod
8785*da0073e9SAndroid Build Coastguard Worker
8786*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8787*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_t, y_t, _):
8788*da0073e9SAndroid Build Coastguard Worker                x_p, y_p = ctx.saved_tensors
8789*da0073e9SAndroid Build Coastguard Worker                z = ctx.z
8790*da0073e9SAndroid Build Coastguard Worker                return z * (y_p * x_t + x_p * y_t)
8791*da0073e9SAndroid Build Coastguard Worker
8792*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8793*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, grad_out):
8794*da0073e9SAndroid Build Coastguard Worker                x, y = ctx.saved_tensors
8795*da0073e9SAndroid Build Coastguard Worker                z = ctx.z
8796*da0073e9SAndroid Build Coastguard Worker                return z * grad_out * y, z * grad_out * x, None
8797*da0073e9SAndroid Build Coastguard Worker
8798*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True, dtype=torch.double)
8799*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(1.0, dtype=torch.double)
8800*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(2.0, requires_grad=True, dtype=torch.double)
8801*da0073e9SAndroid Build Coastguard Worker        c = 4
8802*da0073e9SAndroid Build Coastguard Worker
8803*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8804*da0073e9SAndroid Build Coastguard Worker            a_dual = fwAD.make_dual(a, t)
8805*da0073e9SAndroid Build Coastguard Worker            out = Func.apply(a_dual, b, c)
8806*da0073e9SAndroid Build Coastguard Worker            out.backward()
8807*da0073e9SAndroid Build Coastguard Worker
8808*da0073e9SAndroid Build Coastguard Worker        gradcheck(Func.apply, (a, b, c), check_forward_ad=True)
8809*da0073e9SAndroid Build Coastguard Worker
8810*da0073e9SAndroid Build Coastguard Worker        # When saved for backward, but not saved for forward
8811*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8812*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8813*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x: torch.Tensor):
8814*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
8815*da0073e9SAndroid Build Coastguard Worker                return x.clone()
8816*da0073e9SAndroid Build Coastguard Worker
8817*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8818*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_t):
8819*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(ctx.saved_tensors), 0)
8820*da0073e9SAndroid Build Coastguard Worker                return x_t
8821*da0073e9SAndroid Build Coastguard Worker
8822*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8823*da0073e9SAndroid Build Coastguard Worker            def vjp(ctx, grad_out):
8824*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
8825*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(ctx.saved_tensors), 1)
8826*da0073e9SAndroid Build Coastguard Worker                return grad_out
8827*da0073e9SAndroid Build Coastguard Worker
8828*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8829*da0073e9SAndroid Build Coastguard Worker            a_dual = fwAD.make_dual(a, t)
8830*da0073e9SAndroid Build Coastguard Worker            out = Func.apply(a_dual)
8831*da0073e9SAndroid Build Coastguard Worker            out.backward()
8832*da0073e9SAndroid Build Coastguard Worker
8833*da0073e9SAndroid Build Coastguard Worker        gradcheck(Func.apply, (a,), check_forward_ad=True)
8834*da0073e9SAndroid Build Coastguard Worker
8835*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py")
8836*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_forward_mode_non_differentiable(self):
8837*da0073e9SAndroid Build Coastguard Worker        # returns differentiable type, marked non-differentiable
8838*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8839*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8840*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
8841*da0073e9SAndroid Build Coastguard Worker                out = y.clone()
8842*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(out)
8843*da0073e9SAndroid Build Coastguard Worker                return x.clone(), out
8844*da0073e9SAndroid Build Coastguard Worker
8845*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8846*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_tangent, y_tangent):
8847*da0073e9SAndroid Build Coastguard Worker                return x_tangent, None
8848*da0073e9SAndroid Build Coastguard Worker
8849*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2.0)
8850*da0073e9SAndroid Build Coastguard Worker        x_tangent = torch.tensor(1.0)
8851*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(3.0)
8852*da0073e9SAndroid Build Coastguard Worker
8853*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8854*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(x, x_tangent)
8855*da0073e9SAndroid Build Coastguard Worker            _, out2_dual = Func.apply(x_dual, y)
8856*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None)
8857*da0073e9SAndroid Build Coastguard Worker
8858*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(3)
8859*da0073e9SAndroid Build Coastguard Worker
8860*da0073e9SAndroid Build Coastguard Worker        # returns non-differentiable type, NOT marked non-differentiable
8861*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8862*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8863*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
8864*da0073e9SAndroid Build Coastguard Worker                return x.clone(), y.clone()
8865*da0073e9SAndroid Build Coastguard Worker
8866*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8867*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_tangent, y_tangent):
8868*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(y_tangent)
8869*da0073e9SAndroid Build Coastguard Worker                return x_tangent, None
8870*da0073e9SAndroid Build Coastguard Worker
8871*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8872*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(x, x_tangent)
8873*da0073e9SAndroid Build Coastguard Worker            _, out2_dual = Func.apply(x_dual, y)
8874*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None)
8875*da0073e9SAndroid Build Coastguard Worker
8876*da0073e9SAndroid Build Coastguard Worker        class FuncWrong(torch.autograd.Function):
8877*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8878*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
8879*da0073e9SAndroid Build Coastguard Worker                out = y.clone()
8880*da0073e9SAndroid Build Coastguard Worker                ctx.mark_non_differentiable(out)
8881*da0073e9SAndroid Build Coastguard Worker                return x.clone(), out
8882*da0073e9SAndroid Build Coastguard Worker
8883*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8884*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_tangent, y_tangent):
8885*da0073e9SAndroid Build Coastguard Worker                return x_tangent, x_tangent.clone()
8886*da0073e9SAndroid Build Coastguard Worker
8887*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8888*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(x, x_tangent)
8889*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
8890*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "You should return None at that position instead"
8891*da0073e9SAndroid Build Coastguard Worker            ):
8892*da0073e9SAndroid Build Coastguard Worker                FuncWrong.apply(x_dual, y)
8893*da0073e9SAndroid Build Coastguard Worker
8894*da0073e9SAndroid Build Coastguard Worker        # returns non-tensor
8895*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
8896*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8897*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
8898*da0073e9SAndroid Build Coastguard Worker                return x.clone(), object(), x.clone()
8899*da0073e9SAndroid Build Coastguard Worker
8900*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8901*da0073e9SAndroid Build Coastguard Worker            def jvp(ctx, x_tangent):
8902*da0073e9SAndroid Build Coastguard Worker                return x_tangent, None, x_tangent
8903*da0073e9SAndroid Build Coastguard Worker
8904*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
8905*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(x, x_tangent)
8906*da0073e9SAndroid Build Coastguard Worker            out_dual, _, out2_dual = Func.apply(x_dual)
8907*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(out_dual).tangent, x_tangent)
8908*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, x_tangent)
8909*da0073e9SAndroid Build Coastguard Worker
8910*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_local_inplace(self):
8911*da0073e9SAndroid Build Coastguard Worker        class MyFn(torch.autograd.Function):
8912*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8913*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp, inplace):
8914*da0073e9SAndroid Build Coastguard Worker                view = inp.clone()[:3]
8915*da0073e9SAndroid Build Coastguard Worker                if inplace:
8916*da0073e9SAndroid Build Coastguard Worker                    view += 2
8917*da0073e9SAndroid Build Coastguard Worker                return view
8918*da0073e9SAndroid Build Coastguard Worker
8919*da0073e9SAndroid Build Coastguard Worker            @staticmethod
8920*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
8921*da0073e9SAndroid Build Coastguard Worker                return grad, None
8922*da0073e9SAndroid Build Coastguard Worker
8923*da0073e9SAndroid Build Coastguard Worker        base = torch.rand(10, requires_grad=True)
8924*da0073e9SAndroid Build Coastguard Worker
8925*da0073e9SAndroid Build Coastguard Worker        foo = MyFn.apply(base, False)
8926*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward")
8927*da0073e9SAndroid Build Coastguard Worker
8928*da0073e9SAndroid Build Coastguard Worker        foo = MyFn.apply(base, True)
8929*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward")
8930*da0073e9SAndroid Build Coastguard Worker
8931*da0073e9SAndroid Build Coastguard Worker    def test_integer_outputs(self):
8932*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(4, requires_grad=True)
8933*da0073e9SAndroid Build Coastguard Worker
8934*da0073e9SAndroid Build Coastguard Worker        out = inp.argmax()
8935*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.dtype.is_floating_point)
8936*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
8937*da0073e9SAndroid Build Coastguard Worker
8938*da0073e9SAndroid Build Coastguard Worker        out = inp.argmin()
8939*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.dtype.is_floating_point)
8940*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
8941*da0073e9SAndroid Build Coastguard Worker
8942*da0073e9SAndroid Build Coastguard Worker        out = inp.argsort()
8943*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.dtype.is_floating_point)
8944*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
8945*da0073e9SAndroid Build Coastguard Worker
8946*da0073e9SAndroid Build Coastguard Worker        val = torch.rand((), requires_grad=True)
8947*da0073e9SAndroid Build Coastguard Worker
8948*da0073e9SAndroid Build Coastguard Worker        out = torch.searchsorted(inp, val)
8949*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.dtype.is_floating_point)
8950*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
8951*da0073e9SAndroid Build Coastguard Worker
8952*da0073e9SAndroid Build Coastguard Worker        bins = torch.linspace(0, 1.0, steps=100, requires_grad=True)
8953*da0073e9SAndroid Build Coastguard Worker        vals = torch.rand(5, 5, requires_grad=True)
8954*da0073e9SAndroid Build Coastguard Worker        out = torch.bucketize(vals, bins)
8955*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.dtype.is_floating_point)
8956*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
8957*da0073e9SAndroid Build Coastguard Worker
8958*da0073e9SAndroid Build Coastguard Worker        val = torch.empty(5).requires_grad_()
8959*da0073e9SAndroid Build Coastguard Worker        out = val.count_nonzero()
8960*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
8961*da0073e9SAndroid Build Coastguard Worker
8962*da0073e9SAndroid Build Coastguard Worker        def assert_only_first_requires_grad(res):
8963*da0073e9SAndroid Build Coastguard Worker            if not isinstance(res, tuple):
8964*da0073e9SAndroid Build Coastguard Worker                res = (res,)
8965*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(res[0].requires_grad)
8966*da0073e9SAndroid Build Coastguard Worker            for out in res[1:]:
8967*da0073e9SAndroid Build Coastguard Worker                if out is not None:
8968*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(out.requires_grad)
8969*da0073e9SAndroid Build Coastguard Worker
8970*da0073e9SAndroid Build Coastguard Worker        for sort in [True, False]:
8971*da0073e9SAndroid Build Coastguard Worker            for return_inverse in [True, False]:
8972*da0073e9SAndroid Build Coastguard Worker                for return_counts in [True, False]:
8973*da0073e9SAndroid Build Coastguard Worker                    res = torch.unique(
8974*da0073e9SAndroid Build Coastguard Worker                        inp,
8975*da0073e9SAndroid Build Coastguard Worker                        sorted=sort,
8976*da0073e9SAndroid Build Coastguard Worker                        return_inverse=return_inverse,
8977*da0073e9SAndroid Build Coastguard Worker                        return_counts=return_counts,
8978*da0073e9SAndroid Build Coastguard Worker                    )
8979*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
8980*da0073e9SAndroid Build Coastguard Worker
8981*da0073e9SAndroid Build Coastguard Worker                    res = torch.unique(
8982*da0073e9SAndroid Build Coastguard Worker                        inp,
8983*da0073e9SAndroid Build Coastguard Worker                        sorted=sort,
8984*da0073e9SAndroid Build Coastguard Worker                        return_inverse=return_inverse,
8985*da0073e9SAndroid Build Coastguard Worker                        return_counts=return_counts,
8986*da0073e9SAndroid Build Coastguard Worker                        dim=0,
8987*da0073e9SAndroid Build Coastguard Worker                    )
8988*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
8989*da0073e9SAndroid Build Coastguard Worker
8990*da0073e9SAndroid Build Coastguard Worker                    res = torch.unique_consecutive(
8991*da0073e9SAndroid Build Coastguard Worker                        inp, return_inverse=return_inverse, return_counts=return_counts
8992*da0073e9SAndroid Build Coastguard Worker                    )
8993*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
8994*da0073e9SAndroid Build Coastguard Worker
8995*da0073e9SAndroid Build Coastguard Worker                    res = torch.unique_consecutive(
8996*da0073e9SAndroid Build Coastguard Worker                        inp,
8997*da0073e9SAndroid Build Coastguard Worker                        return_inverse=return_inverse,
8998*da0073e9SAndroid Build Coastguard Worker                        return_counts=return_counts,
8999*da0073e9SAndroid Build Coastguard Worker                        dim=0,
9000*da0073e9SAndroid Build Coastguard Worker                    )
9001*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
9002*da0073e9SAndroid Build Coastguard Worker
9003*da0073e9SAndroid Build Coastguard Worker                    # Here we test the internal functions to make sure all of them are
9004*da0073e9SAndroid Build Coastguard Worker                    # covered on top of the public API
9005*da0073e9SAndroid Build Coastguard Worker                    res = torch._unique(inp, sorted=sort, return_inverse=return_inverse)
9006*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
9007*da0073e9SAndroid Build Coastguard Worker
9008*da0073e9SAndroid Build Coastguard Worker                    # This looks public but is actually manually deleted from the
9009*da0073e9SAndroid Build Coastguard Worker                    # torch namespace in torch/functional.py
9010*da0073e9SAndroid Build Coastguard Worker                    res = torch._VF.unique_dim(
9011*da0073e9SAndroid Build Coastguard Worker                        inp,
9012*da0073e9SAndroid Build Coastguard Worker                        dim=0,
9013*da0073e9SAndroid Build Coastguard Worker                        sorted=sort,
9014*da0073e9SAndroid Build Coastguard Worker                        return_inverse=return_inverse,
9015*da0073e9SAndroid Build Coastguard Worker                        return_counts=return_counts,
9016*da0073e9SAndroid Build Coastguard Worker                    )
9017*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
9018*da0073e9SAndroid Build Coastguard Worker
9019*da0073e9SAndroid Build Coastguard Worker                    # We don't test `unique_dim_consecutive` here.
9020*da0073e9SAndroid Build Coastguard Worker                    # It looks public but the python binding is actually manually disabled in
9021*da0073e9SAndroid Build Coastguard Worker                    # tools/autograd/gen_python_functions.py
9022*da0073e9SAndroid Build Coastguard Worker
9023*da0073e9SAndroid Build Coastguard Worker                    res = torch._unique2(
9024*da0073e9SAndroid Build Coastguard Worker                        inp,
9025*da0073e9SAndroid Build Coastguard Worker                        sorted=sort,
9026*da0073e9SAndroid Build Coastguard Worker                        return_inverse=return_inverse,
9027*da0073e9SAndroid Build Coastguard Worker                        return_counts=return_counts,
9028*da0073e9SAndroid Build Coastguard Worker                    )
9029*da0073e9SAndroid Build Coastguard Worker                    assert_only_first_requires_grad(res)
9030*da0073e9SAndroid Build Coastguard Worker
9031*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_cycle(self):
9032*da0073e9SAndroid Build Coastguard Worker        class MyFn(Function):
9033*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9034*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, metadata):
9035*da0073e9SAndroid Build Coastguard Worker                x = x.clone()
9036*da0073e9SAndroid Build Coastguard Worker                ctx.meta = metadata
9037*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
9038*da0073e9SAndroid Build Coastguard Worker                return x
9039*da0073e9SAndroid Build Coastguard Worker
9040*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9041*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
9042*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
9043*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x, 3.14)
9044*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(ctx.meta["foo"], 3.14)
9045*da0073e9SAndroid Build Coastguard Worker                return gO * x, None
9046*da0073e9SAndroid Build Coastguard Worker
9047*da0073e9SAndroid Build Coastguard Worker        def get_refs(with_backward):
9048*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(3.14, requires_grad=True)
9049*da0073e9SAndroid Build Coastguard Worker
9050*da0073e9SAndroid Build Coastguard Worker            metadata = {}
9051*da0073e9SAndroid Build Coastguard Worker            out = MyFn.apply(a, metadata)
9052*da0073e9SAndroid Build Coastguard Worker
9053*da0073e9SAndroid Build Coastguard Worker            metadata["foo"] = out
9054*da0073e9SAndroid Build Coastguard Worker
9055*da0073e9SAndroid Build Coastguard Worker            if with_backward:
9056*da0073e9SAndroid Build Coastguard Worker                out.sum().backward()
9057*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.grad, a)
9058*da0073e9SAndroid Build Coastguard Worker
9059*da0073e9SAndroid Build Coastguard Worker            return torch._C._WeakTensorRef(out)
9060*da0073e9SAndroid Build Coastguard Worker
9061*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
9062*da0073e9SAndroid Build Coastguard Worker            ref = get_refs(False)
9063*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(ref.expired())
9064*da0073e9SAndroid Build Coastguard Worker        gc.collect()
9065*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref.expired())
9066*da0073e9SAndroid Build Coastguard Worker
9067*da0073e9SAndroid Build Coastguard Worker        # The backward clears the saved_variables but not the __dict__
9068*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
9069*da0073e9SAndroid Build Coastguard Worker            ref = get_refs(True)
9070*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(ref.expired())
9071*da0073e9SAndroid Build Coastguard Worker        gc.collect()
9072*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(ref.expired())
9073*da0073e9SAndroid Build Coastguard Worker
9074*da0073e9SAndroid Build Coastguard Worker    def test_create_graph_and_full_backward_hook_cycle(self):
9075*da0073e9SAndroid Build Coastguard Worker        # If BackwardHook saves grad_output, it can create a cycle when we perform backward
9076*da0073e9SAndroid Build Coastguard Worker        # with create_graph=True
9077*da0073e9SAndroid Build Coastguard Worker        #
9078*da0073e9SAndroid Build Coastguard Worker        #   grad_output -> grad_output.grad_fn -> graph -> hook -> grad_output
9079*da0073e9SAndroid Build Coastguard Worker        #
9080*da0073e9SAndroid Build Coastguard Worker        class TestCls:
9081*da0073e9SAndroid Build Coastguard Worker            # Dummy class for the purpose of creating a weakref
9082*da0073e9SAndroid Build Coastguard Worker            pass
9083*da0073e9SAndroid Build Coastguard Worker
9084*da0073e9SAndroid Build Coastguard Worker        def get_ref(input_requires_grad, nb_hooks):
9085*da0073e9SAndroid Build Coastguard Worker            t = torch.randn(10, requires_grad=input_requires_grad)
9086*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1.0, requires_grad=True)
9087*da0073e9SAndroid Build Coastguard Worker
9088*da0073e9SAndroid Build Coastguard Worker            class Test(nn.Module):
9089*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
9090*da0073e9SAndroid Build Coastguard Worker                    return x**2 * a**2
9091*da0073e9SAndroid Build Coastguard Worker
9092*da0073e9SAndroid Build Coastguard Worker            mod = Test()
9093*da0073e9SAndroid Build Coastguard Worker
9094*da0073e9SAndroid Build Coastguard Worker            for _ in range(nb_hooks):
9095*da0073e9SAndroid Build Coastguard Worker                mod.register_full_backward_hook(lambda a, b, c: None)
9096*da0073e9SAndroid Build Coastguard Worker
9097*da0073e9SAndroid Build Coastguard Worker            tmp = mod(t)
9098*da0073e9SAndroid Build Coastguard Worker
9099*da0073e9SAndroid Build Coastguard Worker            # Save dummy object to graph and get a weak ref to it
9100*da0073e9SAndroid Build Coastguard Worker            test = TestCls()
9101*da0073e9SAndroid Build Coastguard Worker            ref = weakref.ref(test)
9102*da0073e9SAndroid Build Coastguard Worker            tmp.grad_fn.metadata["a"] = test
9103*da0073e9SAndroid Build Coastguard Worker
9104*da0073e9SAndroid Build Coastguard Worker            with set_warn_always_context(True):
9105*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as w:
9106*da0073e9SAndroid Build Coastguard Worker                    tmp.exp().sum().backward(create_graph=True)
9107*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(len(w) == 1)
9108*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
9109*da0073e9SAndroid Build Coastguard Worker                        "Using backward() with create_graph=True" in str(w[0].message)
9110*da0073e9SAndroid Build Coastguard Worker                    )
9111*da0073e9SAndroid Build Coastguard Worker
9112*da0073e9SAndroid Build Coastguard Worker            # Remove the backward + create_graph=True cycle
9113*da0073e9SAndroid Build Coastguard Worker            a.grad = None
9114*da0073e9SAndroid Build Coastguard Worker            t.grad = None
9115*da0073e9SAndroid Build Coastguard Worker
9116*da0073e9SAndroid Build Coastguard Worker            return ref
9117*da0073e9SAndroid Build Coastguard Worker
9118*da0073e9SAndroid Build Coastguard Worker        for nb_hooks in (1, 2, 3):
9119*da0073e9SAndroid Build Coastguard Worker            for input_requires_grad in (True, False):
9120*da0073e9SAndroid Build Coastguard Worker                ref_ = get_ref(
9121*da0073e9SAndroid Build Coastguard Worker                    input_requires_grad=input_requires_grad,
9122*da0073e9SAndroid Build Coastguard Worker                    nb_hooks=nb_hooks,
9123*da0073e9SAndroid Build Coastguard Worker                )
9124*da0073e9SAndroid Build Coastguard Worker                gc.collect()
9125*da0073e9SAndroid Build Coastguard Worker                self.assertIsNone(ref_())
9126*da0073e9SAndroid Build Coastguard Worker
9127*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_custom_function", [True, False])
9128*da0073e9SAndroid Build Coastguard Worker    @parametrize("use_tensor_hook", [True, False])
9129*da0073e9SAndroid Build Coastguard Worker    def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook):
9130*da0073e9SAndroid Build Coastguard Worker        # This creates a cycle between the hook and grad_fn_b
9131*da0073e9SAndroid Build Coastguard Worker        # hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp)
9132*da0073e9SAndroid Build Coastguard Worker        # -> dict -> hook
9133*da0073e9SAndroid Build Coastguard Worker        #
9134*da0073e9SAndroid Build Coastguard Worker        # This test is testing that the grad_fn_b (python) only traverses the
9135*da0073e9SAndroid Build Coastguard Worker        # dict if it is the only one holding a reference to the grad_fn_b (cpp)
9136*da0073e9SAndroid Build Coastguard Worker        # shared_ptr
9137*da0073e9SAndroid Build Coastguard Worker        #
9138*da0073e9SAndroid Build Coastguard Worker        # See: https://github.com/pytorch/pytorch/issues/102174
9139*da0073e9SAndroid Build Coastguard Worker        class Function(torch.autograd.Function):
9140*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9141*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
9142*da0073e9SAndroid Build Coastguard Worker                return x
9143*da0073e9SAndroid Build Coastguard Worker
9144*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9145*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
9146*da0073e9SAndroid Build Coastguard Worker                return grad
9147*da0073e9SAndroid Build Coastguard Worker
9148*da0073e9SAndroid Build Coastguard Worker        class Test:
9149*da0073e9SAndroid Build Coastguard Worker            pass
9150*da0073e9SAndroid Build Coastguard Worker
9151*da0073e9SAndroid Build Coastguard Worker        count = [0]
9152*da0073e9SAndroid Build Coastguard Worker
9153*da0073e9SAndroid Build Coastguard Worker        def scope():
9154*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor(1.0, requires_grad=True)
9155*da0073e9SAndroid Build Coastguard Worker            if use_custom_function:
9156*da0073e9SAndroid Build Coastguard Worker                b = Function.apply(a)
9157*da0073e9SAndroid Build Coastguard Worker            else:
9158*da0073e9SAndroid Build Coastguard Worker                b = a.clone()
9159*da0073e9SAndroid Build Coastguard Worker            grad_fn_b = b.grad_fn
9160*da0073e9SAndroid Build Coastguard Worker            obj = Test()
9161*da0073e9SAndroid Build Coastguard Worker
9162*da0073e9SAndroid Build Coastguard Worker            def hook(*args):
9163*da0073e9SAndroid Build Coastguard Worker                # Make sure this hook's closure holds onto grad_fn_b
9164*da0073e9SAndroid Build Coastguard Worker                # This forms a cycle between the hook and grad_fn_b
9165*da0073e9SAndroid Build Coastguard Worker                # We also hold onto a sentinel object 'obj' to track
9166*da0073e9SAndroid Build Coastguard Worker                # whether this cycle is still alive. See 'ref' below.
9167*da0073e9SAndroid Build Coastguard Worker                grad_fn_b
9168*da0073e9SAndroid Build Coastguard Worker                obj
9169*da0073e9SAndroid Build Coastguard Worker                count[0] += 1
9170*da0073e9SAndroid Build Coastguard Worker
9171*da0073e9SAndroid Build Coastguard Worker            if use_tensor_hook:
9172*da0073e9SAndroid Build Coastguard Worker                b.register_hook(hook)
9173*da0073e9SAndroid Build Coastguard Worker            else:
9174*da0073e9SAndroid Build Coastguard Worker                b.grad_fn.register_hook(hook)
9175*da0073e9SAndroid Build Coastguard Worker            c = b.clone()
9176*da0073e9SAndroid Build Coastguard Worker            ref = weakref.ref(obj)
9177*da0073e9SAndroid Build Coastguard Worker            return c, ref
9178*da0073e9SAndroid Build Coastguard Worker
9179*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
9180*da0073e9SAndroid Build Coastguard Worker            out, ref = scope()
9181*da0073e9SAndroid Build Coastguard Worker            out.backward(retain_graph=True)
9182*da0073e9SAndroid Build Coastguard Worker
9183*da0073e9SAndroid Build Coastguard Worker            gc.collect()
9184*da0073e9SAndroid Build Coastguard Worker
9185*da0073e9SAndroid Build Coastguard Worker            # Make sure gc does not clear the cycle noted above.
9186*da0073e9SAndroid Build Coastguard Worker            # e.g. the hook is alive and gets fired even after gc runs
9187*da0073e9SAndroid Build Coastguard Worker            out.backward(retain_graph=True)
9188*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(count[0], 2)
9189*da0073e9SAndroid Build Coastguard Worker
9190*da0073e9SAndroid Build Coastguard Worker            # ref is still alive because the use_count of the cpp grad_fn
9191*da0073e9SAndroid Build Coastguard Worker            # shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the
9192*da0073e9SAndroid Build Coastguard Worker            # rest of the graph holds onto the shared_ptr
9193*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(ref())
9194*da0073e9SAndroid Build Coastguard Worker
9195*da0073e9SAndroid Build Coastguard Worker            # Then delete the rest of the graph and check that ref is dead
9196*da0073e9SAndroid Build Coastguard Worker            del out
9197*da0073e9SAndroid Build Coastguard Worker            gc.collect()
9198*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(ref())
9199*da0073e9SAndroid Build Coastguard Worker
9200*da0073e9SAndroid Build Coastguard Worker    def test_full_backward_hook_double_backward(self):
9201*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(1, requires_grad=True)
9202*da0073e9SAndroid Build Coastguard Worker        y = torch.rand_like(x)
9203*da0073e9SAndroid Build Coastguard Worker
9204*da0073e9SAndroid Build Coastguard Worker        func = torch.nn.MSELoss()
9205*da0073e9SAndroid Build Coastguard Worker        counter = [0]
9206*da0073e9SAndroid Build Coastguard Worker
9207*da0073e9SAndroid Build Coastguard Worker        def hook(module, grad_input, grad_output):
9208*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
9209*da0073e9SAndroid Build Coastguard Worker
9210*da0073e9SAndroid Build Coastguard Worker        func.register_full_backward_hook(hook)
9211*da0073e9SAndroid Build Coastguard Worker
9212*da0073e9SAndroid Build Coastguard Worker        f = func(x, y)
9213*da0073e9SAndroid Build Coastguard Worker
9214*da0073e9SAndroid Build Coastguard Worker        (gradx_f,) = torch.autograd.grad(f, x, create_graph=True)
9215*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
9216*da0073e9SAndroid Build Coastguard Worker        _ = torch.autograd.grad(gradx_f, x)
9217*da0073e9SAndroid Build Coastguard Worker        # We should not error, and counter should not be incremented
9218*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
9219*da0073e9SAndroid Build Coastguard Worker
9220*da0073e9SAndroid Build Coastguard Worker    def test_input_buffer_accum(self):
9221*da0073e9SAndroid Build Coastguard Worker        leaf = torch.rand(2, 2, requires_grad=True)
9222*da0073e9SAndroid Build Coastguard Worker
9223*da0073e9SAndroid Build Coastguard Worker        # An op that returns sparse gradients
9224*da0073e9SAndroid Build Coastguard Worker        ind = torch.tensor([[0, 0]], dtype=torch.long)
9225*da0073e9SAndroid Build Coastguard Worker        out2 = leaf.gather(0, ind, sparse_grad=True)
9226*da0073e9SAndroid Build Coastguard Worker
9227*da0073e9SAndroid Build Coastguard Worker        # An op that returns the gradients as-is
9228*da0073e9SAndroid Build Coastguard Worker        out1 = leaf.clone()
9229*da0073e9SAndroid Build Coastguard Worker
9230*da0073e9SAndroid Build Coastguard Worker        grad_out1_original = torch.rand_like(out1)
9231*da0073e9SAndroid Build Coastguard Worker        grad_out1 = grad_out1_original.clone()
9232*da0073e9SAndroid Build Coastguard Worker        grad_out2 = torch.rand_like(out2)
9233*da0073e9SAndroid Build Coastguard Worker
9234*da0073e9SAndroid Build Coastguard Worker        torch.autograd.backward((out1, out2), (grad_out1, grad_out2))
9235*da0073e9SAndroid Build Coastguard Worker
9236*da0073e9SAndroid Build Coastguard Worker        # Given gradients should not be modified inplace
9237*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad_out1, grad_out1_original)
9238*da0073e9SAndroid Build Coastguard Worker
9239*da0073e9SAndroid Build Coastguard Worker    def test_no_unnecessary_unwrapping(self):
9240*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
9241*da0073e9SAndroid Build Coastguard Worker        a_orig = a.detach().clone()
9242*da0073e9SAndroid Build Coastguard Worker        b = a * a
9243*da0073e9SAndroid Build Coastguard Worker        c = a * b
9244*da0073e9SAndroid Build Coastguard Worker        d = torch.exp(a)
9245*da0073e9SAndroid Build Coastguard Worker
9246*da0073e9SAndroid Build Coastguard Worker        # a is leaf
9247*da0073e9SAndroid Build Coastguard Worker        self.assertIs(b.grad_fn._saved_self, a)
9248*da0073e9SAndroid Build Coastguard Worker        self.assertIs(b.grad_fn._saved_other, a)
9249*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c.grad_fn._saved_self, a)
9250*da0073e9SAndroid Build Coastguard Worker
9251*da0073e9SAndroid Build Coastguard Worker        # b is not an output
9252*da0073e9SAndroid Build Coastguard Worker        self.assertIs(c.grad_fn._saved_other, b)
9253*da0073e9SAndroid Build Coastguard Worker
9254*da0073e9SAndroid Build Coastguard Worker        # d is an output
9255*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(d.grad_fn._saved_result, d)
9256*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(d.grad_fn._saved_result, d)
9257*da0073e9SAndroid Build Coastguard Worker
9258*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
9259*da0073e9SAndroid Build Coastguard Worker
9260*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "after they have already been freed"):
9261*da0073e9SAndroid Build Coastguard Worker            c.grad_fn._saved_self
9262*da0073e9SAndroid Build Coastguard Worker
9263*da0073e9SAndroid Build Coastguard Worker        # a is left untouched
9264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, a_orig)
9265*da0073e9SAndroid Build Coastguard Worker
9266*da0073e9SAndroid Build Coastguard Worker    def test_saved_variable_version_counter(self):
9267*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, requires_grad=True)
9268*da0073e9SAndroid Build Coastguard Worker
9269*da0073e9SAndroid Build Coastguard Worker        b = torch.exp(a)
9270*da0073e9SAndroid Build Coastguard Worker
9271*da0073e9SAndroid Build Coastguard Worker        b_unpacked = b.grad_fn._saved_result
9272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, b_unpacked)
9273*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b._version, b_unpacked._version)
9274*da0073e9SAndroid Build Coastguard Worker
9275*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
9276*da0073e9SAndroid Build Coastguard Worker            b += 1
9277*da0073e9SAndroid Build Coastguard Worker
9278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b, b_unpacked)
9279*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b._version, b_unpacked._version)
9280*da0073e9SAndroid Build Coastguard Worker
9281*da0073e9SAndroid Build Coastguard Worker    def test_saved_variable_packing_unpacking_saved_original_with_hooks(self):
9282*da0073e9SAndroid Build Coastguard Worker        # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
9283*da0073e9SAndroid Build Coastguard Worker        # The saved_original / did_not_save_original distinction corresponds to the `save_original`
9284*da0073e9SAndroid Build Coastguard Worker        # attribute of `SavedVariable`.
9285*da0073e9SAndroid Build Coastguard Worker
9286*da0073e9SAndroid Build Coastguard Worker        def test(get_input, is_leaf):
9287*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9288*da0073e9SAndroid Build Coastguard Worker            grad_fn = a.grad_fn
9289*da0073e9SAndroid Build Coastguard Worker            y = a * a
9290*da0073e9SAndroid Build Coastguard Worker            y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x / 2)
9291*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, y.grad_fn._saved_self)
9292*da0073e9SAndroid Build Coastguard Worker            if not is_leaf:
9293*da0073e9SAndroid Build Coastguard Worker                self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
9294*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9295*da0073e9SAndroid Build Coastguard Worker            else:
9296*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9297*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(2 * a, a.grad)
9298*da0073e9SAndroid Build Coastguard Worker
9299*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9300*da0073e9SAndroid Build Coastguard Worker            grad_fn = a.grad_fn
9301*da0073e9SAndroid Build Coastguard Worker            y = a * a
9302*da0073e9SAndroid Build Coastguard Worker            y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x)
9303*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(2 * a, y.grad_fn._saved_self)
9304*da0073e9SAndroid Build Coastguard Worker            if not is_leaf:
9305*da0073e9SAndroid Build Coastguard Worker                self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
9306*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9307*da0073e9SAndroid Build Coastguard Worker            else:
9308*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9309*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(3 * a, a.grad)
9310*da0073e9SAndroid Build Coastguard Worker
9311*da0073e9SAndroid Build Coastguard Worker            # double backward
9312*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9313*da0073e9SAndroid Build Coastguard Worker            grad_fn = a.grad_fn
9314*da0073e9SAndroid Build Coastguard Worker            y = a**3
9315*da0073e9SAndroid Build Coastguard Worker            y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x)
9316*da0073e9SAndroid Build Coastguard Worker            s = torch.sum(y)
9317*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9318*da0073e9SAndroid Build Coastguard Worker            if not is_leaf:
9319*da0073e9SAndroid Build Coastguard Worker                self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn)
9320*da0073e9SAndroid Build Coastguard Worker                g.sum().backward()
9321*da0073e9SAndroid Build Coastguard Worker            else:
9322*da0073e9SAndroid Build Coastguard Worker                g.sum().backward()
9323*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(6 * a, a.grad)
9324*da0073e9SAndroid Build Coastguard Worker
9325*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9326*da0073e9SAndroid Build Coastguard Worker            y = a * a
9327*da0073e9SAndroid Build Coastguard Worker            y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: 1)
9328*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
9329*da0073e9SAndroid Build Coastguard Worker                TypeError, "Output of saved tensor unpack_hook expected to be a Tensor"
9330*da0073e9SAndroid Build Coastguard Worker            ):
9331*da0073e9SAndroid Build Coastguard Worker                print(y.grad_fn._saved_self)
9332*da0073e9SAndroid Build Coastguard Worker
9333*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9334*da0073e9SAndroid Build Coastguard Worker            y = a * a
9335*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
9336*da0073e9SAndroid Build Coastguard Worker                TypeError, "missing 1 required positional argument"
9337*da0073e9SAndroid Build Coastguard Worker            ):
9338*da0073e9SAndroid Build Coastguard Worker                y.grad_fn._raw_saved_self.register_hooks(lambda x, b: x, lambda x: x)
9339*da0073e9SAndroid Build Coastguard Worker
9340*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9341*da0073e9SAndroid Build Coastguard Worker            y = a * a
9342*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
9343*da0073e9SAndroid Build Coastguard Worker                TypeError, "missing 1 required positional argument"
9344*da0073e9SAndroid Build Coastguard Worker            ):
9345*da0073e9SAndroid Build Coastguard Worker                y.grad_fn._raw_saved_self.register_hooks(
9346*da0073e9SAndroid Build Coastguard Worker                    lambda x, b: (x, b), lambda x: x
9347*da0073e9SAndroid Build Coastguard Worker                )
9348*da0073e9SAndroid Build Coastguard Worker
9349*da0073e9SAndroid Build Coastguard Worker            def inplace_double(x):
9350*da0073e9SAndroid Build Coastguard Worker                x *= 2
9351*da0073e9SAndroid Build Coastguard Worker                return x
9352*da0073e9SAndroid Build Coastguard Worker
9353*da0073e9SAndroid Build Coastguard Worker            a = get_input()
9354*da0073e9SAndroid Build Coastguard Worker            t = a * a
9355*da0073e9SAndroid Build Coastguard Worker
9356*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
9357*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
9358*da0073e9SAndroid Build Coastguard Worker                "A saved tensor pack hook is modifying its input in place.",
9359*da0073e9SAndroid Build Coastguard Worker            ):
9360*da0073e9SAndroid Build Coastguard Worker                t.grad_fn._raw_saved_self.register_hooks(
9361*da0073e9SAndroid Build Coastguard Worker                    inplace_double, lambda x: x / 2
9362*da0073e9SAndroid Build Coastguard Worker                )
9363*da0073e9SAndroid Build Coastguard Worker
9364*da0073e9SAndroid Build Coastguard Worker        # leaf
9365*da0073e9SAndroid Build Coastguard Worker        test(lambda: torch.randn(5, requires_grad=True), True)
9366*da0073e9SAndroid Build Coastguard Worker
9367*da0073e9SAndroid Build Coastguard Worker        # not leaf, not output
9368*da0073e9SAndroid Build Coastguard Worker        test(lambda: (1 + torch.randn(5, requires_grad=True)), False)
9369*da0073e9SAndroid Build Coastguard Worker
9370*da0073e9SAndroid Build Coastguard Worker    def test_saved_variable_saved_original_inplace_detach(self):
9371*da0073e9SAndroid Build Coastguard Worker        # Detaching a tensor that is saved input raises
9372*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True).clone()
9373*da0073e9SAndroid Build Coastguard Worker        b = a.sin()
9374*da0073e9SAndroid Build Coastguard Worker        a.detach_()
9375*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
9376*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "Trying to use a saved tensor that has been detached"
9377*da0073e9SAndroid Build Coastguard Worker        ):
9378*da0073e9SAndroid Build Coastguard Worker            b.backward()
9379*da0073e9SAndroid Build Coastguard Worker
9380*da0073e9SAndroid Build Coastguard Worker        # Detaching a tensor that is saved as output is OK
9381*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True).clone()
9382*da0073e9SAndroid Build Coastguard Worker        b = a.exp()
9383*da0073e9SAndroid Build Coastguard Worker        a.detach_()
9384*da0073e9SAndroid Build Coastguard Worker        b.backward()
9385*da0073e9SAndroid Build Coastguard Worker
9386*da0073e9SAndroid Build Coastguard Worker    def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self):
9387*da0073e9SAndroid Build Coastguard Worker        # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks
9388*da0073e9SAndroid Build Coastguard Worker        # The saved_original / did_not_save_original distinction corresponds to the `save_original`
9389*da0073e9SAndroid Build Coastguard Worker        # attribute of `SavedVariable`.
9390*da0073e9SAndroid Build Coastguard Worker
9391*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
9392*da0073e9SAndroid Build Coastguard Worker        y = torch.exp(a)
9393*da0073e9SAndroid Build Coastguard Worker        y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x)
9394*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y.grad_fn._saved_result)
9395*da0073e9SAndroid Build Coastguard Worker        self.assertIs(y.grad_fn, y.grad_fn._saved_result.grad_fn)
9396*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
9397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, y)
9398*da0073e9SAndroid Build Coastguard Worker
9399*da0073e9SAndroid Build Coastguard Worker    def test_saved_variable_packing_unpacking_saved_original_with_default_hooks(self):
9400*da0073e9SAndroid Build Coastguard Worker        # Tests that default hooks are properly registered, used and reset
9401*da0073e9SAndroid Build Coastguard Worker        # The saved_original / did_not_save_original distinction corresponds to the `save_original`
9402*da0073e9SAndroid Build Coastguard Worker        # attribute of `SavedVariable`.
9403*da0073e9SAndroid Build Coastguard Worker        # See also:
9404*da0073e9SAndroid Build Coastguard Worker        #  - test_saved_variable_packing_unpacking_saved_original_with_hooks
9405*da0073e9SAndroid Build Coastguard Worker
9406*da0073e9SAndroid Build Coastguard Worker        def pack(x):
9407*da0073e9SAndroid Build Coastguard Worker            warnings.warn("pack")
9408*da0073e9SAndroid Build Coastguard Worker            return x
9409*da0073e9SAndroid Build Coastguard Worker
9410*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
9411*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(5, requires_grad=True)
9412*da0073e9SAndroid Build Coastguard Worker
9413*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
9414*da0073e9SAndroid Build Coastguard Worker                warnings.simplefilter("always")
9415*da0073e9SAndroid Build Coastguard Worker                y = a * a
9416*da0073e9SAndroid Build Coastguard Worker                # should raise two warnings from a being saved twice
9417*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(w), 2)
9418*da0073e9SAndroid Build Coastguard Worker
9419*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9420*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9421*da0073e9SAndroid Build Coastguard Worker            y = a * a
9422*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, y.grad_fn._saved_self)
9423*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, y.grad_fn._saved_other)
9424*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
9425*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(2 * a, a.grad)
9426*da0073e9SAndroid Build Coastguard Worker
9427*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x / 2):
9428*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9429*da0073e9SAndroid Build Coastguard Worker            y = a * a
9430*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, y.grad_fn._saved_self)
9431*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a, y.grad_fn._saved_other)
9432*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
9433*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(2 * a, a.grad)
9434*da0073e9SAndroid Build Coastguard Worker
9435*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9436*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9437*da0073e9SAndroid Build Coastguard Worker            y = a * a
9438*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(2 * a, y.grad_fn._saved_self)
9439*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(2 * a, y.grad_fn._saved_other)
9440*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
9441*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(4 * a, a.grad)
9442*da0073e9SAndroid Build Coastguard Worker
9443*da0073e9SAndroid Build Coastguard Worker        # Exited hooks correctly
9444*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
9445*da0073e9SAndroid Build Coastguard Worker        y = a * a
9446*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, y.grad_fn._saved_self)
9447*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a, y.grad_fn._saved_other)
9448*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
9449*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2 * a, a.grad)
9450*da0073e9SAndroid Build Coastguard Worker
9451*da0073e9SAndroid Build Coastguard Worker    def test_saved_variable_packing_unpacking_did_not_save_original_with_default_hooks(
9452*da0073e9SAndroid Build Coastguard Worker        self,
9453*da0073e9SAndroid Build Coastguard Worker    ):
9454*da0073e9SAndroid Build Coastguard Worker        # See also test_saved_variable_packing_unpacking_did_not_save_original_with_hooks
9455*da0073e9SAndroid Build Coastguard Worker
9456*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9457*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9458*da0073e9SAndroid Build Coastguard Worker            y = torch.exp(a)
9459*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(y, y.grad_fn._saved_result)
9460*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
9461*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(a.grad, y)
9462*da0073e9SAndroid Build Coastguard Worker
9463*da0073e9SAndroid Build Coastguard Worker    def test_setting_default_saved_variable_hooks_twice_should_not_fail(self):
9464*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9465*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9466*da0073e9SAndroid Build Coastguard Worker                pass
9467*da0073e9SAndroid Build Coastguard Worker
9468*da0073e9SAndroid Build Coastguard Worker    def test_setting_default_saved_variable_hooks_twice_should_use_inner(self):
9469*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x):
9470*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(5, requires_grad=True)
9471*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.saved_tensors_hooks(
9472*da0073e9SAndroid Build Coastguard Worker                lambda x: 5 * x, lambda x: 5 * x
9473*da0073e9SAndroid Build Coastguard Worker            ):
9474*da0073e9SAndroid Build Coastguard Worker                a = torch.randn(5, requires_grad=True)
9475*da0073e9SAndroid Build Coastguard Worker                y = a * a
9476*da0073e9SAndroid Build Coastguard Worker            z = b * b
9477*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
9478*da0073e9SAndroid Build Coastguard Worker        z.sum().backward()
9479*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2 * 5 * 5 * a, a.grad)
9480*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(2 * 3 * 3 * b, b.grad)
9481*da0073e9SAndroid Build Coastguard Worker
9482*da0073e9SAndroid Build Coastguard Worker    def test_disabling_saved_tensor_hooks(self):
9483*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.disable_saved_tensors_hooks("error message"):
9484*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "error message"):
9485*da0073e9SAndroid Build Coastguard Worker                with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9486*da0073e9SAndroid Build Coastguard Worker                    pass
9487*da0073e9SAndroid Build Coastguard Worker
9488*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
9489*da0073e9SAndroid Build Coastguard Worker
9490*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9491*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "error message"):
9492*da0073e9SAndroid Build Coastguard Worker                with torch.autograd.graph.disable_saved_tensors_hooks("error message"):
9493*da0073e9SAndroid Build Coastguard Worker                    pass
9494*da0073e9SAndroid Build Coastguard Worker
9495*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
9496*da0073e9SAndroid Build Coastguard Worker
9497*da0073e9SAndroid Build Coastguard Worker    def test_disabling_saved_tensor_hooks_nested(self):
9498*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.disable_saved_tensors_hooks("outer"):
9499*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.disable_saved_tensors_hooks("inner"):
9500*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "inner"):
9501*da0073e9SAndroid Build Coastguard Worker                    with torch.autograd.graph.saved_tensors_hooks(
9502*da0073e9SAndroid Build Coastguard Worker                        lambda x: x, lambda x: x
9503*da0073e9SAndroid Build Coastguard Worker                    ):
9504*da0073e9SAndroid Build Coastguard Worker                        pass
9505*da0073e9SAndroid Build Coastguard Worker
9506*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled())
9507*da0073e9SAndroid Build Coastguard Worker
9508*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled())
9509*da0073e9SAndroid Build Coastguard Worker
9510*da0073e9SAndroid Build Coastguard Worker    def test_saved_tensor_hooks_custom_error_propagation(self):
9511*da0073e9SAndroid Build Coastguard Worker        class CustomError(Exception):
9512*da0073e9SAndroid Build Coastguard Worker            pass
9513*da0073e9SAndroid Build Coastguard Worker
9514*da0073e9SAndroid Build Coastguard Worker        class error_on_pack_hook(torch.autograd.graph.saved_tensors_hooks):
9515*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9516*da0073e9SAndroid Build Coastguard Worker                def pack_hook(x):
9517*da0073e9SAndroid Build Coastguard Worker                    raise CustomError("pack")
9518*da0073e9SAndroid Build Coastguard Worker
9519*da0073e9SAndroid Build Coastguard Worker                super().__init__(pack_hook, lambda x: x)
9520*da0073e9SAndroid Build Coastguard Worker
9521*da0073e9SAndroid Build Coastguard Worker        class error_on_unpack_hook(torch.autograd.graph.saved_tensors_hooks):
9522*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9523*da0073e9SAndroid Build Coastguard Worker                def unpack_hook(x):
9524*da0073e9SAndroid Build Coastguard Worker                    raise CustomError("unpack")
9525*da0073e9SAndroid Build Coastguard Worker
9526*da0073e9SAndroid Build Coastguard Worker                super().__init__(lambda x: x, unpack_hook)
9527*da0073e9SAndroid Build Coastguard Worker
9528*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
9529*da0073e9SAndroid Build Coastguard Worker
9530*da0073e9SAndroid Build Coastguard Worker        with error_on_pack_hook():
9531*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(CustomError, "pack"):
9532*da0073e9SAndroid Build Coastguard Worker                out = torch.sin(a)
9533*da0073e9SAndroid Build Coastguard Worker
9534*da0073e9SAndroid Build Coastguard Worker        with error_on_unpack_hook():
9535*da0073e9SAndroid Build Coastguard Worker            out = torch.sin(a)
9536*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(CustomError, "unpack"):
9537*da0073e9SAndroid Build Coastguard Worker                out.backward()
9538*da0073e9SAndroid Build Coastguard Worker
9539*da0073e9SAndroid Build Coastguard Worker    def test_saved_tensor_hooks_custom_function_intermediates(self):
9540*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
9541*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9542*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
9543*da0073e9SAndroid Build Coastguard Worker                intermediate = x.exp()
9544*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(
9545*da0073e9SAndroid Build Coastguard Worker                    intermediate.clone().detach_().requires_grad_(True)
9546*da0073e9SAndroid Build Coastguard Worker                )
9547*da0073e9SAndroid Build Coastguard Worker                return x.exp()
9548*da0073e9SAndroid Build Coastguard Worker
9549*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9550*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_out):
9551*da0073e9SAndroid Build Coastguard Worker                (intermediate,) = ctx.saved_tensors
9552*da0073e9SAndroid Build Coastguard Worker                return grad_out * intermediate
9553*da0073e9SAndroid Build Coastguard Worker
9554*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
9555*da0073e9SAndroid Build Coastguard Worker
9556*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9557*da0073e9SAndroid Build Coastguard Worker            out = Func.apply(a)
9558*da0073e9SAndroid Build Coastguard Worker        out.backward()
9559*da0073e9SAndroid Build Coastguard Worker
9560*da0073e9SAndroid Build Coastguard Worker    def test_unpack_hooks_exec_count(self):
9561*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
9562*da0073e9SAndroid Build Coastguard Worker            return x * y
9563*da0073e9SAndroid Build Coastguard Worker
9564*da0073e9SAndroid Build Coastguard Worker        pack_count = 0
9565*da0073e9SAndroid Build Coastguard Worker        unpack_count = 0
9566*da0073e9SAndroid Build Coastguard Worker
9567*da0073e9SAndroid Build Coastguard Worker        def pack_hook(x):
9568*da0073e9SAndroid Build Coastguard Worker            nonlocal pack_count
9569*da0073e9SAndroid Build Coastguard Worker            pack_count += 1
9570*da0073e9SAndroid Build Coastguard Worker            return x
9571*da0073e9SAndroid Build Coastguard Worker
9572*da0073e9SAndroid Build Coastguard Worker        # unpack hook shouldn't run during compilation, while we trace the forward
9573*da0073e9SAndroid Build Coastguard Worker        def unpack_hook(x):
9574*da0073e9SAndroid Build Coastguard Worker            nonlocal unpack_count
9575*da0073e9SAndroid Build Coastguard Worker            unpack_count += 1
9576*da0073e9SAndroid Build Coastguard Worker            return x
9577*da0073e9SAndroid Build Coastguard Worker
9578*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(4, requires_grad=True)
9579*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4, requires_grad=False)
9580*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
9581*da0073e9SAndroid Build Coastguard Worker            out_test = f(x, y)
9582*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pack_count, 1)
9583*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpack_count, 0)
9584*da0073e9SAndroid Build Coastguard Worker            out_test.sum().backward()
9585*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(pack_count, 1)
9586*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(unpack_count, 1)
9587*da0073e9SAndroid Build Coastguard Worker
9588*da0073e9SAndroid Build Coastguard Worker    def test_saved_tensors_hook_version_counter_not_shared(self):
9589*da0073e9SAndroid Build Coastguard Worker        class Test(torch.autograd.Function):
9590*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9591*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
9592*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
9593*da0073e9SAndroid Build Coastguard Worker                return x.sin()
9594*da0073e9SAndroid Build Coastguard Worker
9595*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9596*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
9597*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
9598*da0073e9SAndroid Build Coastguard Worker                before = a._version
9599*da0073e9SAndroid Build Coastguard Worker                x.add_(1)
9600*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a._version, before)
9601*da0073e9SAndroid Build Coastguard Worker                return grad_output
9602*da0073e9SAndroid Build Coastguard Worker
9603*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
9604*da0073e9SAndroid Build Coastguard Worker        a_replacement = a.clone()
9605*da0073e9SAndroid Build Coastguard Worker
9606*da0073e9SAndroid Build Coastguard Worker        def pack_hook(x):
9607*da0073e9SAndroid Build Coastguard Worker            return a_replacement
9608*da0073e9SAndroid Build Coastguard Worker
9609*da0073e9SAndroid Build Coastguard Worker        def unpack_hook(x):
9610*da0073e9SAndroid Build Coastguard Worker            return x
9611*da0073e9SAndroid Build Coastguard Worker
9612*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
9613*da0073e9SAndroid Build Coastguard Worker            b = Test.apply(a)
9614*da0073e9SAndroid Build Coastguard Worker
9615*da0073e9SAndroid Build Coastguard Worker        b.backward()
9616*da0073e9SAndroid Build Coastguard Worker
9617*da0073e9SAndroid Build Coastguard Worker    def test_save_on_cpu_and_checkpoint(self):
9618*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, requires_grad=True)
9619*da0073e9SAndroid Build Coastguard Worker
9620*da0073e9SAndroid Build Coastguard Worker        b = a.pow(2).pow(2).pow(2).pow(2)
9621*da0073e9SAndroid Build Coastguard Worker        b.sum().backward()
9622*da0073e9SAndroid Build Coastguard Worker        b_grad = a.grad.clone()
9623*da0073e9SAndroid Build Coastguard Worker        a.grad.zero_()
9624*da0073e9SAndroid Build Coastguard Worker
9625*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.save_on_cpu():
9626*da0073e9SAndroid Build Coastguard Worker            h = a.pow(2)
9627*da0073e9SAndroid Build Coastguard Worker            h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False)
9628*da0073e9SAndroid Build Coastguard Worker            c = h.pow(2)
9629*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
9630*da0073e9SAndroid Build Coastguard Worker        c_grad = a.grad.clone()
9631*da0073e9SAndroid Build Coastguard Worker        a.grad.zero_()
9632*da0073e9SAndroid Build Coastguard Worker
9633*da0073e9SAndroid Build Coastguard Worker        def f(a):
9634*da0073e9SAndroid Build Coastguard Worker            h = a.pow(2)
9635*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.save_on_cpu():
9636*da0073e9SAndroid Build Coastguard Worker                h = h.pow(2).pow(2)
9637*da0073e9SAndroid Build Coastguard Worker            return h.pow(2)
9638*da0073e9SAndroid Build Coastguard Worker
9639*da0073e9SAndroid Build Coastguard Worker        d = checkpoint(f, a, use_reentrant=False)
9640*da0073e9SAndroid Build Coastguard Worker        d.sum().backward()
9641*da0073e9SAndroid Build Coastguard Worker        d_grad = a.grad.clone()
9642*da0073e9SAndroid Build Coastguard Worker
9643*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_grad, c_grad)
9644*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b_grad, d_grad)
9645*da0073e9SAndroid Build Coastguard Worker
9646*da0073e9SAndroid Build Coastguard Worker    def test_pack_hook_with_inplace_modification_should_fail(self):
9647*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
9648*da0073e9SAndroid Build Coastguard Worker
9649*da0073e9SAndroid Build Coastguard Worker        def inc(x):
9650*da0073e9SAndroid Build Coastguard Worker            x += 1
9651*da0073e9SAndroid Build Coastguard Worker            return x
9652*da0073e9SAndroid Build Coastguard Worker
9653*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(inc, lambda x: x):
9654*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
9655*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
9656*da0073e9SAndroid Build Coastguard Worker                "A saved tensor pack hook is modifying its input in place.",
9657*da0073e9SAndroid Build Coastguard Worker            ):
9658*da0073e9SAndroid Build Coastguard Worker                y = torch.exp(a)
9659*da0073e9SAndroid Build Coastguard Worker
9660*da0073e9SAndroid Build Coastguard Worker        y = torch.exp(a)
9661*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
9662*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "A saved tensor pack hook is modifying its input in place."
9663*da0073e9SAndroid Build Coastguard Worker        ):
9664*da0073e9SAndroid Build Coastguard Worker            y.grad_fn._raw_saved_result.register_hooks(inc, lambda x: x)
9665*da0073e9SAndroid Build Coastguard Worker
9666*da0073e9SAndroid Build Coastguard Worker    def test_saving_variable_to_disk(self):
9667*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryDirectory() as tmp_dir:
9668*da0073e9SAndroid Build Coastguard Worker
9669*da0073e9SAndroid Build Coastguard Worker            def pack(x):
9670*da0073e9SAndroid Build Coastguard Worker                name = os.path.join(tmp_dir, str(uuid.uuid4()))
9671*da0073e9SAndroid Build Coastguard Worker                torch.save(x, name)
9672*da0073e9SAndroid Build Coastguard Worker                return name
9673*da0073e9SAndroid Build Coastguard Worker
9674*da0073e9SAndroid Build Coastguard Worker            def unpack(name):
9675*da0073e9SAndroid Build Coastguard Worker                return torch.load(name)
9676*da0073e9SAndroid Build Coastguard Worker
9677*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
9678*da0073e9SAndroid Build Coastguard Worker                a = torch.ones(5, requires_grad=True)
9679*da0073e9SAndroid Build Coastguard Worker                y = a * a
9680*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, y.grad_fn._saved_self)
9681*da0073e9SAndroid Build Coastguard Worker
9682*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9683*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(2 * a, a.grad)
9684*da0073e9SAndroid Build Coastguard Worker
9685*da0073e9SAndroid Build Coastguard Worker    def test_default_saved_tensors_hooks_double_backward(self):
9686*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x):
9687*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9688*da0073e9SAndroid Build Coastguard Worker            y = a**3
9689*da0073e9SAndroid Build Coastguard Worker            s = torch.sum(y)
9690*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9691*da0073e9SAndroid Build Coastguard Worker            g.sum().backward()
9692*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(6 * a, a.grad)
9693*da0073e9SAndroid Build Coastguard Worker
9694*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9695*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9696*da0073e9SAndroid Build Coastguard Worker            y = a**3
9697*da0073e9SAndroid Build Coastguard Worker            s = torch.sum(y)
9698*da0073e9SAndroid Build Coastguard Worker        (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9699*da0073e9SAndroid Build Coastguard Worker        g.sum().backward()
9700*da0073e9SAndroid Build Coastguard Worker        # factor 2 because only a is saved once
9701*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(6 * 2 * a, a.grad)
9702*da0073e9SAndroid Build Coastguard Worker
9703*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, requires_grad=True)
9704*da0073e9SAndroid Build Coastguard Worker        y = a**3
9705*da0073e9SAndroid Build Coastguard Worker        s = torch.sum(y)
9706*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9707*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9708*da0073e9SAndroid Build Coastguard Worker            g.sum().backward()
9709*da0073e9SAndroid Build Coastguard Worker            # factor 4 because pow_backward is grad * (exp * self.pow(exp - 1))
9710*da0073e9SAndroid Build Coastguard Worker            # so grad is saved and self (i.e. a) is saved
9711*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(6 * 4 * a, a.grad)
9712*da0073e9SAndroid Build Coastguard Worker
9713*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x):
9714*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9715*da0073e9SAndroid Build Coastguard Worker            y = a**3
9716*da0073e9SAndroid Build Coastguard Worker            s = torch.sum(y)
9717*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad(s, (a,), create_graph=True)
9718*da0073e9SAndroid Build Coastguard Worker            g.sum().backward()
9719*da0073e9SAndroid Build Coastguard Worker            # combining the two above blocks: 2 * 4 = 8
9720*da0073e9SAndroid Build Coastguard Worker            # note that in that sense, a is saved twice
9721*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(6 * 8 * a, a.grad)
9722*da0073e9SAndroid Build Coastguard Worker
9723*da0073e9SAndroid Build Coastguard Worker    def test_wrapped_number_saved_tensors_hooks(self):
9724*da0073e9SAndroid Build Coastguard Worker        def err_hook(x):
9725*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("this hook should not be called")
9726*da0073e9SAndroid Build Coastguard Worker
9727*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(err_hook, err_hook):
9728*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(5, requires_grad=True)
9729*da0073e9SAndroid Build Coastguard Worker            out = (a * 3).sum()
9730*da0073e9SAndroid Build Coastguard Worker            # 3 is saved as a saved tensor because it is a wrapped number, but
9731*da0073e9SAndroid Build Coastguard Worker            # wrapped numbers should be special cased to not trigger saved variable hooks
9732*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out, (a,))
9733*da0073e9SAndroid Build Coastguard Worker
9734*da0073e9SAndroid Build Coastguard Worker    def test_graph_save_on_cpu(self):
9735*da0073e9SAndroid Build Coastguard Worker        def test(get_input, cuda, pin_memory):
9736*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.graph.save_on_cpu(pin_memory):
9737*da0073e9SAndroid Build Coastguard Worker                a = get_input()
9738*da0073e9SAndroid Build Coastguard Worker                if cuda:
9739*da0073e9SAndroid Build Coastguard Worker                    a.cuda()
9740*da0073e9SAndroid Build Coastguard Worker                y = a * a
9741*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, y.grad_fn._saved_self)
9742*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a, y.grad_fn._saved_other)
9743*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.dtype, y.grad_fn._saved_self.dtype)
9744*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.layout, y.grad_fn._saved_self.layout)
9745*da0073e9SAndroid Build Coastguard Worker                if y.is_sparse:
9746*da0073e9SAndroid Build Coastguard Worker                    y = y.to_dense()
9747*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
9748*da0073e9SAndroid Build Coastguard Worker
9749*da0073e9SAndroid Build Coastguard Worker                actual = 2 * a
9750*da0073e9SAndroid Build Coastguard Worker                expected = a.grad
9751*da0073e9SAndroid Build Coastguard Worker                if a.is_sparse:
9752*da0073e9SAndroid Build Coastguard Worker                    actual = actual.coalesce()
9753*da0073e9SAndroid Build Coastguard Worker                    expected = expected.coalesce()
9754*da0073e9SAndroid Build Coastguard Worker
9755*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(actual, expected)
9756*da0073e9SAndroid Build Coastguard Worker
9757*da0073e9SAndroid Build Coastguard Worker        for cuda in [False] + ([True] if torch.cuda.is_available() else []):
9758*da0073e9SAndroid Build Coastguard Worker            for pin_memory in [True, False]:
9759*da0073e9SAndroid Build Coastguard Worker                # FloatTensor
9760*da0073e9SAndroid Build Coastguard Worker                test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory)
9761*da0073e9SAndroid Build Coastguard Worker                # DoubleTensor
9762*da0073e9SAndroid Build Coastguard Worker                test(
9763*da0073e9SAndroid Build Coastguard Worker                    lambda: torch.randn(5, requires_grad=True, dtype=torch.double),
9764*da0073e9SAndroid Build Coastguard Worker                    cuda,
9765*da0073e9SAndroid Build Coastguard Worker                    pin_memory,
9766*da0073e9SAndroid Build Coastguard Worker                )
9767*da0073e9SAndroid Build Coastguard Worker                # Sparse tensor
9768*da0073e9SAndroid Build Coastguard Worker                x = torch.sparse_coo_tensor(
9769*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([[1, 1]]).long(),
9770*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([1.0, 1.0]),
9771*da0073e9SAndroid Build Coastguard Worker                    requires_grad=True,
9772*da0073e9SAndroid Build Coastguard Worker                )
9773*da0073e9SAndroid Build Coastguard Worker                test(lambda: x, cuda, pin_memory)
9774*da0073e9SAndroid Build Coastguard Worker
9775*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
9776*da0073e9SAndroid Build Coastguard Worker    def test_graph_save_on_cpu_cuda(self):
9777*da0073e9SAndroid Build Coastguard Worker        def f(x):
9778*da0073e9SAndroid Build Coastguard Worker            a = x + 1
9779*da0073e9SAndroid Build Coastguard Worker            return a * a
9780*da0073e9SAndroid Build Coastguard Worker
9781*da0073e9SAndroid Build Coastguard Worker        # with grad
9782*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, requires_grad=True, device="cuda")
9783*da0073e9SAndroid Build Coastguard Worker        y = f(a)
9784*da0073e9SAndroid Build Coastguard Worker        memory_with_grad = torch.cuda.memory_allocated()
9785*da0073e9SAndroid Build Coastguard Worker
9786*da0073e9SAndroid Build Coastguard Worker        del a
9787*da0073e9SAndroid Build Coastguard Worker        del y
9788*da0073e9SAndroid Build Coastguard Worker
9789*da0073e9SAndroid Build Coastguard Worker        # without grad
9790*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, requires_grad=True, device="cuda")
9791*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
9792*da0073e9SAndroid Build Coastguard Worker            y = f(a)
9793*da0073e9SAndroid Build Coastguard Worker        memory_without_grad = torch.cuda.memory_allocated()
9794*da0073e9SAndroid Build Coastguard Worker
9795*da0073e9SAndroid Build Coastguard Worker        self.assertGreater(memory_with_grad, memory_without_grad)
9796*da0073e9SAndroid Build Coastguard Worker
9797*da0073e9SAndroid Build Coastguard Worker        del a
9798*da0073e9SAndroid Build Coastguard Worker        del y
9799*da0073e9SAndroid Build Coastguard Worker
9800*da0073e9SAndroid Build Coastguard Worker        # with hooks
9801*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.save_on_cpu():
9802*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1, requires_grad=True, device="cuda")
9803*da0073e9SAndroid Build Coastguard Worker            y = f(a)
9804*da0073e9SAndroid Build Coastguard Worker            memory_with_hooks = torch.cuda.memory_allocated()
9805*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(memory_with_hooks, memory_without_grad)
9806*da0073e9SAndroid Build Coastguard Worker
9807*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
9808*da0073e9SAndroid Build Coastguard Worker    def test_scalar_grad_mixed_device(self):
9809*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(1.0, requires_grad=True)
9810*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, 2, device="cuda")
9811*da0073e9SAndroid Build Coastguard Worker        out = x * y
9812*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
9813*da0073e9SAndroid Build Coastguard Worker
9814*da0073e9SAndroid Build Coastguard Worker    def test_multi_grad_all_hooks(self):
9815*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(2, requires_grad=True)
9816*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(2, requires_grad=True)
9817*da0073e9SAndroid Build Coastguard Worker        t3 = torch.rand(2, requires_grad=True)
9818*da0073e9SAndroid Build Coastguard Worker        t4 = torch.rand(2, requires_grad=True)
9819*da0073e9SAndroid Build Coastguard Worker
9820*da0073e9SAndroid Build Coastguard Worker        # Ensure we properly detect all types of Nodes here
9821*da0073e9SAndroid Build Coastguard Worker        # C++ Node
9822*da0073e9SAndroid Build Coastguard Worker        t1 = t1.mul(2)
9823*da0073e9SAndroid Build Coastguard Worker
9824*da0073e9SAndroid Build Coastguard Worker        # Python custom Function
9825*da0073e9SAndroid Build Coastguard Worker        class Foo(Function):
9826*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9827*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, a):
9828*da0073e9SAndroid Build Coastguard Worker                return a.clone()
9829*da0073e9SAndroid Build Coastguard Worker
9830*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9831*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
9832*da0073e9SAndroid Build Coastguard Worker                return gO
9833*da0073e9SAndroid Build Coastguard Worker
9834*da0073e9SAndroid Build Coastguard Worker        t2 = Foo.apply(t2)
9835*da0073e9SAndroid Build Coastguard Worker
9836*da0073e9SAndroid Build Coastguard Worker        # C++ Node
9837*da0073e9SAndroid Build Coastguard Worker        t3 = torch._C._functions.UndefinedGrad()(t3)
9838*da0073e9SAndroid Build Coastguard Worker
9839*da0073e9SAndroid Build Coastguard Worker        # C++ Custom Op
9840*da0073e9SAndroid Build Coastguard Worker        cpp_source = """
9841*da0073e9SAndroid Build Coastguard Workerstruct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> {
9842*da0073e9SAndroid Build Coastguard Worker  static torch::Tensor forward(
9843*da0073e9SAndroid Build Coastguard Worker      torch::autograd::AutogradContext* ctx,
9844*da0073e9SAndroid Build Coastguard Worker      const torch::Tensor& x) {
9845*da0073e9SAndroid Build Coastguard Worker    return x.clone();
9846*da0073e9SAndroid Build Coastguard Worker  }
9847*da0073e9SAndroid Build Coastguard Worker
9848*da0073e9SAndroid Build Coastguard Worker  static torch::autograd::variable_list backward(
9849*da0073e9SAndroid Build Coastguard Worker      torch::autograd::AutogradContext *ctx,
9850*da0073e9SAndroid Build Coastguard Worker      torch::autograd::variable_list grad_output) {
9851*da0073e9SAndroid Build Coastguard Worker    return grad_output;
9852*da0073e9SAndroid Build Coastguard Worker  }
9853*da0073e9SAndroid Build Coastguard Worker};
9854*da0073e9SAndroid Build Coastguard Worker
9855*da0073e9SAndroid Build Coastguard Workertorch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) {
9856*da0073e9SAndroid Build Coastguard Worker  return CustomOpAutogradFunction::apply(x);
9857*da0073e9SAndroid Build Coastguard Worker}
9858*da0073e9SAndroid Build Coastguard Worker
9859*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(test_autograd_cpp_node, m) {
9860*da0073e9SAndroid Build Coastguard Worker    m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn);
9861*da0073e9SAndroid Build Coastguard Worker}
9862*da0073e9SAndroid Build Coastguard Worker        """
9863*da0073e9SAndroid Build Coastguard Worker
9864*da0073e9SAndroid Build Coastguard Worker        module = load_inline(
9865*da0073e9SAndroid Build Coastguard Worker            name="test_autograd_cpp_node",
9866*da0073e9SAndroid Build Coastguard Worker            cpp_sources=cpp_source,
9867*da0073e9SAndroid Build Coastguard Worker            functions="custom_op_backed_by_autograd_fn",
9868*da0073e9SAndroid Build Coastguard Worker            verbose=True,
9869*da0073e9SAndroid Build Coastguard Worker        )
9870*da0073e9SAndroid Build Coastguard Worker
9871*da0073e9SAndroid Build Coastguard Worker        t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4)
9872*da0073e9SAndroid Build Coastguard Worker
9873*da0073e9SAndroid Build Coastguard Worker        res = [None] * 4
9874*da0073e9SAndroid Build Coastguard Worker        count = [0]
9875*da0073e9SAndroid Build Coastguard Worker
9876*da0073e9SAndroid Build Coastguard Worker        def hook(grads):
9877*da0073e9SAndroid Build Coastguard Worker            nonlocal res
9878*da0073e9SAndroid Build Coastguard Worker            count[0] += 1
9879*da0073e9SAndroid Build Coastguard Worker            res = [g is not None for g in grads]
9880*da0073e9SAndroid Build Coastguard Worker
9881*da0073e9SAndroid Build Coastguard Worker        handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
9882*da0073e9SAndroid Build Coastguard Worker
9883*da0073e9SAndroid Build Coastguard Worker        out = t2 * t3
9884*da0073e9SAndroid Build Coastguard Worker
9885*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(inputs=(t2, t3), retain_graph=True)
9886*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 1)
9887*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, [False, True, True, False])
9888*da0073e9SAndroid Build Coastguard Worker
9889*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(inputs=(t1, t4), retain_graph=True)
9890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 1)
9891*da0073e9SAndroid Build Coastguard Worker
9892*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(inputs=(t1, t3), retain_graph=True)
9893*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 2)
9894*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, [False, False, True, False])
9895*da0073e9SAndroid Build Coastguard Worker
9896*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
9897*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9898*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
9899*da0073e9SAndroid Build Coastguard Worker                return x
9900*da0073e9SAndroid Build Coastguard Worker
9901*da0073e9SAndroid Build Coastguard Worker            @staticmethod
9902*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
9903*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("error message")
9904*da0073e9SAndroid Build Coastguard Worker
9905*da0073e9SAndroid Build Coastguard Worker        out = Func.apply(t2) * t3
9906*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "error message"):
9907*da0073e9SAndroid Build Coastguard Worker            out.sum().backward(inputs=(t2, t3), retain_graph=True)
9908*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 2)
9909*da0073e9SAndroid Build Coastguard Worker
9910*da0073e9SAndroid Build Coastguard Worker        handle.remove()
9911*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(inputs=(t1, t3), retain_graph=True)
9912*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 2)
9913*da0073e9SAndroid Build Coastguard Worker
9914*da0073e9SAndroid Build Coastguard Worker    def test_multi_grad_any_hooks(self):
9915*da0073e9SAndroid Build Coastguard Worker        hook_id = 0
9916*da0073e9SAndroid Build Coastguard Worker        any_hook_handles: List[RemovableHandle] = []
9917*da0073e9SAndroid Build Coastguard Worker
9918*da0073e9SAndroid Build Coastguard Worker        class MultiOutputModule(nn.Module):
9919*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9920*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9921*da0073e9SAndroid Build Coastguard Worker                self.lin = nn.Linear(3, 3)
9922*da0073e9SAndroid Build Coastguard Worker
9923*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
9924*da0073e9SAndroid Build Coastguard Worker                z = self.lin(x)
9925*da0073e9SAndroid Build Coastguard Worker                out = torch.sin(z), torch.cos(z)
9926*da0073e9SAndroid Build Coastguard Worker                nonlocal hook_id
9927*da0073e9SAndroid Build Coastguard Worker                z.register_hook(partial(hook, hook_id))
9928*da0073e9SAndroid Build Coastguard Worker                hook_id += 1
9929*da0073e9SAndroid Build Coastguard Worker                any_hook_handles.append(
9930*da0073e9SAndroid Build Coastguard Worker                    torch.autograd.graph.register_multi_grad_hook(
9931*da0073e9SAndroid Build Coastguard Worker                        out, partial(hook, hook_id), mode="any"
9932*da0073e9SAndroid Build Coastguard Worker                    )
9933*da0073e9SAndroid Build Coastguard Worker                )
9934*da0073e9SAndroid Build Coastguard Worker                hook_id += 1
9935*da0073e9SAndroid Build Coastguard Worker                return out
9936*da0073e9SAndroid Build Coastguard Worker
9937*da0073e9SAndroid Build Coastguard Worker        class Model(nn.Module):
9938*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
9939*da0073e9SAndroid Build Coastguard Worker                super().__init__()
9940*da0073e9SAndroid Build Coastguard Worker                self.mod1 = MultiOutputModule()
9941*da0073e9SAndroid Build Coastguard Worker                self.mod2 = MultiOutputModule()
9942*da0073e9SAndroid Build Coastguard Worker
9943*da0073e9SAndroid Build Coastguard Worker            def forward(self, x: torch.Tensor) -> torch.Tensor:
9944*da0073e9SAndroid Build Coastguard Worker                y = self.mod1(x)
9945*da0073e9SAndroid Build Coastguard Worker                z = y[0] + y[1]
9946*da0073e9SAndroid Build Coastguard Worker                return self.mod2(z)
9947*da0073e9SAndroid Build Coastguard Worker
9948*da0073e9SAndroid Build Coastguard Worker        hook_order: List[int] = []
9949*da0073e9SAndroid Build Coastguard Worker        hook_count = 0
9950*da0073e9SAndroid Build Coastguard Worker
9951*da0073e9SAndroid Build Coastguard Worker        def hook(hook_id: int, *unused):
9952*da0073e9SAndroid Build Coastguard Worker            nonlocal hook_count
9953*da0073e9SAndroid Build Coastguard Worker            nonlocal hook_order
9954*da0073e9SAndroid Build Coastguard Worker            hook_count += 1
9955*da0073e9SAndroid Build Coastguard Worker            hook_order.append(hook_id)
9956*da0073e9SAndroid Build Coastguard Worker
9957*da0073e9SAndroid Build Coastguard Worker        # Any hooks: IDs 1 and 3; regular hooks: IDs 0 and 2
9958*da0073e9SAndroid Build Coastguard Worker        model = Model()
9959*da0073e9SAndroid Build Coastguard Worker        inp = torch.randn((2, 3))
9960*da0073e9SAndroid Build Coastguard Worker        out = model(inp)
9961*da0073e9SAndroid Build Coastguard Worker        (out[0] + out[1]).sum().backward()
9962*da0073e9SAndroid Build Coastguard Worker        # Check that the any-hook runs only once and before the regular hook
9963*da0073e9SAndroid Build Coastguard Worker        # for each module
9964*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(any_hook_handles), 2)
9965*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hook_order, [3, 2, 1, 0])
9966*da0073e9SAndroid Build Coastguard Worker
9967*da0073e9SAndroid Build Coastguard Worker        hook_id = 0
9968*da0073e9SAndroid Build Coastguard Worker        hook_order.clear()
9969*da0073e9SAndroid Build Coastguard Worker        any_hook_handles.clear()
9970*da0073e9SAndroid Build Coastguard Worker        out = model(inp)
9971*da0073e9SAndroid Build Coastguard Worker        for handle in any_hook_handles:
9972*da0073e9SAndroid Build Coastguard Worker            handle.remove()
9973*da0073e9SAndroid Build Coastguard Worker        (out[0] + out[1]).sum().backward()
9974*da0073e9SAndroid Build Coastguard Worker        # Check that the any-hook does not run if removed
9975*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hook_order, [2, 0])
9976*da0073e9SAndroid Build Coastguard Worker
9977*da0073e9SAndroid Build Coastguard Worker    def test_multi_grad_hooks_invalid_mode(self):
9978*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(2, requires_grad=True)
9979*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(2, requires_grad=True)
9980*da0073e9SAndroid Build Coastguard Worker        regex = r"Expects mode to be one of \('all', 'any'\) but got foo"
9981*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, regex):
9982*da0073e9SAndroid Build Coastguard Worker            torch.autograd.graph.register_multi_grad_hook(
9983*da0073e9SAndroid Build Coastguard Worker                (t1, t2), lambda _: None, mode="foo"
9984*da0073e9SAndroid Build Coastguard Worker            )
9985*da0073e9SAndroid Build Coastguard Worker
9986*da0073e9SAndroid Build Coastguard Worker    def test_pynode_destruction_deadlock(self):
9987*da0073e9SAndroid Build Coastguard Worker        script = """
9988*da0073e9SAndroid Build Coastguard Workerimport torch
9989*da0073e9SAndroid Build Coastguard Worker
9990*da0073e9SAndroid Build Coastguard Workerclass Foo(torch.autograd.Function):
9991*da0073e9SAndroid Build Coastguard Worker    @staticmethod
9992*da0073e9SAndroid Build Coastguard Worker    def forward(ctx, x):
9993*da0073e9SAndroid Build Coastguard Worker        return x.clone()
9994*da0073e9SAndroid Build Coastguard Worker
9995*da0073e9SAndroid Build Coastguard Worker    @staticmethod
9996*da0073e9SAndroid Build Coastguard Worker    def forward(ctx, gO):
9997*da0073e9SAndroid Build Coastguard Worker        return gO.clone()
9998*da0073e9SAndroid Build Coastguard Worker
9999*da0073e9SAndroid Build Coastguard Workerdef get_out():
10000*da0073e9SAndroid Build Coastguard Worker    inp = torch.rand(2, requires_grad=True)
10001*da0073e9SAndroid Build Coastguard Worker
10002*da0073e9SAndroid Build Coastguard Worker    # The python function is first so that it runs
10003*da0073e9SAndroid Build Coastguard Worker    # last in the backward pass
10004*da0073e9SAndroid Build Coastguard Worker    right = Foo.apply(inp)
10005*da0073e9SAndroid Build Coastguard Worker
10006*da0073e9SAndroid Build Coastguard Worker    # An op that creates new memory
10007*da0073e9SAndroid Build Coastguard Worker    left1 = inp.clone()
10008*da0073e9SAndroid Build Coastguard Worker    # An op that saves its input
10009*da0073e9SAndroid Build Coastguard Worker    left2 = left1 ** 2
10010*da0073e9SAndroid Build Coastguard Worker
10011*da0073e9SAndroid Build Coastguard Worker    # Inplace modify so that the backward for
10012*da0073e9SAndroid Build Coastguard Worker    # left2 always raises an error
10013*da0073e9SAndroid Build Coastguard Worker    left1 += 1
10014*da0073e9SAndroid Build Coastguard Worker
10015*da0073e9SAndroid Build Coastguard Worker    # An op that takes both side as input.
10016*da0073e9SAndroid Build Coastguard Worker    # After running, both side's last op will be in
10017*da0073e9SAndroid Build Coastguard Worker    # the ready queue
10018*da0073e9SAndroid Build Coastguard Worker    # And the op for left will run first as it was
10019*da0073e9SAndroid Build Coastguard Worker    # executed last during the forward
10020*da0073e9SAndroid Build Coastguard Worker    out = left2 + right
10021*da0073e9SAndroid Build Coastguard Worker
10022*da0073e9SAndroid Build Coastguard Worker    return out
10023*da0073e9SAndroid Build Coastguard Worker
10024*da0073e9SAndroid Build Coastguard Worker# Nothing should be global variables here as, from what
10025*da0073e9SAndroid Build Coastguard Worker# I can see, python leaks all the global objects
10026*da0073e9SAndroid Build Coastguard Workerget_out().sum().backward()
10027*da0073e9SAndroid Build Coastguard Worker
10028*da0073e9SAndroid Build Coastguard Worker# This used to deadlock when the PyNode is being destroyed after
10029*da0073e9SAndroid Build Coastguard Worker# the error is raised.
10030*da0073e9SAndroid Build Coastguard Worker"""
10031*da0073e9SAndroid Build Coastguard Worker        try:
10032*da0073e9SAndroid Build Coastguard Worker            subprocess.check_output(
10033*da0073e9SAndroid Build Coastguard Worker                [sys.executable, "-c", script],
10034*da0073e9SAndroid Build Coastguard Worker                stderr=subprocess.STDOUT,
10035*da0073e9SAndroid Build Coastguard Worker                # On Windows, opening the subprocess with the default CWD makes `import torch`
10036*da0073e9SAndroid Build Coastguard Worker                # fail, so just set CWD to this script's directory
10037*da0073e9SAndroid Build Coastguard Worker                cwd=os.path.dirname(os.path.realpath(__file__)),
10038*da0073e9SAndroid Build Coastguard Worker                # It is ok to have an extra long timeout here as a timeout means the test failed
10039*da0073e9SAndroid Build Coastguard Worker                timeout=20,
10040*da0073e9SAndroid Build Coastguard Worker            )
10041*da0073e9SAndroid Build Coastguard Worker        except subprocess.TimeoutExpired as e:
10042*da0073e9SAndroid Build Coastguard Worker            self.fail(
10043*da0073e9SAndroid Build Coastguard Worker                msg="Example code timed out! See the code sample in the test for details."
10044*da0073e9SAndroid Build Coastguard Worker            )
10045*da0073e9SAndroid Build Coastguard Worker        except subprocess.CalledProcessError as e:
10046*da0073e9SAndroid Build Coastguard Worker            if e.returncode < 0:
10047*da0073e9SAndroid Build Coastguard Worker                # Sometimes we segfault instead of deadlocking
10048*da0073e9SAndroid Build Coastguard Worker                self.fail("Subprocess exited with a fatal signal")
10049*da0073e9SAndroid Build Coastguard Worker            else:
10050*da0073e9SAndroid Build Coastguard Worker                err_msg = (
10051*da0073e9SAndroid Build Coastguard Worker                    "RuntimeError: one of the variables needed for gradient computation"
10052*da0073e9SAndroid Build Coastguard Worker                )
10053*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(err_msg in e.output.decode("utf-8"))
10054*da0073e9SAndroid Build Coastguard Worker
10055*da0073e9SAndroid Build Coastguard Worker    def test_view_func_replay(self):
10056*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._force_original_view_tracking(True):
10057*da0073e9SAndroid Build Coastguard Worker
10058*da0073e9SAndroid Build Coastguard Worker            def _assert_match_metadata(a, b):
10059*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.size(), b.size())
10060*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.stride(), b.stride())
10061*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.storage_offset(), b.storage_offset())
10062*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.device, b.device)
10063*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.dtype, b.dtype)
10064*da0073e9SAndroid Build Coastguard Worker
10065*da0073e9SAndroid Build Coastguard Worker            def _test_fn(fn, inp, *args, use_unsafe_view_func=False):
10066*da0073e9SAndroid Build Coastguard Worker                outs = fn(inp, *args)
10067*da0073e9SAndroid Build Coastguard Worker                # handle functions that return multiple views (e.g. split)
10068*da0073e9SAndroid Build Coastguard Worker                if isinstance(outs, torch.Tensor):
10069*da0073e9SAndroid Build Coastguard Worker                    outs = [outs]
10070*da0073e9SAndroid Build Coastguard Worker
10071*da0073e9SAndroid Build Coastguard Worker                for out in outs:
10072*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out._is_view())
10073*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(out._base is inp)
10074*da0073e9SAndroid Build Coastguard Worker
10075*da0073e9SAndroid Build Coastguard Worker                    # forward view_func
10076*da0073e9SAndroid Build Coastguard Worker                    new_inp = inp.clone()
10077*da0073e9SAndroid Build Coastguard Worker                    _assert_match_metadata(new_inp, inp)
10078*da0073e9SAndroid Build Coastguard Worker                    if use_unsafe_view_func:
10079*da0073e9SAndroid Build Coastguard Worker                        new_out = out._view_func_unsafe(new_inp)
10080*da0073e9SAndroid Build Coastguard Worker                    else:
10081*da0073e9SAndroid Build Coastguard Worker                        new_out = out._view_func(new_inp)
10082*da0073e9SAndroid Build Coastguard Worker                    _assert_match_metadata(new_out, out)
10083*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(new_out, out)
10084*da0073e9SAndroid Build Coastguard Worker
10085*da0073e9SAndroid Build Coastguard Worker                    # reverse view_func
10086*da0073e9SAndroid Build Coastguard Worker                    new_out = out.detach()
10087*da0073e9SAndroid Build Coastguard Worker                    new_inp = out._rev_view_func_unsafe(new_out)
10088*da0073e9SAndroid Build Coastguard Worker                    _assert_match_metadata(new_inp, inp)
10089*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(new_inp._is_view())
10090*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(new_inp._base is new_out)
10091*da0073e9SAndroid Build Coastguard Worker
10092*da0073e9SAndroid Build Coastguard Worker            # test individual view ops
10093*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.ops.aten.alias.default, torch.rand(2, 2))
10094*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.as_strided, torch.rand(2, 2), (4,), (1,))
10095*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.chunk, torch.rand(2, 4), 2, -1)
10096*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.diagonal, torch.rand(4, 4))
10097*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.ops.aten.expand.default, torch.rand(4, 1), (-1, 3))
10098*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.narrow, torch.rand(2, 2), 0, 1, 1)
10099*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.permute, torch.rand(2, 3, 4), (1, 0, 2))
10100*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.select, torch.rand(2, 2), 0, 0)
10101*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.ops.aten.slice.Tensor, torch.rand(2, 2), 1, 1, 2)
10102*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.split, torch.rand(2, 2), 1)
10103*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.split_with_sizes, torch.rand(2, 4), [1, 3], -1)
10104*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.squeeze, torch.rand(2, 1, 4))
10105*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.squeeze, torch.rand(2, 1, 4), 1)
10106*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.squeeze, torch.rand(2, 1, 1, 4), [1, 2])
10107*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.t, torch.rand(2, 4))
10108*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.transpose, torch.rand(2, 4), 0, 1)
10109*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.unbind, torch.rand(1, 5))
10110*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.ops.aten.unfold.default, torch.rand(1, 5), 1, 3, 2)
10111*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.unsqueeze, torch.rand(2, 4), -2)
10112*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.ops.aten.view.default, torch.rand(2, 10), (-1, 5, 2))
10113*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.view_as_complex, torch.rand(2, 2))
10114*da0073e9SAndroid Build Coastguard Worker            _test_fn(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat))
10115*da0073e9SAndroid Build Coastguard Worker
10116*da0073e9SAndroid Build Coastguard Worker            # test view chains
10117*da0073e9SAndroid Build Coastguard Worker            _test_fn(
10118*da0073e9SAndroid Build Coastguard Worker                lambda x: x.unsqueeze(-1).transpose(-1, -2).squeeze(1),
10119*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 4),
10120*da0073e9SAndroid Build Coastguard Worker            )
10121*da0073e9SAndroid Build Coastguard Worker            _test_fn(
10122*da0073e9SAndroid Build Coastguard Worker                lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1),
10123*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 3, 4),
10124*da0073e9SAndroid Build Coastguard Worker            )
10125*da0073e9SAndroid Build Coastguard Worker            _test_fn(
10126*da0073e9SAndroid Build Coastguard Worker                lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0),
10127*da0073e9SAndroid Build Coastguard Worker                torch.randn(2, 3, 4),
10128*da0073e9SAndroid Build Coastguard Worker            )
10129*da0073e9SAndroid Build Coastguard Worker
10130*da0073e9SAndroid Build Coastguard Worker            # chains with missing view_func()s use as_strided() to cover the gaps
10131*da0073e9SAndroid Build Coastguard Worker            def chain_with_only_parent_view_func(x):
10132*da0073e9SAndroid Build Coastguard Worker                with torch.autograd._force_original_view_tracking(True):
10133*da0073e9SAndroid Build Coastguard Worker                    x = x.split_with_sizes([1, 3], -1)[0]
10134*da0073e9SAndroid Build Coastguard Worker
10135*da0073e9SAndroid Build Coastguard Worker                with torch.autograd._force_original_view_tracking(False):
10136*da0073e9SAndroid Build Coastguard Worker                    x = x.chunk(2, 0)
10137*da0073e9SAndroid Build Coastguard Worker
10138*da0073e9SAndroid Build Coastguard Worker                return x
10139*da0073e9SAndroid Build Coastguard Worker
10140*da0073e9SAndroid Build Coastguard Worker            _test_fn(chain_with_only_parent_view_func, torch.randn(2, 3, 4))
10141*da0073e9SAndroid Build Coastguard Worker
10142*da0073e9SAndroid Build Coastguard Worker            def chain_with_only_current_view_func(x):
10143*da0073e9SAndroid Build Coastguard Worker                with torch.autograd._force_original_view_tracking(False):
10144*da0073e9SAndroid Build Coastguard Worker                    x = x.split_with_sizes([1, 3], -1)[0]
10145*da0073e9SAndroid Build Coastguard Worker
10146*da0073e9SAndroid Build Coastguard Worker                with torch.autograd._force_original_view_tracking(True):
10147*da0073e9SAndroid Build Coastguard Worker                    x = x.chunk(2, 0)
10148*da0073e9SAndroid Build Coastguard Worker
10149*da0073e9SAndroid Build Coastguard Worker                return x
10150*da0073e9SAndroid Build Coastguard Worker
10151*da0073e9SAndroid Build Coastguard Worker            _test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4))
10152*da0073e9SAndroid Build Coastguard Worker
10153*da0073e9SAndroid Build Coastguard Worker            # TODO: Move this somewhere else
10154*da0073e9SAndroid Build Coastguard Worker            # test NT views
10155*da0073e9SAndroid Build Coastguard Worker            from torch.nested._internal.nested_tensor import (
10156*da0073e9SAndroid Build Coastguard Worker                nested_view_from_values_offsets,
10157*da0073e9SAndroid Build Coastguard Worker            )
10158*da0073e9SAndroid Build Coastguard Worker
10159*da0073e9SAndroid Build Coastguard Worker            values = torch.randn(10, 5)
10160*da0073e9SAndroid Build Coastguard Worker            offsets = torch.tensor([0, 3, 6, 10])
10161*da0073e9SAndroid Build Coastguard Worker            _test_fn(nested_view_from_values_offsets, values, offsets)
10162*da0073e9SAndroid Build Coastguard Worker
10163*da0073e9SAndroid Build Coastguard Worker            nt = nested_view_from_values_offsets(values, offsets).clone().detach()
10164*da0073e9SAndroid Build Coastguard Worker            _test_fn(
10165*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten._nested_get_values.default, nt, use_unsafe_view_func=True
10166*da0073e9SAndroid Build Coastguard Worker            )
10167*da0073e9SAndroid Build Coastguard Worker
10168*da0073e9SAndroid Build Coastguard Worker            def chain_nt_to_dense_back_and_forth(nt):
10169*da0073e9SAndroid Build Coastguard Worker                # NJT1 -> dense -> NJT2 -> dense
10170*da0073e9SAndroid Build Coastguard Worker                offsets2 = nt.offsets().clone().detach()
10171*da0073e9SAndroid Build Coastguard Worker                return nested_view_from_values_offsets(nt.values(), offsets2).values()
10172*da0073e9SAndroid Build Coastguard Worker
10173*da0073e9SAndroid Build Coastguard Worker            _test_fn(chain_nt_to_dense_back_and_forth, nt, use_unsafe_view_func=True)
10174*da0073e9SAndroid Build Coastguard Worker
10175*da0073e9SAndroid Build Coastguard Worker            def chain_dense_to_nt_back_and_forth(values, offsets):
10176*da0073e9SAndroid Build Coastguard Worker                offsets2 = offsets.clone().detach()
10177*da0073e9SAndroid Build Coastguard Worker                # dense -> NJT1 -> dense -> NJT2
10178*da0073e9SAndroid Build Coastguard Worker                return nested_view_from_values_offsets(
10179*da0073e9SAndroid Build Coastguard Worker                    nested_view_from_values_offsets(values, offsets).values(), offsets2
10180*da0073e9SAndroid Build Coastguard Worker                )
10181*da0073e9SAndroid Build Coastguard Worker
10182*da0073e9SAndroid Build Coastguard Worker            _test_fn(
10183*da0073e9SAndroid Build Coastguard Worker                chain_dense_to_nt_back_and_forth,
10184*da0073e9SAndroid Build Coastguard Worker                values,
10185*da0073e9SAndroid Build Coastguard Worker                offsets,
10186*da0073e9SAndroid Build Coastguard Worker                use_unsafe_view_func=True,
10187*da0073e9SAndroid Build Coastguard Worker            )
10188*da0073e9SAndroid Build Coastguard Worker
10189*da0073e9SAndroid Build Coastguard Worker    def test_view_func_replay_with_modified_state(self):
10190*da0073e9SAndroid Build Coastguard Worker        with torch.autograd._force_original_view_tracking(True):
10191*da0073e9SAndroid Build Coastguard Worker            base = torch.randn(3, 4, 5)
10192*da0073e9SAndroid Build Coastguard Worker            view = base.select(1, 2)
10193*da0073e9SAndroid Build Coastguard Worker
10194*da0073e9SAndroid Build Coastguard Worker            def symint_visitor_fn(x):
10195*da0073e9SAndroid Build Coastguard Worker                # modify saved index
10196*da0073e9SAndroid Build Coastguard Worker                return x + 1
10197*da0073e9SAndroid Build Coastguard Worker
10198*da0073e9SAndroid Build Coastguard Worker            # ensure modifying state changes view replay
10199*da0073e9SAndroid Build Coastguard Worker            new_base = torch.randn_like(base)
10200*da0073e9SAndroid Build Coastguard Worker            new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn)
10201*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(new_view, new_base.select(1, 3))
10202*da0073e9SAndroid Build Coastguard Worker
10203*da0073e9SAndroid Build Coastguard Worker            # ensure saved state reverts back afterwards
10204*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(view._view_func(new_base), new_base.select(1, 2))
10205*da0073e9SAndroid Build Coastguard Worker
10206*da0073e9SAndroid Build Coastguard Worker            # check modifying tensor state. currently, slice_inverse() is the only
10207*da0073e9SAndroid Build Coastguard Worker            # view that saves a tensor
10208*da0073e9SAndroid Build Coastguard Worker            base = torch.randn(3, 4, 5)
10209*da0073e9SAndroid Build Coastguard Worker            sliced = base[:, 2:3, :].detach()
10210*da0073e9SAndroid Build Coastguard Worker            view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1)
10211*da0073e9SAndroid Build Coastguard Worker
10212*da0073e9SAndroid Build Coastguard Worker            replacement_shape = (1, 2, 3)
10213*da0073e9SAndroid Build Coastguard Worker
10214*da0073e9SAndroid Build Coastguard Worker            def tensor_visitor_fn(x):
10215*da0073e9SAndroid Build Coastguard Worker                # return tensor with a smaller shape than the saved one
10216*da0073e9SAndroid Build Coastguard Worker                return torch.randn(*replacement_shape)
10217*da0073e9SAndroid Build Coastguard Worker
10218*da0073e9SAndroid Build Coastguard Worker            # ensure modifying state changes view replay
10219*da0073e9SAndroid Build Coastguard Worker            new_sliced = torch.ones_like(base)[:, 2:3, :].detach()
10220*da0073e9SAndroid Build Coastguard Worker            new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn)
10221*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(new_view.shape, replacement_shape)
10222*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
10223*da0073e9SAndroid Build Coastguard Worker                new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1))
10224*da0073e9SAndroid Build Coastguard Worker            )
10225*da0073e9SAndroid Build Coastguard Worker
10226*da0073e9SAndroid Build Coastguard Worker            # ensure saved state reverts back afterwards
10227*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(view._view_func(sliced), base)
10228*da0073e9SAndroid Build Coastguard Worker
10229*da0073e9SAndroid Build Coastguard Worker    def test_setup_context_when_forward_has_default_args(self):
10230*da0073e9SAndroid Build Coastguard Worker        class PowFunction(Function):
10231*da0073e9SAndroid Build Coastguard Worker            @staticmethod
10232*da0073e9SAndroid Build Coastguard Worker            def forward(x, y=3):
10233*da0073e9SAndroid Build Coastguard Worker                return torch.pow(x, y)
10234*da0073e9SAndroid Build Coastguard Worker
10235*da0073e9SAndroid Build Coastguard Worker            @staticmethod
10236*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output):
10237*da0073e9SAndroid Build Coastguard Worker                x, y = inputs
10238*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
10239*da0073e9SAndroid Build Coastguard Worker                ctx.y = y
10240*da0073e9SAndroid Build Coastguard Worker
10241*da0073e9SAndroid Build Coastguard Worker            @staticmethod
10242*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
10243*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
10244*da0073e9SAndroid Build Coastguard Worker                y = ctx.y
10245*da0073e9SAndroid Build Coastguard Worker                return gO * y * torch.pow(x, y - 1), None
10246*da0073e9SAndroid Build Coastguard Worker
10247*da0073e9SAndroid Build Coastguard Worker        class PowFunctionWithClassmethod(Function):
10248*da0073e9SAndroid Build Coastguard Worker            @classmethod
10249*da0073e9SAndroid Build Coastguard Worker            def forward(cls, x, y=3):
10250*da0073e9SAndroid Build Coastguard Worker                return torch.pow(x, y)
10251*da0073e9SAndroid Build Coastguard Worker
10252*da0073e9SAndroid Build Coastguard Worker            @classmethod
10253*da0073e9SAndroid Build Coastguard Worker            def setup_context(cls, ctx, inputs, output):
10254*da0073e9SAndroid Build Coastguard Worker                x, y = inputs
10255*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
10256*da0073e9SAndroid Build Coastguard Worker                ctx.y = y
10257*da0073e9SAndroid Build Coastguard Worker
10258*da0073e9SAndroid Build Coastguard Worker            @classmethod
10259*da0073e9SAndroid Build Coastguard Worker            def backward(cls, ctx, gO):
10260*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
10261*da0073e9SAndroid Build Coastguard Worker                y = ctx.y
10262*da0073e9SAndroid Build Coastguard Worker                return gO * y * torch.pow(x, y - 1), None
10263*da0073e9SAndroid Build Coastguard Worker
10264*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(2.0, requires_grad=True)
10265*da0073e9SAndroid Build Coastguard Worker
10266*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor(8.0)
10267*da0073e9SAndroid Build Coastguard Worker        y_expected = torch.tensor(12.0)
10268*da0073e9SAndroid Build Coastguard Worker
10269*da0073e9SAndroid Build Coastguard Worker        y1 = PowFunction.apply(x)
10270*da0073e9SAndroid Build Coastguard Worker        (y1_expected,) = torch.autograd.grad(y1, x)
10271*da0073e9SAndroid Build Coastguard Worker
10272*da0073e9SAndroid Build Coastguard Worker        y2 = PowFunctionWithClassmethod.apply(x)
10273*da0073e9SAndroid Build Coastguard Worker        (y2_expected,) = torch.autograd.grad(y2, x)
10274*da0073e9SAndroid Build Coastguard Worker
10275*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y1)
10276*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_expected, y1_expected)
10277*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, y2)
10278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y_expected, y2_expected)
10279*da0073e9SAndroid Build Coastguard Worker
10280*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
10281*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_default_device_placement_context(self):
10282*da0073e9SAndroid Build Coastguard Worker        # During gradcheck with fast_mode=True, we create a random vector on the CPU device using a CPU generator.
10283*da0073e9SAndroid Build Coastguard Worker        # This test ensures that this still works when the default device is set to something else by the user.
10284*da0073e9SAndroid Build Coastguard Worker        with torch.device("cuda"):
10285*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, dtype=torch.double, requires_grad=True)
10286*da0073e9SAndroid Build Coastguard Worker
10287*da0073e9SAndroid Build Coastguard Worker            def func(inp):
10288*da0073e9SAndroid Build Coastguard Worker                return inp**2.0
10289*da0073e9SAndroid Build Coastguard Worker
10290*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(gradcheck(func, x, fast_mode=True))
10291*da0073e9SAndroid Build Coastguard Worker
10292*da0073e9SAndroid Build Coastguard Worker
10293*da0073e9SAndroid Build Coastguard Workerdef index_perm_variable(shape, max_indices):
10294*da0073e9SAndroid Build Coastguard Worker    if not isinstance(shape, tuple):
10295*da0073e9SAndroid Build Coastguard Worker        shape = (shape,)
10296*da0073e9SAndroid Build Coastguard Worker
10297*da0073e9SAndroid Build Coastguard Worker    index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape)
10298*da0073e9SAndroid Build Coastguard Worker    return index
10299*da0073e9SAndroid Build Coastguard Worker
10300*da0073e9SAndroid Build Coastguard Worker
10301*da0073e9SAndroid Build Coastguard Workerdef bernoulli_scalar():
10302*da0073e9SAndroid Build Coastguard Worker    return torch.tensor(0, dtype=torch.uint8).bernoulli_()
10303*da0073e9SAndroid Build Coastguard Worker
10304*da0073e9SAndroid Build Coastguard Worker
10305*da0073e9SAndroid Build Coastguard Workerclass TestAutogradForwardModeBatchedGrad(TestCase):
10306*da0073e9SAndroid Build Coastguard Worker    def test_out_of_place_basic(self):
10307*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
10308*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4, 4, dtype=torch.double, requires_grad=True)
10309*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
10310*da0073e9SAndroid Build Coastguard Worker            gradcheck(
10311*da0073e9SAndroid Build Coastguard Worker                torch.sin,
10312*da0073e9SAndroid Build Coastguard Worker                a,
10313*da0073e9SAndroid Build Coastguard Worker                check_forward_ad=True,
10314*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=True,
10315*da0073e9SAndroid Build Coastguard Worker                check_batched_forward_grad=True,
10316*da0073e9SAndroid Build Coastguard Worker            )
10317*da0073e9SAndroid Build Coastguard Worker        )
10318*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(
10319*da0073e9SAndroid Build Coastguard Worker            gradcheck(
10320*da0073e9SAndroid Build Coastguard Worker                torch.add,
10321*da0073e9SAndroid Build Coastguard Worker                (a, b),
10322*da0073e9SAndroid Build Coastguard Worker                check_forward_ad=True,
10323*da0073e9SAndroid Build Coastguard Worker                check_batched_grad=True,
10324*da0073e9SAndroid Build Coastguard Worker                check_batched_forward_grad=True,
10325*da0073e9SAndroid Build Coastguard Worker            )
10326*da0073e9SAndroid Build Coastguard Worker        )
10327*da0073e9SAndroid Build Coastguard Worker
10328*da0073e9SAndroid Build Coastguard Worker    def test_out_of_place_not_same_layout(self):
10329*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros([2, 2]).transpose(0, 1)
10330*da0073e9SAndroid Build Coastguard Worker        tangent = torch.zeros([2, 2, 2])
10331*da0073e9SAndroid Build Coastguard Worker
10332*da0073e9SAndroid Build Coastguard Worker        def jvp(tangent):
10333*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
10334*da0073e9SAndroid Build Coastguard Worker                x = fwAD.make_dual(input, tangent)
10335*da0073e9SAndroid Build Coastguard Worker                return fwAD.unpack_dual(x)[1]
10336*da0073e9SAndroid Build Coastguard Worker
10337*da0073e9SAndroid Build Coastguard Worker        x_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
10338*da0073e9SAndroid Build Coastguard Worker
10339*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(x_tangent, tangent)
10340*da0073e9SAndroid Build Coastguard Worker
10341*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_same_layout(self):
10342*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros([2, 2])
10343*da0073e9SAndroid Build Coastguard Worker        tangent = torch.zeros([2, 2, 2])
10344*da0073e9SAndroid Build Coastguard Worker        base = torch.zeros([2, 2])
10345*da0073e9SAndroid Build Coastguard Worker        view = base.view_as(base)
10346*da0073e9SAndroid Build Coastguard Worker
10347*da0073e9SAndroid Build Coastguard Worker        def jvp(tangent):
10348*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
10349*da0073e9SAndroid Build Coastguard Worker                x = fwAD.make_dual(input, tangent)
10350*da0073e9SAndroid Build Coastguard Worker                view.copy_(x)
10351*da0073e9SAndroid Build Coastguard Worker                return (
10352*da0073e9SAndroid Build Coastguard Worker                    fwAD.unpack_dual(x)[1],
10353*da0073e9SAndroid Build Coastguard Worker                    fwAD.unpack_dual(view)[1],
10354*da0073e9SAndroid Build Coastguard Worker                    fwAD.unpack_dual(view._base)[1],
10355*da0073e9SAndroid Build Coastguard Worker                )
10356*da0073e9SAndroid Build Coastguard Worker
10357*da0073e9SAndroid Build Coastguard Worker        x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(
10358*da0073e9SAndroid Build Coastguard Worker            tangent
10359*da0073e9SAndroid Build Coastguard Worker        )
10360*da0073e9SAndroid Build Coastguard Worker
10361*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
10362*da0073e9SAndroid Build Coastguard Worker            view_tangent._is_view()
10363*da0073e9SAndroid Build Coastguard Worker        )  # Optimization to share the same tensor!
10364*da0073e9SAndroid Build Coastguard Worker        self.assertIs(view_tangent, base_tangent)
10365*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x_tangent, tangent)
10366*da0073e9SAndroid Build Coastguard Worker
10367*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_not_same_layout(self):
10368*da0073e9SAndroid Build Coastguard Worker        input = torch.zeros([2, 2])
10369*da0073e9SAndroid Build Coastguard Worker        tangent = torch.zeros([2, 2, 2])
10370*da0073e9SAndroid Build Coastguard Worker        view = torch.zeros([2, 2]).transpose(0, 1)
10371*da0073e9SAndroid Build Coastguard Worker
10372*da0073e9SAndroid Build Coastguard Worker        def jvp(tangent):
10373*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
10374*da0073e9SAndroid Build Coastguard Worker                x = fwAD.make_dual(input, tangent)
10375*da0073e9SAndroid Build Coastguard Worker                view.copy_(x)
10376*da0073e9SAndroid Build Coastguard Worker                return (
10377*da0073e9SAndroid Build Coastguard Worker                    fwAD.unpack_dual(x)[1],
10378*da0073e9SAndroid Build Coastguard Worker                    fwAD.unpack_dual(view)[1],
10379*da0073e9SAndroid Build Coastguard Worker                    fwAD.unpack_dual(view._base)[1],
10380*da0073e9SAndroid Build Coastguard Worker                )
10381*da0073e9SAndroid Build Coastguard Worker
10382*da0073e9SAndroid Build Coastguard Worker        x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(
10383*da0073e9SAndroid Build Coastguard Worker            tangent
10384*da0073e9SAndroid Build Coastguard Worker        )
10385*da0073e9SAndroid Build Coastguard Worker
10386*da0073e9SAndroid Build Coastguard Worker        self.assertIs(view_tangent._base, base_tangent)
10387*da0073e9SAndroid Build Coastguard Worker        self.assertIs(x_tangent, tangent)
10388*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(view_tangent, tangent)
10389*da0073e9SAndroid Build Coastguard Worker
10390*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_for_storage_numel_skipped(self):
10391*da0073e9SAndroid Build Coastguard Worker        # See: test_metadata_check_checks_storage_numel for the reverse of this test
10392*da0073e9SAndroid Build Coastguard Worker        primal = torch.randn(5)[:4].detach()
10393*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(primal.storage()), 5)
10394*da0073e9SAndroid Build Coastguard Worker        tangent = torch.randn(10, 4)
10395*da0073e9SAndroid Build Coastguard Worker
10396*da0073e9SAndroid Build Coastguard Worker        def jvp(tangent):
10397*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
10398*da0073e9SAndroid Build Coastguard Worker                dual = fwAD.make_dual(primal, tangent)
10399*da0073e9SAndroid Build Coastguard Worker                _, unpacked_tangent = fwAD.unpack_dual(dual)
10400*da0073e9SAndroid Build Coastguard Worker
10401*da0073e9SAndroid Build Coastguard Worker                # No copy is made
10402*da0073e9SAndroid Build Coastguard Worker                self.assertIs(tangent, unpacked_tangent)
10403*da0073e9SAndroid Build Coastguard Worker
10404*da0073e9SAndroid Build Coastguard Worker                # as_strided raises
10405*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
10406*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "can access memory outside of `tensor`"
10407*da0073e9SAndroid Build Coastguard Worker                ):
10408*da0073e9SAndroid Build Coastguard Worker                    dual.as_strided((5,), (1,), 0)
10409*da0073e9SAndroid Build Coastguard Worker            return unpacked_tangent
10410*da0073e9SAndroid Build Coastguard Worker
10411*da0073e9SAndroid Build Coastguard Worker        torch._vmap_internals._vmap(jvp, 0, 0)(tangent)
10412*da0073e9SAndroid Build Coastguard Worker
10413*da0073e9SAndroid Build Coastguard Worker
10414*da0073e9SAndroid Build Coastguard Workerclass TestAutogradForwardMode(TestCase):
10415*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
10416*da0073e9SAndroid Build Coastguard Worker        # Ensure that a failing test won't make others fail
10417*da0073e9SAndroid Build Coastguard Worker        while fwAD._current_level >= 0:
10418*da0073e9SAndroid Build Coastguard Worker            fwAD.exit_dual_level()
10419*da0073e9SAndroid Build Coastguard Worker
10420*da0073e9SAndroid Build Coastguard Worker        super().tearDown()
10421*da0073e9SAndroid Build Coastguard Worker
10422*da0073e9SAndroid Build Coastguard Worker    def test_forward_level_cleanup(self):
10423*da0073e9SAndroid Build Coastguard Worker        def get_tensor_and_weak_ref():
10424*da0073e9SAndroid Build Coastguard Worker            # Create a new Tensor and weak reference
10425*da0073e9SAndroid Build Coastguard Worker            t = torch.rand(2, requires_grad=True)
10426*da0073e9SAndroid Build Coastguard Worker            return t, torch._C._WeakTensorRef(t)
10427*da0073e9SAndroid Build Coastguard Worker
10428*da0073e9SAndroid Build Coastguard Worker        # Sanity check that the helper function works as expected
10429*da0073e9SAndroid Build Coastguard Worker        t, t_ref = get_tensor_and_weak_ref()
10430*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(t_ref.expired())
10431*da0073e9SAndroid Build Coastguard Worker
10432*da0073e9SAndroid Build Coastguard Worker        del t
10433*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(t_ref.expired())
10434*da0073e9SAndroid Build Coastguard Worker
10435*da0073e9SAndroid Build Coastguard Worker        # Main test code
10436*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10437*da0073e9SAndroid Build Coastguard Worker
10438*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10439*da0073e9SAndroid Build Coastguard Worker            tangent, tangent_ref = get_tensor_and_weak_ref()
10440*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(tangent_ref.expired())
10441*da0073e9SAndroid Build Coastguard Worker
10442*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, tangent)
10443*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(tangent_ref.expired())
10444*da0073e9SAndroid Build Coastguard Worker
10445*da0073e9SAndroid Build Coastguard Worker            # Make sure that the tangent we provided has been re-used as is
10446*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent)
10447*da0073e9SAndroid Build Coastguard Worker
10448*da0073e9SAndroid Build Coastguard Worker            # Make sure that dual is keeping the tangent alive
10449*da0073e9SAndroid Build Coastguard Worker            del tangent
10450*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(tangent_ref.expired())
10451*da0073e9SAndroid Build Coastguard Worker
10452*da0073e9SAndroid Build Coastguard Worker            # Make sure that the dual level does not keep the c++
10453*da0073e9SAndroid Build Coastguard Worker            # version of the tangent alive
10454*da0073e9SAndroid Build Coastguard Worker            del dual
10455*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(tangent_ref.expired())
10456*da0073e9SAndroid Build Coastguard Worker
10457*da0073e9SAndroid Build Coastguard Worker    def test_size_check(self):
10458*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10459*da0073e9SAndroid Build Coastguard Worker        tangent = torch.rand(3)
10460*da0073e9SAndroid Build Coastguard Worker
10461*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10462*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
10463*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
10464*da0073e9SAndroid Build Coastguard Worker                "Trying to set a forward gradient that has a different size",
10465*da0073e9SAndroid Build Coastguard Worker            ):
10466*da0073e9SAndroid Build Coastguard Worker                dual = fwAD.make_dual(foo, tangent)
10467*da0073e9SAndroid Build Coastguard Worker
10468*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, tangent[1:])
10469*da0073e9SAndroid Build Coastguard Worker
10470*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_checks_storage_numel(self):
10471*da0073e9SAndroid Build Coastguard Worker        primal = torch.randn(5)[:4].detach()
10472*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(primal.storage()), 5)
10473*da0073e9SAndroid Build Coastguard Worker        tangent = torch.randn(4)
10474*da0073e9SAndroid Build Coastguard Worker
10475*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10476*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(primal, tangent)
10477*da0073e9SAndroid Build Coastguard Worker            _, unpacked_tangent = fwAD.unpack_dual(dual)
10478*da0073e9SAndroid Build Coastguard Worker
10479*da0073e9SAndroid Build Coastguard Worker            # # Verify that mutating unpacked tangent does not affect the original tangent
10480*da0073e9SAndroid Build Coastguard Worker            tangent_clone = tangent.clone()
10481*da0073e9SAndroid Build Coastguard Worker            unpacked_tangent *= 2
10482*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(tangent_clone, tangent))
10483*da0073e9SAndroid Build Coastguard Worker
10484*da0073e9SAndroid Build Coastguard Worker            # as_strided runs without error
10485*da0073e9SAndroid Build Coastguard Worker            dual.as_strided((5,), (1,), 0)
10486*da0073e9SAndroid Build Coastguard Worker
10487*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_checks_ignores_size_zero(self):
10488*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(0).as_strided((0, 1), (1, 1), 0)
10489*da0073e9SAndroid Build Coastguard Worker        b = torch.ones(0).as_strided((0, 1), (1, 0), 0)
10490*da0073e9SAndroid Build Coastguard Worker
10491*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10492*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(a, b)
10493*da0073e9SAndroid Build Coastguard Worker            torch.diagonal(dual, offset=0)
10494*da0073e9SAndroid Build Coastguard Worker
10495*da0073e9SAndroid Build Coastguard Worker        input = torch.rand([0, 1], dtype=torch.complex128, requires_grad=True)
10496*da0073e9SAndroid Build Coastguard Worker        func = partial(torch.diagonal, offset=0)
10497*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(func, (input,), check_forward_ad=True)
10498*da0073e9SAndroid Build Coastguard Worker
10499*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_when_primal_has_conj_bit(self):
10500*da0073e9SAndroid Build Coastguard Worker        # Make sure the _has_same_storage_numel is a fallthrough, so that
10501*da0073e9SAndroid Build Coastguard Worker        # conj bit does not materialize. If it materializes it would
10502*da0073e9SAndroid Build Coastguard Worker        # cause the layout check to fail for views that do not index the
10503*da0073e9SAndroid Build Coastguard Worker        # the entire storage.
10504*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, dtype=torch.cdouble).conj()
10505*da0073e9SAndroid Build Coastguard Worker        b = torch.rand_like(a)
10506*da0073e9SAndroid Build Coastguard Worker
10507*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.is_conj(a))
10508*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(a.storage()), len(b.storage()))
10509*da0073e9SAndroid Build Coastguard Worker
10510*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10511*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(a, b)
10512*da0073e9SAndroid Build Coastguard Worker            dual[1:]
10513*da0073e9SAndroid Build Coastguard Worker
10514*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_when_primal_has_neg_bit(self):
10515*da0073e9SAndroid Build Coastguard Worker        # Make sure the _has_same_storage_numel is a fallthrough, so that
10516*da0073e9SAndroid Build Coastguard Worker        # conj bit does not materialize. If it materializes it would
10517*da0073e9SAndroid Build Coastguard Worker        # cause the layout check to fail for views that do not index the
10518*da0073e9SAndroid Build Coastguard Worker        # the entire storage.
10519*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 2, dtype=torch.cdouble).conj().imag
10520*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, dtype=torch.cdouble).imag
10521*da0073e9SAndroid Build Coastguard Worker
10522*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.is_neg(a))
10523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(a.storage()), len(b.storage()))
10524*da0073e9SAndroid Build Coastguard Worker
10525*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10526*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(a, b)
10527*da0073e9SAndroid Build Coastguard Worker            dual[1:]
10528*da0073e9SAndroid Build Coastguard Worker
10529*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_check_conj(self):
10530*da0073e9SAndroid Build Coastguard Worker        keys = {
10531*da0073e9SAndroid Build Coastguard Worker            "NEITHER": lambda x: x,
10532*da0073e9SAndroid Build Coastguard Worker            "CONJ": lambda x: x.conj(),
10533*da0073e9SAndroid Build Coastguard Worker            "NEG": lambda x: x._neg_view(),
10534*da0073e9SAndroid Build Coastguard Worker        }
10535*da0073e9SAndroid Build Coastguard Worker
10536*da0073e9SAndroid Build Coastguard Worker        for primal_key, tangent_key in product(keys, keys):
10537*da0073e9SAndroid Build Coastguard Worker            x = keys[primal_key](torch.randn(2, 3, 4, dtype=torch.cdouble))
10538*da0073e9SAndroid Build Coastguard Worker            t = keys[tangent_key](torch.randn(2, 3, 4, dtype=torch.cdouble))
10539*da0073e9SAndroid Build Coastguard Worker
10540*da0073e9SAndroid Build Coastguard Worker            if primal_key == tangent_key:
10541*da0073e9SAndroid Build Coastguard Worker                with fwAD.dual_level():
10542*da0073e9SAndroid Build Coastguard Worker                    dual = fwAD.make_dual(x, t)
10543*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(fwAD.unpack_dual(dual).tangent is t)
10544*da0073e9SAndroid Build Coastguard Worker                    torch.real(dual)
10545*da0073e9SAndroid Build Coastguard Worker                    torch.imag(dual)
10546*da0073e9SAndroid Build Coastguard Worker            else:
10547*da0073e9SAndroid Build Coastguard Worker                with fwAD.dual_level():
10548*da0073e9SAndroid Build Coastguard Worker                    dual = fwAD.make_dual(x, t)
10549*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(fwAD.unpack_dual(dual).tangent is not t)
10550*da0073e9SAndroid Build Coastguard Worker                    torch.real(dual)
10551*da0073e9SAndroid Build Coastguard Worker                    torch.imag(dual)
10552*da0073e9SAndroid Build Coastguard Worker
10553*da0073e9SAndroid Build Coastguard Worker    def test_metadata_check_ignore_storage_offset_for_zero_numel_tensor(self):
10554*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/80507
10555*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0]).as_strided((0,), (1,), 1)
10556*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([1.0]).as_strided((0,), (1,), 2)
10557*da0073e9SAndroid Build Coastguard Worker
10558*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10559*da0073e9SAndroid Build Coastguard Worker            dual_input = fwAD.make_dual(a, b)
10560*da0073e9SAndroid Build Coastguard Worker            # Check that no copy is made
10561*da0073e9SAndroid Build Coastguard Worker            self.assertIs(fwAD.unpack_dual(dual_input).tangent, b)
10562*da0073e9SAndroid Build Coastguard Worker
10563*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0]).as_strided((1,), (2,), 0)
10564*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([1.0]).as_strided((1,), (1,), 0)
10565*da0073e9SAndroid Build Coastguard Worker
10566*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10567*da0073e9SAndroid Build Coastguard Worker            dual_input = fwAD.make_dual(a, b)
10568*da0073e9SAndroid Build Coastguard Worker            dual_input[1:]
10569*da0073e9SAndroid Build Coastguard Worker
10570*da0073e9SAndroid Build Coastguard Worker    # The following test functions want to ensure all the following behaviors:
10571*da0073e9SAndroid Build Coastguard Worker    #   - Ensure that default level system in the python binding works
10572*da0073e9SAndroid Build Coastguard Worker    #   - Ensure that only level 0 exists and nesting is properly disabled
10573*da0073e9SAndroid Build Coastguard Worker    #   - Ensure that printing works fine
10574*da0073e9SAndroid Build Coastguard Worker    #   - Ensure that basic packing/unpacking works
10575*da0073e9SAndroid Build Coastguard Worker    #   - Ensure that advanced packing/unpacking works
10576*da0073e9SAndroid Build Coastguard Worker    #     - For memory / version counter share
10577*da0073e9SAndroid Build Coastguard Worker    #     - For backward AD (regular ops)
10578*da0073e9SAndroid Build Coastguard Worker    #   - Ensure that view + inplace for both modes work fine
10579*da0073e9SAndroid Build Coastguard Worker    #   - Ensure we do proper cleanup on exit of a level
10580*da0073e9SAndroid Build Coastguard Worker
10581*da0073e9SAndroid Build Coastguard Worker    def test_default_level(self):
10582*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10583*da0073e9SAndroid Build Coastguard Worker        bar = torch.rand(2)
10584*da0073e9SAndroid Build Coastguard Worker
10585*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10586*da0073e9SAndroid Build Coastguard Worker            baz = fwAD.make_dual(foo, bar)
10587*da0073e9SAndroid Build Coastguard Worker            baz_primal, baz_tangent = fwAD.unpack_dual(baz)
10588*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(baz_primal, foo)
10589*da0073e9SAndroid Build Coastguard Worker        # We don't actually need to enforce that these two are the exact same python
10590*da0073e9SAndroid Build Coastguard Worker        # object, feel free to relax in the future
10591*da0073e9SAndroid Build Coastguard Worker        self.assertIs(baz_tangent, bar)
10592*da0073e9SAndroid Build Coastguard Worker
10593*da0073e9SAndroid Build Coastguard Worker        baz_primal, baz_tangent = fwAD.unpack_dual(baz)
10594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(baz_primal, foo)
10595*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(baz_tangent, None)
10596*da0073e9SAndroid Build Coastguard Worker
10597*da0073e9SAndroid Build Coastguard Worker    def test_fwd_grad_enabled(self):
10598*da0073e9SAndroid Build Coastguard Worker        # Tests some private helper functions to enable/disable fwd grad mode
10599*da0073e9SAndroid Build Coastguard Worker        enabled = fwAD._is_fwd_grad_enabled()
10600*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(enabled)
10601*da0073e9SAndroid Build Coastguard Worker
10602*da0073e9SAndroid Build Coastguard Worker        try:
10603*da0073e9SAndroid Build Coastguard Worker            torch._C._set_fwd_grad_enabled(False)
10604*da0073e9SAndroid Build Coastguard Worker            enabled = fwAD._is_fwd_grad_enabled()
10605*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(enabled)
10606*da0073e9SAndroid Build Coastguard Worker        finally:
10607*da0073e9SAndroid Build Coastguard Worker            torch._C._set_fwd_grad_enabled(True)
10608*da0073e9SAndroid Build Coastguard Worker
10609*da0073e9SAndroid Build Coastguard Worker        enabled = fwAD._is_fwd_grad_enabled()
10610*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(enabled)
10611*da0073e9SAndroid Build Coastguard Worker
10612*da0073e9SAndroid Build Coastguard Worker    def test_set_fwd_grad_enabled(self):
10613*da0073e9SAndroid Build Coastguard Worker        # Tests a private helper function
10614*da0073e9SAndroid Build Coastguard Worker        try:
10615*da0073e9SAndroid Build Coastguard Worker            torch._C._set_fwd_grad_enabled(False)
10616*da0073e9SAndroid Build Coastguard Worker            enabled = fwAD._is_fwd_grad_enabled()
10617*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(enabled)
10618*da0073e9SAndroid Build Coastguard Worker
10619*da0073e9SAndroid Build Coastguard Worker            with fwAD._set_fwd_grad_enabled(True):
10620*da0073e9SAndroid Build Coastguard Worker                enabled = fwAD._is_fwd_grad_enabled()
10621*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(enabled)
10622*da0073e9SAndroid Build Coastguard Worker
10623*da0073e9SAndroid Build Coastguard Worker            enabled = fwAD._is_fwd_grad_enabled()
10624*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(enabled)
10625*da0073e9SAndroid Build Coastguard Worker        finally:
10626*da0073e9SAndroid Build Coastguard Worker            torch._C._set_fwd_grad_enabled(True)
10627*da0073e9SAndroid Build Coastguard Worker
10628*da0073e9SAndroid Build Coastguard Worker    def test_nested_level(self):
10629*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level() as level:
10630*da0073e9SAndroid Build Coastguard Worker            # For now only level 0 exists
10631*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(level, 0)
10632*da0073e9SAndroid Build Coastguard Worker
10633*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10634*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
10635*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Nested forward mode AD is not supported at the moment"
10636*da0073e9SAndroid Build Coastguard Worker            ):
10637*da0073e9SAndroid Build Coastguard Worker                nest_level = fwAD.enter_dual_level()
10638*da0073e9SAndroid Build Coastguard Worker
10639*da0073e9SAndroid Build Coastguard Worker    def test_set_fw_grad_having_own_fw_grad_at_same_level(self):
10640*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10641*da0073e9SAndroid Build Coastguard Worker        bar = torch.rand(2)
10642*da0073e9SAndroid Build Coastguard Worker        baz = torch.rand(2)
10643*da0073e9SAndroid Build Coastguard Worker
10644*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10645*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, bar)
10646*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
10647*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "has a forward gradient at the same level"
10648*da0073e9SAndroid Build Coastguard Worker            ):
10649*da0073e9SAndroid Build Coastguard Worker                fwAD.make_dual(baz, dual)
10650*da0073e9SAndroid Build Coastguard Worker
10651*da0073e9SAndroid Build Coastguard Worker    def test_codegen_ignores_undefined_outputs(self):
10652*da0073e9SAndroid Build Coastguard Worker        # This test checks that codegen silently ignores undefined outputs
10653*da0073e9SAndroid Build Coastguard Worker        # Below, grad_input is specified as False in grad_output_mask, so
10654*da0073e9SAndroid Build Coastguard Worker        # convolution backward will return a undefined tensor in that position.
10655*da0073e9SAndroid Build Coastguard Worker        # Note that for this test to work we need to make sure either grad_output
10656*da0073e9SAndroid Build Coastguard Worker        # or weight to be a dual tensor, so grad_input requires forward grad
10657*da0073e9SAndroid Build Coastguard Worker        weight = torch.randn(6, 1, 30, 30)
10658*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand((1, 1, 32, 32))
10659*da0073e9SAndroid Build Coastguard Worker        out = torch.nn.functional.conv2d(inp, weight)
10660*da0073e9SAndroid Build Coastguard Worker        grad_out = torch.ones_like(out)
10661*da0073e9SAndroid Build Coastguard Worker
10662*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10663*da0073e9SAndroid Build Coastguard Worker            dual_weight = fwAD.make_dual(weight, torch.ones_like(weight))
10664*da0073e9SAndroid Build Coastguard Worker            grad_input, _, _ = torch.ops.aten.convolution_backward(
10665*da0073e9SAndroid Build Coastguard Worker                grad_out,
10666*da0073e9SAndroid Build Coastguard Worker                inp,
10667*da0073e9SAndroid Build Coastguard Worker                dual_weight,
10668*da0073e9SAndroid Build Coastguard Worker                (0,),
10669*da0073e9SAndroid Build Coastguard Worker                (1, 1),
10670*da0073e9SAndroid Build Coastguard Worker                (0, 0),
10671*da0073e9SAndroid Build Coastguard Worker                (1, 1),
10672*da0073e9SAndroid Build Coastguard Worker                False,
10673*da0073e9SAndroid Build Coastguard Worker                (0, 0),
10674*da0073e9SAndroid Build Coastguard Worker                1,
10675*da0073e9SAndroid Build Coastguard Worker                (False, True, False),
10676*da0073e9SAndroid Build Coastguard Worker            )
10677*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(grad_input)
10678*da0073e9SAndroid Build Coastguard Worker
10679*da0073e9SAndroid Build Coastguard Worker    def test_make_dual_inference_tensor_in_inference_mode(self):
10680*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
10681*da0073e9SAndroid Build Coastguard Worker            foo = torch.rand(2)
10682*da0073e9SAndroid Build Coastguard Worker            bar = torch.rand(2)
10683*da0073e9SAndroid Build Coastguard Worker            foo_copy = foo.clone()
10684*da0073e9SAndroid Build Coastguard Worker
10685*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
10686*da0073e9SAndroid Build Coastguard Worker                dual = fwAD.make_dual(foo, bar)
10687*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(dual._is_view())
10688*da0073e9SAndroid Build Coastguard Worker
10689*da0073e9SAndroid Build Coastguard Worker                dual += 1
10690*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.allclose(foo, foo_copy))
10691*da0073e9SAndroid Build Coastguard Worker
10692*da0073e9SAndroid Build Coastguard Worker    def test_make_dual_torch_dispatch(self):
10693*da0073e9SAndroid Build Coastguard Worker        counter = [0]
10694*da0073e9SAndroid Build Coastguard Worker
10695*da0073e9SAndroid Build Coastguard Worker        class MySubclass(torch.Tensor):
10696*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, data=None):
10697*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, data)
10698*da0073e9SAndroid Build Coastguard Worker
10699*da0073e9SAndroid Build Coastguard Worker            @classmethod
10700*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
10701*da0073e9SAndroid Build Coastguard Worker                if func.overloadpacket == torch.ops.aten.alias:
10702*da0073e9SAndroid Build Coastguard Worker                    counter[0] += 1
10703*da0073e9SAndroid Build Coastguard Worker
10704*da0073e9SAndroid Build Coastguard Worker                    # Make sure we can re-enable autograd here
10705*da0073e9SAndroid Build Coastguard Worker                    with torch.overrides.enable_reentrant_dispatch():
10706*da0073e9SAndroid Build Coastguard Worker                        foo = torch.rand(1, requires_grad=True)
10707*da0073e9SAndroid Build Coastguard Worker                        self.assertIsNotNone(foo.exp().grad_fn)
10708*da0073e9SAndroid Build Coastguard Worker
10709*da0073e9SAndroid Build Coastguard Worker                with no_dispatch():
10710*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
10711*da0073e9SAndroid Build Coastguard Worker
10712*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0)
10713*da0073e9SAndroid Build Coastguard Worker        s = MySubclass(a)
10714*da0073e9SAndroid Build Coastguard Worker
10715*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10716*da0073e9SAndroid Build Coastguard Worker            # Only the primal has "alias" called on it
10717*da0073e9SAndroid Build Coastguard Worker            fwAD.make_dual(s, torch.rand_like(s))
10718*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter[0], 1)
10719*da0073e9SAndroid Build Coastguard Worker            fwAD.make_dual(torch.rand_like(s), s)
10720*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter[0], 1)
10721*da0073e9SAndroid Build Coastguard Worker
10722*da0073e9SAndroid Build Coastguard Worker    def test_make_dual_forbid_integral_dtype(self):
10723*da0073e9SAndroid Build Coastguard Worker        primal_f = torch.ones(2, 2, dtype=torch.float)
10724*da0073e9SAndroid Build Coastguard Worker        primal_l = torch.ones(2, 2, dtype=torch.long)
10725*da0073e9SAndroid Build Coastguard Worker
10726*da0073e9SAndroid Build Coastguard Worker        tangent_f = torch.ones(2, 2, dtype=torch.float)
10727*da0073e9SAndroid Build Coastguard Worker        tangent_l = torch.ones(2, 2, dtype=torch.long)
10728*da0073e9SAndroid Build Coastguard Worker
10729*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10730*da0073e9SAndroid Build Coastguard Worker            # Float Primal and Long Tangent
10731*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
10732*da0073e9SAndroid Build Coastguard Worker                ValueError, "Expected tangent to be floating point or complex"
10733*da0073e9SAndroid Build Coastguard Worker            ):
10734*da0073e9SAndroid Build Coastguard Worker                fwAD.make_dual(primal_f, tangent_l)
10735*da0073e9SAndroid Build Coastguard Worker
10736*da0073e9SAndroid Build Coastguard Worker            # Long Primal and Long Tangent
10737*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
10738*da0073e9SAndroid Build Coastguard Worker                ValueError, "Expected primal to be floating point or complex"
10739*da0073e9SAndroid Build Coastguard Worker            ):
10740*da0073e9SAndroid Build Coastguard Worker                fwAD.make_dual(primal_l, tangent_l)
10741*da0073e9SAndroid Build Coastguard Worker
10742*da0073e9SAndroid Build Coastguard Worker            # Long Primal and Float Tangent
10743*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
10744*da0073e9SAndroid Build Coastguard Worker                ValueError, "Expected primal to be floating point or complex"
10745*da0073e9SAndroid Build Coastguard Worker            ):
10746*da0073e9SAndroid Build Coastguard Worker                fwAD.make_dual(primal_l, tangent_f)
10747*da0073e9SAndroid Build Coastguard Worker
10748*da0073e9SAndroid Build Coastguard Worker    def test_print(self):
10749*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level() as level:
10750*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(3)
10751*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("tangent=" in str(a))
10752*da0073e9SAndroid Build Coastguard Worker
10753*da0073e9SAndroid Build Coastguard Worker            b = fwAD.make_dual(a, torch.rand(3))
10754*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("tangent=" in str(a))
10755*da0073e9SAndroid Build Coastguard Worker            self.assertTrue("tangent=" in str(b))
10756*da0073e9SAndroid Build Coastguard Worker
10757*da0073e9SAndroid Build Coastguard Worker            b_primal, b_tangent = fwAD.unpack_dual(b)
10758*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("tangent=" in str(b_primal))
10759*da0073e9SAndroid Build Coastguard Worker            self.assertFalse("tangent=" in str(b_tangent))
10760*da0073e9SAndroid Build Coastguard Worker
10761*da0073e9SAndroid Build Coastguard Worker    def test_basic_packing_unpacking(self):
10762*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10763*da0073e9SAndroid Build Coastguard Worker        bar = torch.rand(2)
10764*da0073e9SAndroid Build Coastguard Worker
10765*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10766*da0073e9SAndroid Build Coastguard Worker            baz = fwAD.make_dual(foo, bar)
10767*da0073e9SAndroid Build Coastguard Worker            baz_primal, baz_tangent = fwAD.unpack_dual(baz)
10768*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(baz_primal, foo)
10769*da0073e9SAndroid Build Coastguard Worker            self.assertIs(baz_tangent, bar)
10770*da0073e9SAndroid Build Coastguard Worker
10771*da0073e9SAndroid Build Coastguard Worker            # Check unpacked dual is returned as a named tuple
10772*da0073e9SAndroid Build Coastguard Worker            # NB: Every invocation of unpack_dual returns a new tensor view
10773*da0073e9SAndroid Build Coastguard Worker            self.assertIsNot(baz_primal, fwAD.unpack_dual(baz).primal)
10774*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(baz_primal, fwAD.unpack_dual(baz).primal)
10775*da0073e9SAndroid Build Coastguard Worker            self.assertIs(baz_tangent, fwAD.unpack_dual(baz).tangent)
10776*da0073e9SAndroid Build Coastguard Worker
10777*da0073e9SAndroid Build Coastguard Worker            # Check that packing/unpacking did not change the input
10778*da0073e9SAndroid Build Coastguard Worker            foo_primal, foo_tangent = fwAD.unpack_dual(foo)
10779*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo_primal, foo)
10780*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(foo_tangent)
10781*da0073e9SAndroid Build Coastguard Worker
10782*da0073e9SAndroid Build Coastguard Worker    def test_advanced_packing_unpacking(self):
10783*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10784*da0073e9SAndroid Build Coastguard Worker        bar = torch.ones(2)
10785*da0073e9SAndroid Build Coastguard Worker
10786*da0073e9SAndroid Build Coastguard Worker        # Memory and version counter check
10787*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10788*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, bar)
10789*da0073e9SAndroid Build Coastguard Worker
10790*da0073e9SAndroid Build Coastguard Worker            # Ensure that they are sharing memory and version counter
10791*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dual.storage().data_ptr(), foo.storage().data_ptr())
10792*da0073e9SAndroid Build Coastguard Worker
10793*da0073e9SAndroid Build Coastguard Worker            # Ensure we properly share the version counter
10794*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo._version, dual._version)
10795*da0073e9SAndroid Build Coastguard Worker            foo.add_(1)
10796*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo._version, dual._version)
10797*da0073e9SAndroid Build Coastguard Worker
10798*da0073e9SAndroid Build Coastguard Worker            # Unpacking should only create aliases as well
10799*da0073e9SAndroid Build Coastguard Worker            dual_primal, dual_tangent = fwAD.unpack_dual(dual)
10800*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dual_primal.storage().data_ptr(), foo.storage().data_ptr())
10801*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
10802*da0073e9SAndroid Build Coastguard Worker                dual_tangent.storage().data_ptr(), bar.storage().data_ptr()
10803*da0073e9SAndroid Build Coastguard Worker            )
10804*da0073e9SAndroid Build Coastguard Worker            # And the tangent is actually re-used as-is so it is still the same Tensor
10805*da0073e9SAndroid Build Coastguard Worker            self.assertIs(dual_tangent, bar)
10806*da0073e9SAndroid Build Coastguard Worker
10807*da0073e9SAndroid Build Coastguard Worker            # Ensure we properly share the version counter
10808*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo._version, dual_primal._version)
10809*da0073e9SAndroid Build Coastguard Worker            foo.add_(1)
10810*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo._version, dual_primal._version)
10811*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bar._version, dual_tangent._version)
10812*da0073e9SAndroid Build Coastguard Worker            bar.add_(1)
10813*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bar._version, dual_tangent._version)
10814*da0073e9SAndroid Build Coastguard Worker
10815*da0073e9SAndroid Build Coastguard Worker        # backward mode check
10816*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10817*da0073e9SAndroid Build Coastguard Worker            foo.requires_grad_()
10818*da0073e9SAndroid Build Coastguard Worker            bar.requires_grad_()
10819*da0073e9SAndroid Build Coastguard Worker
10820*da0073e9SAndroid Build Coastguard Worker            # Check that backward gradients properly propagates through packing/unpacking
10821*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, bar)
10822*da0073e9SAndroid Build Coastguard Worker            p, t = fwAD.unpack_dual(dual)
10823*da0073e9SAndroid Build Coastguard Worker
10824*da0073e9SAndroid Build Coastguard Worker            gfoo, gbar = torch.autograd.grad(
10825*da0073e9SAndroid Build Coastguard Worker                p.sum(), (foo, bar), retain_graph=True, allow_unused=True
10826*da0073e9SAndroid Build Coastguard Worker            )
10827*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gfoo, torch.ones_like(foo))
10828*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(gbar)
10829*da0073e9SAndroid Build Coastguard Worker
10830*da0073e9SAndroid Build Coastguard Worker            gfoo, gbar = torch.autograd.grad(
10831*da0073e9SAndroid Build Coastguard Worker                t.sum(), (foo, bar), retain_graph=True, allow_unused=True
10832*da0073e9SAndroid Build Coastguard Worker            )
10833*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(gfoo)
10834*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(gbar, torch.ones_like(bar))
10835*da0073e9SAndroid Build Coastguard Worker
10836*da0073e9SAndroid Build Coastguard Worker            # Check that forward gradients are impacted by detach()
10837*da0073e9SAndroid Build Coastguard Worker            detached_dual = dual.detach()
10838*da0073e9SAndroid Build Coastguard Worker            out = detached_dual * 2
10839*da0073e9SAndroid Build Coastguard Worker            p, t = fwAD.unpack_dual(out)
10840*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.requires_grad)
10841*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p, foo * 2)
10842*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(t)
10843*da0073e9SAndroid Build Coastguard Worker
10844*da0073e9SAndroid Build Coastguard Worker            # Check that forward gradients are not impacted by no_grad
10845*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
10846*da0073e9SAndroid Build Coastguard Worker                out = dual * 3
10847*da0073e9SAndroid Build Coastguard Worker            p, t = fwAD.unpack_dual(out)
10848*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.requires_grad)
10849*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(t.requires_grad)
10850*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p, foo * 3)
10851*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t, bar * 3)
10852*da0073e9SAndroid Build Coastguard Worker
10853*da0073e9SAndroid Build Coastguard Worker            # Check that forward gradients are not impacted by inplace detach
10854*da0073e9SAndroid Build Coastguard Worker            dual = dual.clone()
10855*da0073e9SAndroid Build Coastguard Worker            dual.detach_()
10856*da0073e9SAndroid Build Coastguard Worker            out = dual * 2
10857*da0073e9SAndroid Build Coastguard Worker            p, t = fwAD.unpack_dual(out)
10858*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(p.requires_grad)
10859*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(p, foo * 2)
10860*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(t)
10861*da0073e9SAndroid Build Coastguard Worker
10862*da0073e9SAndroid Build Coastguard Worker    def test_view_inplace_non_differentiable_views(self):
10863*da0073e9SAndroid Build Coastguard Worker        original_foo = torch.rand(2, dtype=torch.double)
10864*da0073e9SAndroid Build Coastguard Worker        original_bar = torch.ones(2, dtype=torch.double)
10865*da0073e9SAndroid Build Coastguard Worker
10866*da0073e9SAndroid Build Coastguard Worker        # Do clones to be able to compare the values updated inplace
10867*da0073e9SAndroid Build Coastguard Worker        # with the original content of these Tensors
10868*da0073e9SAndroid Build Coastguard Worker        foo = original_foo.clone()
10869*da0073e9SAndroid Build Coastguard Worker        bar = original_bar.clone()
10870*da0073e9SAndroid Build Coastguard Worker
10871*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10872*da0073e9SAndroid Build Coastguard Worker            # Note that in this test, we use "update" to mean computing the right tangent for the dual
10873*da0073e9SAndroid Build Coastguard Worker            # All the inplace operations here are expected to update the primal value of the Tensors but
10874*da0073e9SAndroid Build Coastguard Worker            # not always their tangents.
10875*da0073e9SAndroid Build Coastguard Worker            # Also all mentions of "non differentiable view" here means non forward differentiable view
10876*da0073e9SAndroid Build Coastguard Worker            # unless specified otherwise.
10877*da0073e9SAndroid Build Coastguard Worker            # See note [Forward Grad View/inplace] for more details on how these views work.
10878*da0073e9SAndroid Build Coastguard Worker
10879*da0073e9SAndroid Build Coastguard Worker            # Check that inplace ops do not update non-differentiable views
10880*da0073e9SAndroid Build Coastguard Worker            # Non differentiable view
10881*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, bar)
10882*da0073e9SAndroid Build Coastguard Worker            dual *= 2
10883*da0073e9SAndroid Build Coastguard Worker            # Check that non differentiable view's tangent was not updated
10884*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10885*da0073e9SAndroid Build Coastguard Worker            # Check that the computed result is correct
10886*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(bar, original_bar * 2)
10887*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2)
10888*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(foo, original_foo * 2)
10889*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 2)
10890*da0073e9SAndroid Build Coastguard Worker            # Other non differentiable view
10891*da0073e9SAndroid Build Coastguard Worker            dual_primal, dual_tangent = fwAD.unpack_dual(dual)
10892*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(dual_primal)[1])
10893*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(dual_tangent)[1])
10894*da0073e9SAndroid Build Coastguard Worker            dual_primal *= 2
10895*da0073e9SAndroid Build Coastguard Worker            # Ensure dual's tangent did not change
10896*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4)
10897*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2)
10898*da0073e9SAndroid Build Coastguard Worker            dual_tangent *= 2
10899*da0073e9SAndroid Build Coastguard Worker            # Ensure dual's primal did not change
10900*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4)
10901*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 4)
10902*da0073e9SAndroid Build Coastguard Worker
10903*da0073e9SAndroid Build Coastguard Worker    def test_view_inplace_differentiable_views(self):
10904*da0073e9SAndroid Build Coastguard Worker        original_foo = torch.rand(2)
10905*da0073e9SAndroid Build Coastguard Worker        original_bar = torch.ones(2)
10906*da0073e9SAndroid Build Coastguard Worker
10907*da0073e9SAndroid Build Coastguard Worker        # Do clones to be able to compare the values updated inplace
10908*da0073e9SAndroid Build Coastguard Worker        # with the original content of these Tensors
10909*da0073e9SAndroid Build Coastguard Worker        foo = original_foo.clone()
10910*da0073e9SAndroid Build Coastguard Worker        bar = original_bar.clone()
10911*da0073e9SAndroid Build Coastguard Worker
10912*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10913*da0073e9SAndroid Build Coastguard Worker            # Check that inplace ops do update differentiable view but stop at non differentiable ones
10914*da0073e9SAndroid Build Coastguard Worker            # A non differentiable view
10915*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, bar)
10916*da0073e9SAndroid Build Coastguard Worker            # A differentiable view
10917*da0073e9SAndroid Build Coastguard Worker            view = dual.narrow(0, 0, 1)
10918*da0073e9SAndroid Build Coastguard Worker            view *= 2
10919*da0073e9SAndroid Build Coastguard Worker            # Check that non differentiable view was not updated
10920*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10921*da0073e9SAndroid Build Coastguard Worker            # Check that differentiable view was updated
10922*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(dual)[1], torch.tensor([2.0, 1.0]))
10923*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(view)[1], torch.tensor([2.0]))
10924*da0073e9SAndroid Build Coastguard Worker
10925*da0073e9SAndroid Build Coastguard Worker            # Check that we track differentiable view even for Tensors that are not dual
10926*da0073e9SAndroid Build Coastguard Worker            baz = torch.rand(2)
10927*da0073e9SAndroid Build Coastguard Worker            baz += dual
10928*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(baz)[1], fwAD.unpack_dual(dual)[1])
10929*da0073e9SAndroid Build Coastguard Worker            # Updates on view should as well
10930*da0073e9SAndroid Build Coastguard Worker            baz = torch.rand(2)
10931*da0073e9SAndroid Build Coastguard Worker            baz[0] = dual[0]
10932*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(baz)[1][0], fwAD.unpack_dual(dual)[1][0])
10933*da0073e9SAndroid Build Coastguard Worker            # Unused values get a gradient of 0
10934*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(fwAD.unpack_dual(baz)[1][1], 0.0)
10935*da0073e9SAndroid Build Coastguard Worker
10936*da0073e9SAndroid Build Coastguard Worker            # Check that forward non-differentiable views do prevent gradient update
10937*da0073e9SAndroid Build Coastguard Worker            baz = torch.rand(2)
10938*da0073e9SAndroid Build Coastguard Worker            view = baz.detach()
10939*da0073e9SAndroid Build Coastguard Worker            view += dual
10940*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(baz)[1])
10941*da0073e9SAndroid Build Coastguard Worker
10942*da0073e9SAndroid Build Coastguard Worker    def test_view_inplace_always_creates_a_view(self):
10943*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/67800
10944*da0073e9SAndroid Build Coastguard Worker        # The codepath may depend on the op. At the time writing, when self is not a dual tensor
10945*da0073e9SAndroid Build Coastguard Worker        # the resulting forward grad for self for...
10946*da0073e9SAndroid Build Coastguard Worker        # - add_ has the same layout as self
10947*da0073e9SAndroid Build Coastguard Worker        # - mul_ has the same layout as other
10948*da0073e9SAndroid Build Coastguard Worker        # This is kind of fragile because the above depends on how the forward grad expression
10949*da0073e9SAndroid Build Coastguard Worker        # is written. For add and mul at least, the output inherits the layout of LHS.
10950*da0073e9SAndroid Build Coastguard Worker        # We want to handle at least these two cases.
10951*da0073e9SAndroid Build Coastguard Worker        inplace_binary_ops = (  # Add more to this list?
10952*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x.add_(y),
10953*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x.mul_(y),
10954*da0073e9SAndroid Build Coastguard Worker            lambda x, y: x.copy_(y),
10955*da0073e9SAndroid Build Coastguard Worker        )
10956*da0073e9SAndroid Build Coastguard Worker
10957*da0073e9SAndroid Build Coastguard Worker        for inplace_binary_op in inplace_binary_ops:
10958*da0073e9SAndroid Build Coastguard Worker            base = torch.randn(2, 2)
10959*da0073e9SAndroid Build Coastguard Worker            view = base.transpose(0, 1)
10960*da0073e9SAndroid Build Coastguard Worker
10961*da0073e9SAndroid Build Coastguard Worker            primal = torch.randn(2, 2)
10962*da0073e9SAndroid Build Coastguard Worker            tangent = torch.randn(2, 2)
10963*da0073e9SAndroid Build Coastguard Worker
10964*da0073e9SAndroid Build Coastguard Worker            with fwAD.dual_level():
10965*da0073e9SAndroid Build Coastguard Worker                dual = fwAD.make_dual(primal, tangent)
10966*da0073e9SAndroid Build Coastguard Worker                inplace_binary_op(view, dual)
10967*da0073e9SAndroid Build Coastguard Worker
10968*da0073e9SAndroid Build Coastguard Worker                # Verify that a view relationship is created for both the primal and tangent
10969*da0073e9SAndroid Build Coastguard Worker                p, t = fwAD.unpack_dual(base)
10970*da0073e9SAndroid Build Coastguard Worker                p_clone = p.clone()
10971*da0073e9SAndroid Build Coastguard Worker                t_clone = t.clone()
10972*da0073e9SAndroid Build Coastguard Worker                view *= 2
10973*da0073e9SAndroid Build Coastguard Worker                p, t = fwAD.unpack_dual(base)
10974*da0073e9SAndroid Build Coastguard Worker
10975*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(p_clone * 2, p))
10976*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(t_clone * 2, t))
10977*da0073e9SAndroid Build Coastguard Worker
10978*da0073e9SAndroid Build Coastguard Worker    def test_grad_cleanup(self):
10979*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
10980*da0073e9SAndroid Build Coastguard Worker        bar = torch.rand(2)
10981*da0073e9SAndroid Build Coastguard Worker        baz = torch.rand(2)
10982*da0073e9SAndroid Build Coastguard Worker
10983*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10984*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(foo, bar)
10985*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10986*da0073e9SAndroid Build Coastguard Worker            self.assertIs(fwAD.unpack_dual(dual)[1], bar)
10987*da0073e9SAndroid Build Coastguard Worker
10988*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(fwAD.unpack_dual(dual)[1])
10989*da0073e9SAndroid Build Coastguard Worker
10990*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
10991*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(fwAD.unpack_dual(foo)[1])
10992*da0073e9SAndroid Build Coastguard Worker            new_dual = fwAD.make_dual(foo, baz)
10993*da0073e9SAndroid Build Coastguard Worker
10994*da0073e9SAndroid Build Coastguard Worker            dual_primal, dual_tangent = fwAD.unpack_dual(dual)
10995*da0073e9SAndroid Build Coastguard Worker            new_dual_primal, new_dual_tangent = fwAD.unpack_dual(new_dual)
10996*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dual_primal, new_dual_primal)
10997*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(dual_tangent)
10998*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(new_dual_tangent, baz)
10999*da0073e9SAndroid Build Coastguard Worker
11000*da0073e9SAndroid Build Coastguard Worker    def test_detach_view_tracking(self):
11001*da0073e9SAndroid Build Coastguard Worker        # Default detach is both forward and backward non-differentiable
11002*da0073e9SAndroid Build Coastguard Worker        foo = torch.rand(2)
11003*da0073e9SAndroid Build Coastguard Worker        foo_weak = torch._C._WeakTensorRef(foo)
11004*da0073e9SAndroid Build Coastguard Worker
11005*da0073e9SAndroid Build Coastguard Worker        out = foo.detach()
11006*da0073e9SAndroid Build Coastguard Worker
11007*da0073e9SAndroid Build Coastguard Worker        del foo
11008*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(foo_weak.expired())
11009*da0073e9SAndroid Build Coastguard Worker
11010*da0073e9SAndroid Build Coastguard Worker    def test_out_variant(self):
11011*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
11012*da0073e9SAndroid Build Coastguard Worker            foo = fwAD.make_dual(torch.rand(2), torch.rand(2))
11013*da0073e9SAndroid Build Coastguard Worker            bar = torch.rand(2)
11014*da0073e9SAndroid Build Coastguard Worker
11015*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "out= function"):
11016*da0073e9SAndroid Build Coastguard Worker                torch.add(bar, bar, out=foo)
11017*da0073e9SAndroid Build Coastguard Worker
11018*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "out= function"):
11019*da0073e9SAndroid Build Coastguard Worker                torch.add(foo, bar, out=bar)
11020*da0073e9SAndroid Build Coastguard Worker
11021*da0073e9SAndroid Build Coastguard Worker    def test_non_differentiable(self):
11022*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
11023*da0073e9SAndroid Build Coastguard Worker            foo = fwAD.make_dual(torch.rand(2), torch.rand(2))
11024*da0073e9SAndroid Build Coastguard Worker            bar = torch.rand(2)
11025*da0073e9SAndroid Build Coastguard Worker
11026*da0073e9SAndroid Build Coastguard Worker            # No differentiable outputs, shouldn't error
11027*da0073e9SAndroid Build Coastguard Worker            eq = foo == bar
11028*da0073e9SAndroid Build Coastguard Worker
11029*da0073e9SAndroid Build Coastguard Worker            # Inplace
11030*da0073e9SAndroid Build Coastguard Worker            foo.eq_(bar)
11031*da0073e9SAndroid Build Coastguard Worker
11032*da0073e9SAndroid Build Coastguard Worker    def test_create_new_zeros_with_same_meta(self):
11033*da0073e9SAndroid Build Coastguard Worker        new_zeroes_fn = torch.ops.aten._new_zeros_with_same_feature_meta
11034*da0073e9SAndroid Build Coastguard Worker
11035*da0073e9SAndroid Build Coastguard Worker        def check(a, b):
11036*da0073e9SAndroid Build Coastguard Worker            def assert_same_meta(t, target):
11037*da0073e9SAndroid Build Coastguard Worker                for num_bdim in range(t.dim()):
11038*da0073e9SAndroid Build Coastguard Worker                    result = new_zeroes_fn(t, target, self_num_batch_dims=num_bdim)
11039*da0073e9SAndroid Build Coastguard Worker
11040*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result.dim(), target.dim() + num_bdim)
11041*da0073e9SAndroid Build Coastguard Worker
11042*da0073e9SAndroid Build Coastguard Worker                    # Check size/strides match for feature dims only
11043*da0073e9SAndroid Build Coastguard Worker                    for i in range(num_bdim, result.dim()):
11044*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(result.size()[i], target.size()[i - num_bdim])
11045*da0073e9SAndroid Build Coastguard Worker                        self.assertEqual(
11046*da0073e9SAndroid Build Coastguard Worker                            result.stride()[i], target.stride()[i - num_bdim]
11047*da0073e9SAndroid Build Coastguard Worker                        )
11048*da0073e9SAndroid Build Coastguard Worker
11049*da0073e9SAndroid Build Coastguard Worker                    # Check that we generate strides reasonably
11050*da0073e9SAndroid Build Coastguard Worker                    if target.is_contiguous():
11051*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(result.is_contiguous())
11052*da0073e9SAndroid Build Coastguard Worker
11053*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result.storage_offset(), target.storage_offset())
11054*da0073e9SAndroid Build Coastguard Worker
11055*da0073e9SAndroid Build Coastguard Worker                    prod_of_t_bdims = reduce(operator.mul, t.size()[:num_bdim], 1)
11056*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
11057*da0073e9SAndroid Build Coastguard Worker                        len(result.storage()), len(target.storage()) * prod_of_t_bdims
11058*da0073e9SAndroid Build Coastguard Worker                    )
11059*da0073e9SAndroid Build Coastguard Worker
11060*da0073e9SAndroid Build Coastguard Worker                    # TensorOptions is same
11061*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(result.dtype, target.dtype)
11062*da0073e9SAndroid Build Coastguard Worker
11063*da0073e9SAndroid Build Coastguard Worker            assert_same_meta(a, b)
11064*da0073e9SAndroid Build Coastguard Worker            assert_same_meta(b, a)
11065*da0073e9SAndroid Build Coastguard Worker
11066*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5, dtype=torch.float)
11067*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 3, 4, dtype=torch.double)
11068*da0073e9SAndroid Build Coastguard Worker        check(a, b)
11069*da0073e9SAndroid Build Coastguard Worker
11070*da0073e9SAndroid Build Coastguard Worker        # non-contiguous case
11071*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(2, 3, 4).transpose(0, 1).contiguous().transpose(0, 1)
11072*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 3, 4)
11073*da0073e9SAndroid Build Coastguard Worker        check(a, b)
11074*da0073e9SAndroid Build Coastguard Worker
11075*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5).narrow(0, 1, 2)
11076*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2)
11077*da0073e9SAndroid Build Coastguard Worker        check(a, b)
11078*da0073e9SAndroid Build Coastguard Worker
11079*da0073e9SAndroid Build Coastguard Worker        # tensor is not a view, but still does not index entirety of storage
11080*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(5).resize_(4)
11081*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4)
11082*da0073e9SAndroid Build Coastguard Worker        check(a, b)
11083*da0073e9SAndroid Build Coastguard Worker
11084*da0073e9SAndroid Build Coastguard Worker        # Zero-numel tensors
11085*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(1, 0, 2)
11086*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1, 2)
11087*da0073e9SAndroid Build Coastguard Worker        check(a, b)
11088*da0073e9SAndroid Build Coastguard Worker
11089*da0073e9SAndroid Build Coastguard Worker        # Scalar tensor
11090*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0)
11091*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(1, 2)
11092*da0073e9SAndroid Build Coastguard Worker        check(a, b)
11093*da0073e9SAndroid Build Coastguard Worker
11094*da0073e9SAndroid Build Coastguard Worker    def test_backward_graph_destruction(self):
11095*da0073e9SAndroid Build Coastguard Worker        def fn():
11096*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(10, requires_grad=True)
11097*da0073e9SAndroid Build Coastguard Worker
11098*da0073e9SAndroid Build Coastguard Worker            da = fwAD.make_dual(torch.rand_like(a), a)
11099*da0073e9SAndroid Build Coastguard Worker
11100*da0073e9SAndroid Build Coastguard Worker            # Create an object with a c++ cycle as:
11101*da0073e9SAndroid Build Coastguard Worker            # db -> AutogradMeta -> ForwardGrad -> db's grad
11102*da0073e9SAndroid Build Coastguard Worker            # db's grad -> AutogradMeta -> MulBackward
11103*da0073e9SAndroid Build Coastguard Worker            # MulBackward -> SavedVariable -> db
11104*da0073e9SAndroid Build Coastguard Worker            db = da.exp()
11105*da0073e9SAndroid Build Coastguard Worker
11106*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
11107*da0073e9SAndroid Build Coastguard Worker            fn()
11108*da0073e9SAndroid Build Coastguard Worker        # This test make sure that we don't deadlock on exit of this
11109*da0073e9SAndroid Build Coastguard Worker        # context manager. If you do, there is something wrong with the
11110*da0073e9SAndroid Build Coastguard Worker        # locking of the forward ad level most likely
11111*da0073e9SAndroid Build Coastguard Worker
11112*da0073e9SAndroid Build Coastguard Worker
11113*da0073e9SAndroid Build Coastguard Worker# Generic device type autograd tests.
11114*da0073e9SAndroid Build Coastguard Workerclass TestAutogradDeviceType(TestCase):
11115*da0073e9SAndroid Build Coastguard Worker    def test_min_max_median_backprops_to_all_values(self, device):
11116*da0073e9SAndroid Build Coastguard Worker        for f in [torch.min, torch.max, torch.median, torch.nanmedian]:
11117*da0073e9SAndroid Build Coastguard Worker            x1 = torch.tensor(
11118*da0073e9SAndroid Build Coastguard Worker                [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True
11119*da0073e9SAndroid Build Coastguard Worker            )
11120*da0073e9SAndroid Build Coastguard Worker            x2 = torch.tensor(
11121*da0073e9SAndroid Build Coastguard Worker                [float("nan"), float("nan"), float("nan")], requires_grad=True
11122*da0073e9SAndroid Build Coastguard Worker            )
11123*da0073e9SAndroid Build Coastguard Worker            for x in [x1, x2]:
11124*da0073e9SAndroid Build Coastguard Worker                y = f(x)
11125*da0073e9SAndroid Build Coastguard Worker                y.backward()
11126*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x.grad.sum(), 1.0)
11127*da0073e9SAndroid Build Coastguard Worker                self.assertEqual((x.grad == 1 / 3).sum(), 3)
11128*da0073e9SAndroid Build Coastguard Worker
11129*da0073e9SAndroid Build Coastguard Worker    def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device):
11130*da0073e9SAndroid Build Coastguard Worker        # tests that gradients are evenly distributed when there are multiple max/min values
11131*da0073e9SAndroid Build Coastguard Worker        # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad
11132*da0073e9SAndroid Build Coastguard Worker        # as is the case for test_min_max_median_backprops_to_all_values above
11133*da0073e9SAndroid Build Coastguard Worker        fns = (torch.scatter_reduce, torch.index_reduce)
11134*da0073e9SAndroid Build Coastguard Worker        reduces = ("amin", "amax")
11135*da0073e9SAndroid Build Coastguard Worker        for fn, reduction in product(fns, reduces):
11136*da0073e9SAndroid Build Coastguard Worker            input = torch.randn(
11137*da0073e9SAndroid Build Coastguard Worker                (2, 3), device=device, dtype=torch.float64, requires_grad=True
11138*da0073e9SAndroid Build Coastguard Worker            )
11139*da0073e9SAndroid Build Coastguard Worker            src = input.clone().detach_().requires_grad_(True)
11140*da0073e9SAndroid Build Coastguard Worker            idx = torch.arange(2).to(dtype=torch.long, device=device)
11141*da0073e9SAndroid Build Coastguard Worker            if fn == torch.scatter_reduce:
11142*da0073e9SAndroid Build Coastguard Worker                idx = idx.unsqueeze(-1).expand((2, 3))
11143*da0073e9SAndroid Build Coastguard Worker
11144*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (input, 0, idx, src, reduction), check_batched_grad=False)
11145*da0073e9SAndroid Build Coastguard Worker
11146*da0073e9SAndroid Build Coastguard Worker    def test_scatter_index_reduce_prod_gradgrad_error(self, device):
11147*da0073e9SAndroid Build Coastguard Worker        # test that double backward raises an error for the case where 2 zeros in src
11148*da0073e9SAndroid Build Coastguard Worker        # are scattered to the same position in self
11149*da0073e9SAndroid Build Coastguard Worker        input = torch.tensor(
11150*da0073e9SAndroid Build Coastguard Worker            [1.0], device=device, dtype=torch.float64, requires_grad=True
11151*da0073e9SAndroid Build Coastguard Worker        )
11152*da0073e9SAndroid Build Coastguard Worker        src = torch.tensor(
11153*da0073e9SAndroid Build Coastguard Worker            [0.0, 0.0], device=device, dtype=torch.float64, requires_grad=True
11154*da0073e9SAndroid Build Coastguard Worker        )
11155*da0073e9SAndroid Build Coastguard Worker        idx = torch.tensor([0, 0], device=device, dtype=torch.long)
11156*da0073e9SAndroid Build Coastguard Worker
11157*da0073e9SAndroid Build Coastguard Worker        for fn in (torch.scatter_reduce, torch.index_reduce):
11158*da0073e9SAndroid Build Coastguard Worker            # check that this case passes on gradcheck
11159*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (input, 0, idx, src, "prod"), check_batched_grad=False)
11160*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
11161*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "Double backward is unsupported for"
11162*da0073e9SAndroid Build Coastguard Worker            ):
11163*da0073e9SAndroid Build Coastguard Worker                gradgradcheck(fn, (input, 0, idx, src, "prod"))
11164*da0073e9SAndroid Build Coastguard Worker
11165*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11166*da0073e9SAndroid Build Coastguard Worker    def test_parameter_resize(self, device):
11167*da0073e9SAndroid Build Coastguard Worker        asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device))
11168*da0073e9SAndroid Build Coastguard Worker
11169*da0073e9SAndroid Build Coastguard Worker        for i in range(2):
11170*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
11171*da0073e9SAndroid Build Coastguard Worker                asd.set_(asd[1:])
11172*da0073e9SAndroid Build Coastguard Worker                asd.grad = None
11173*da0073e9SAndroid Build Coastguard Worker
11174*da0073e9SAndroid Build Coastguard Worker            m = torch.cat((asd, asd))
11175*da0073e9SAndroid Build Coastguard Worker            m.sum().backward()
11176*da0073e9SAndroid Build Coastguard Worker
11177*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11178*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
11179*da0073e9SAndroid Build Coastguard Worker    def test_sparse_ctor_getter_backward(self, device, dtype):
11180*da0073e9SAndroid Build Coastguard Worker        # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test
11181*da0073e9SAndroid Build Coastguard Worker        def _test(size, sparse_dim, nnz, device):
11182*da0073e9SAndroid Build Coastguard Worker            v_size = [nnz] + list(size[sparse_dim:])
11183*da0073e9SAndroid Build Coastguard Worker            i = torch.rand(sparse_dim, nnz)
11184*da0073e9SAndroid Build Coastguard Worker            i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i))
11185*da0073e9SAndroid Build Coastguard Worker            i = i.to(torch.long)
11186*da0073e9SAndroid Build Coastguard Worker
11187*da0073e9SAndroid Build Coastguard Worker            inp = torch.randn(
11188*da0073e9SAndroid Build Coastguard Worker                v_size, dtype=torch.double, device=device, requires_grad=True
11189*da0073e9SAndroid Build Coastguard Worker            )
11190*da0073e9SAndroid Build Coastguard Worker            other = self.genSparseTensor(
11191*da0073e9SAndroid Build Coastguard Worker                size, sparse_dim, nnz, is_uncoalesced=True, device=device, dtype=dtype
11192*da0073e9SAndroid Build Coastguard Worker            )[0]
11193*da0073e9SAndroid Build Coastguard Worker
11194*da0073e9SAndroid Build Coastguard Worker            def fn(v):
11195*da0073e9SAndroid Build Coastguard Worker                x = torch.sparse_coo_tensor(i, v, size, dtype=dtype, device=device)
11196*da0073e9SAndroid Build Coastguard Worker                y = (x + other).coalesce()
11197*da0073e9SAndroid Build Coastguard Worker                yv = y.values()
11198*da0073e9SAndroid Build Coastguard Worker                new_v = yv.tanh()
11199*da0073e9SAndroid Build Coastguard Worker                z = torch.sparse_coo_tensor(y.indices(), new_v, y.size())
11200*da0073e9SAndroid Build Coastguard Worker                return z.coalesce().values()
11201*da0073e9SAndroid Build Coastguard Worker
11202*da0073e9SAndroid Build Coastguard Worker            gradcheck(fn, (inp,), check_batched_grad=False)
11203*da0073e9SAndroid Build Coastguard Worker            # FIXME: make gradgradcheck work.
11204*da0073e9SAndroid Build Coastguard Worker            # gradgradcheck(fn, (inp,), check_batched_grad=False)
11205*da0073e9SAndroid Build Coastguard Worker
11206*da0073e9SAndroid Build Coastguard Worker            # assert that _values is non-differentiable
11207*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"):
11208*da0073e9SAndroid Build Coastguard Worker                other.detach().requires_grad_()._values().backward(
11209*da0073e9SAndroid Build Coastguard Worker                    torch.ones_like(other._values())
11210*da0073e9SAndroid Build Coastguard Worker                )
11211*da0073e9SAndroid Build Coastguard Worker
11212*da0073e9SAndroid Build Coastguard Worker        for empty_i, empty_v, empty_nnz in product([True, False], repeat=3):
11213*da0073e9SAndroid Build Coastguard Worker            sparse_size = [] if empty_i else [2, 1]
11214*da0073e9SAndroid Build Coastguard Worker            dense_size = [1, 0, 2] if empty_v else [1, 2]
11215*da0073e9SAndroid Build Coastguard Worker            nnz = 0 if empty_nnz else 5
11216*da0073e9SAndroid Build Coastguard Worker            _test(sparse_size + dense_size, len(sparse_size), nnz, device)
11217*da0073e9SAndroid Build Coastguard Worker
11218*da0073e9SAndroid Build Coastguard Worker    @skipMeta
11219*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
11220*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.double, torch.cdouble)
11221*da0073e9SAndroid Build Coastguard Worker    def test_sparse_backward(self, device, dtype):
11222*da0073e9SAndroid Build Coastguard Worker        class FixedGradientFunction(Function):
11223*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11224*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, grad_x):
11225*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(grad_x)
11226*da0073e9SAndroid Build Coastguard Worker                return x
11227*da0073e9SAndroid Build Coastguard Worker
11228*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11229*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_x):
11230*da0073e9SAndroid Build Coastguard Worker                (saved_grad_x,) = ctx.saved_tensors
11231*da0073e9SAndroid Build Coastguard Worker                return saved_grad_x, None
11232*da0073e9SAndroid Build Coastguard Worker
11233*da0073e9SAndroid Build Coastguard Worker        size = torch.Size([6, 3, 2])
11234*da0073e9SAndroid Build Coastguard Worker        i1 = torch.tensor([[0, 3, 4], [0, 2, 2]], dtype=torch.long)
11235*da0073e9SAndroid Build Coastguard Worker        v1 = make_tensor([3, 2], dtype=dtype, device=device)
11236*da0073e9SAndroid Build Coastguard Worker        sparse_grad1 = torch.sparse_coo_tensor(i1, v1, size, dtype=dtype, device=device)
11237*da0073e9SAndroid Build Coastguard Worker        i2 = torch.tensor([[0, 1, 3, 4], [0, 1, 2, 2]], dtype=torch.long)
11238*da0073e9SAndroid Build Coastguard Worker        v2 = make_tensor([4, 2], dtype=dtype, device=device)
11239*da0073e9SAndroid Build Coastguard Worker        sparse_grad2 = torch.sparse_coo_tensor(i2, v2, size, dtype=dtype, device=device)
11240*da0073e9SAndroid Build Coastguard Worker        dense_grad = torch.rand(size, device=device, dtype=dtype)
11241*da0073e9SAndroid Build Coastguard Worker        fn = FixedGradientFunction
11242*da0073e9SAndroid Build Coastguard Worker
11243*da0073e9SAndroid Build Coastguard Worker        # sparse first
11244*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
11245*da0073e9SAndroid Build Coastguard Worker        (
11246*da0073e9SAndroid Build Coastguard Worker            fn.apply(x, sparse_grad1)
11247*da0073e9SAndroid Build Coastguard Worker            + fn.apply(x, dense_grad)
11248*da0073e9SAndroid Build Coastguard Worker            + fn.apply(x, sparse_grad2)
11249*da0073e9SAndroid Build Coastguard Worker        ).sum().abs().backward()
11250*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
11251*da0073e9SAndroid Build Coastguard Worker        # dense first
11252*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
11253*da0073e9SAndroid Build Coastguard Worker        (
11254*da0073e9SAndroid Build Coastguard Worker            fn.apply(x, dense_grad)
11255*da0073e9SAndroid Build Coastguard Worker            + fn.apply(x, sparse_grad1)
11256*da0073e9SAndroid Build Coastguard Worker            + fn.apply(x, sparse_grad2)
11257*da0073e9SAndroid Build Coastguard Worker        ).sum().abs().backward()
11258*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2)
11259*da0073e9SAndroid Build Coastguard Worker        # sparse only
11260*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(size, dtype=dtype, device=device, requires_grad=True)
11261*da0073e9SAndroid Build Coastguard Worker        (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward()
11262*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, sparse_grad1 + sparse_grad2)
11263*da0073e9SAndroid Build Coastguard Worker
11264*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
11265*da0073e9SAndroid Build Coastguard Worker    def test_sparse_mask_autograd(self, device):
11266*da0073e9SAndroid Build Coastguard Worker        tensor = torch.randn(3, requires_grad=True, device=device)
11267*da0073e9SAndroid Build Coastguard Worker        mask = torch.ones(3, device=device)
11268*da0073e9SAndroid Build Coastguard Worker        mask[1] = 0
11269*da0073e9SAndroid Build Coastguard Worker        mask = mask.to_sparse()
11270*da0073e9SAndroid Build Coastguard Worker        converted = tensor.sparse_mask(mask).to_dense()
11271*da0073e9SAndroid Build Coastguard Worker        converted.sum().backward()
11272*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(tensor.grad, mask.to_dense())
11273*da0073e9SAndroid Build Coastguard Worker
11274*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11275*da0073e9SAndroid Build Coastguard Worker    def test_pyscalar_conversions(self, device):
11276*da0073e9SAndroid Build Coastguard Worker        def _test_pyscalar_conversions(t, integral_conv):
11277*da0073e9SAndroid Build Coastguard Worker            # integral -> integral
11278*da0073e9SAndroid Build Coastguard Worker            l = t(torch.zeros(1, 1, 1, dtype=torch.long))
11279*da0073e9SAndroid Build Coastguard Worker            pyscalar = -12345
11280*da0073e9SAndroid Build Coastguard Worker            l[0] = pyscalar
11281*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(integral_conv(l), pyscalar)
11282*da0073e9SAndroid Build Coastguard Worker
11283*da0073e9SAndroid Build Coastguard Worker            # floating point -> floating point
11284*da0073e9SAndroid Build Coastguard Worker            f = Variable(t(torch.randn(1, 1, dtype=torch.double)))
11285*da0073e9SAndroid Build Coastguard Worker            pyscalar = -12345.1
11286*da0073e9SAndroid Build Coastguard Worker            f[0] = pyscalar
11287*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(float(f), pyscalar)
11288*da0073e9SAndroid Build Coastguard Worker            f[0] = nan
11289*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(math.isnan(float(f)))
11290*da0073e9SAndroid Build Coastguard Worker            f[0] = inf
11291*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(float(f), inf)
11292*da0073e9SAndroid Build Coastguard Worker            f[0] = -inf
11293*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(float(f), -inf)
11294*da0073e9SAndroid Build Coastguard Worker
11295*da0073e9SAndroid Build Coastguard Worker            # integral -> floating point
11296*da0073e9SAndroid Build Coastguard Worker            # check we can convert something that loses precision
11297*da0073e9SAndroid Build Coastguard Worker            pyscalar = 1234567890123456789
11298*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(pyscalar, integral_conv(float(pyscalar)))
11299*da0073e9SAndroid Build Coastguard Worker            l[0] = pyscalar
11300*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(float(l), float(pyscalar))
11301*da0073e9SAndroid Build Coastguard Worker
11302*da0073e9SAndroid Build Coastguard Worker            # floating point -> integral
11303*da0073e9SAndroid Build Coastguard Worker            f[0] = nan
11304*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(ValueError, lambda: integral_conv(f[0]))
11305*da0073e9SAndroid Build Coastguard Worker            f[0] = inf
11306*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
11307*da0073e9SAndroid Build Coastguard Worker            f[0] = -inf
11308*da0073e9SAndroid Build Coastguard Worker            self.assertRaises(OverflowError, lambda: integral_conv(f[0]))
11309*da0073e9SAndroid Build Coastguard Worker            f[0] = sys.float_info.max
11310*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(integral_conv(f), sys.float_info.max)
11311*da0073e9SAndroid Build Coastguard Worker
11312*da0073e9SAndroid Build Coastguard Worker            # bool, nonzero
11313*da0073e9SAndroid Build Coastguard Worker            def test_nonzero(tensor, value, expected):
11314*da0073e9SAndroid Build Coastguard Worker                tensor[0] = value
11315*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, bool(tensor))
11316*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(expected, True if tensor else False)
11317*da0073e9SAndroid Build Coastguard Worker
11318*da0073e9SAndroid Build Coastguard Worker            test_nonzero(l, 0, False)
11319*da0073e9SAndroid Build Coastguard Worker            test_nonzero(l, -2, True)
11320*da0073e9SAndroid Build Coastguard Worker            test_nonzero(f, 0.0, False)
11321*da0073e9SAndroid Build Coastguard Worker            test_nonzero(f, sys.float_info.min, True)
11322*da0073e9SAndroid Build Coastguard Worker            test_nonzero(f, nan, bool(nan))
11323*da0073e9SAndroid Build Coastguard Worker            test_nonzero(f, inf, bool(inf))
11324*da0073e9SAndroid Build Coastguard Worker            test_nonzero(f, -inf, bool(-inf))
11325*da0073e9SAndroid Build Coastguard Worker
11326*da0073e9SAndroid Build Coastguard Worker        _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x))
11327*da0073e9SAndroid Build Coastguard Worker
11328*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float32)
11329*da0073e9SAndroid Build Coastguard Worker    @dtypesIfCUDA(
11330*da0073e9SAndroid Build Coastguard Worker        torch.half,
11331*da0073e9SAndroid Build Coastguard Worker        torch.float,
11332*da0073e9SAndroid Build Coastguard Worker        torch.double,
11333*da0073e9SAndroid Build Coastguard Worker        torch.int8,
11334*da0073e9SAndroid Build Coastguard Worker        torch.int16,
11335*da0073e9SAndroid Build Coastguard Worker        torch.int32,
11336*da0073e9SAndroid Build Coastguard Worker        torch.int64,
11337*da0073e9SAndroid Build Coastguard Worker    )
11338*da0073e9SAndroid Build Coastguard Worker    @dtypes(
11339*da0073e9SAndroid Build Coastguard Worker        torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64
11340*da0073e9SAndroid Build Coastguard Worker    )
11341*da0073e9SAndroid Build Coastguard Worker    def test_set_requires_grad_only_for_floats(self, device, dtype):
11342*da0073e9SAndroid Build Coastguard Worker        def f1():
11343*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1, dtype=dtype, device=device)
11344*da0073e9SAndroid Build Coastguard Worker            a.requires_grad_()
11345*da0073e9SAndroid Build Coastguard Worker
11346*da0073e9SAndroid Build Coastguard Worker        def f2():
11347*da0073e9SAndroid Build Coastguard Worker            a = torch.ones(1, dtype=dtype, device=device)
11348*da0073e9SAndroid Build Coastguard Worker            a.requires_grad = True
11349*da0073e9SAndroid Build Coastguard Worker
11350*da0073e9SAndroid Build Coastguard Worker        def f3():
11351*da0073e9SAndroid Build Coastguard Worker            torch.ones(1, dtype=dtype, device=device, requires_grad=True)
11352*da0073e9SAndroid Build Coastguard Worker
11353*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(1, dtype=dtype, device=device)
11354*da0073e9SAndroid Build Coastguard Worker        a.requires_grad = False  # should always work
11355*da0073e9SAndroid Build Coastguard Worker        a.requires_grad_(False)
11356*da0073e9SAndroid Build Coastguard Worker
11357*da0073e9SAndroid Build Coastguard Worker        for f in [f1, f2, f3]:
11358*da0073e9SAndroid Build Coastguard Worker            if dtype.is_floating_point:
11359*da0073e9SAndroid Build Coastguard Worker                f()
11360*da0073e9SAndroid Build Coastguard Worker            else:
11361*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
11362*da0073e9SAndroid Build Coastguard Worker                    RuntimeError,
11363*da0073e9SAndroid Build Coastguard Worker                    "floating point",
11364*da0073e9SAndroid Build Coastguard Worker                    msg=f"dt: {a.dtype} device: {a.device}",
11365*da0073e9SAndroid Build Coastguard Worker                ):
11366*da0073e9SAndroid Build Coastguard Worker                    f()
11367*da0073e9SAndroid Build Coastguard Worker
11368*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11369*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_backwards_large(self, device):
11370*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/22843
11371*da0073e9SAndroid Build Coastguard Worker        n = 1 << 16
11372*da0073e9SAndroid Build Coastguard Worker        x = torch.rand(n, 1, device=device, requires_grad=True)
11373*da0073e9SAndroid Build Coastguard Worker        a = x[:, [0]]
11374*da0073e9SAndroid Build Coastguard Worker        a.sum().backward()
11375*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(n, 1, device=device))
11376*da0073e9SAndroid Build Coastguard Worker
11377*da0073e9SAndroid Build Coastguard Worker    def test_advanced_indexing_backwards_memory_format(self, device):
11378*da0073e9SAndroid Build Coastguard Worker        # See https://github.com/pytorch/pytorch/issues/36956
11379*da0073e9SAndroid Build Coastguard Worker        shape = (2, 8, 1, 2)
11380*da0073e9SAndroid Build Coastguard Worker        i = torch.randint(1, shape, device=device).contiguous(
11381*da0073e9SAndroid Build Coastguard Worker            memory_format=torch.channels_last
11382*da0073e9SAndroid Build Coastguard Worker        )
11383*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(shape, requires_grad=True, device=device)
11384*da0073e9SAndroid Build Coastguard Worker        x[i].sum().backward()
11385*da0073e9SAndroid Build Coastguard Worker
11386*da0073e9SAndroid Build Coastguard Worker    def _test_reentrant_parent_error_on_cpu(self, device):
11387*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand([3, 3], requires_grad=True)
11388*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand([3, 3], device=device, requires_grad=True)
11389*da0073e9SAndroid Build Coastguard Worker        t3 = torch.rand([3, 3], device=device, requires_grad=True)
11390*da0073e9SAndroid Build Coastguard Worker
11391*da0073e9SAndroid Build Coastguard Worker        # Parent graph cpu graph.
11392*da0073e9SAndroid Build Coastguard Worker        t4 = t1 * t1
11393*da0073e9SAndroid Build Coastguard Worker        t5 = TestAutograd.SimulateBackwardError.apply(t4)
11394*da0073e9SAndroid Build Coastguard Worker
11395*da0073e9SAndroid Build Coastguard Worker        # Child gpu graph (much longer than parent graph).
11396*da0073e9SAndroid Build Coastguard Worker        prev = t2 * t2
11397*da0073e9SAndroid Build Coastguard Worker        for i in range(10):
11398*da0073e9SAndroid Build Coastguard Worker            prev = prev * t2
11399*da0073e9SAndroid Build Coastguard Worker        reentrant_root = prev
11400*da0073e9SAndroid Build Coastguard Worker
11401*da0073e9SAndroid Build Coastguard Worker        class ReentrantFunc(Function):
11402*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11403*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, inp):
11404*da0073e9SAndroid Build Coastguard Worker                return inp.clone()
11405*da0073e9SAndroid Build Coastguard Worker
11406*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11407*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
11408*da0073e9SAndroid Build Coastguard Worker                # Reentrant backward in child will take much longer.
11409*da0073e9SAndroid Build Coastguard Worker                reentrant_root.backward()
11410*da0073e9SAndroid Build Coastguard Worker                return grad
11411*da0073e9SAndroid Build Coastguard Worker
11412*da0073e9SAndroid Build Coastguard Worker        # Parent gpu graph.
11413*da0073e9SAndroid Build Coastguard Worker        t6 = ReentrantFunc.apply(t3)
11414*da0073e9SAndroid Build Coastguard Worker        t7 = t6 * t6
11415*da0073e9SAndroid Build Coastguard Worker
11416*da0073e9SAndroid Build Coastguard Worker        # Parent graph will error out first, while child graph will continue executing.
11417*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(Exception, "Simulate error"):
11418*da0073e9SAndroid Build Coastguard Worker            torch.autograd.backward([t5.sum(), t7.sum()])
11419*da0073e9SAndroid Build Coastguard Worker
11420*da0073e9SAndroid Build Coastguard Worker        # No grads should be accumulated since child graph will stop execution
11421*da0073e9SAndroid Build Coastguard Worker        # after parent receives error.
11422*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(t2.grad)
11423*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(t1.grad)
11424*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(t3.grad)
11425*da0073e9SAndroid Build Coastguard Worker
11426*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11427*da0073e9SAndroid Build Coastguard Worker    def test_reentrant_parent_error_on_cpu(self, device):
11428*da0073e9SAndroid Build Coastguard Worker        def _get_cuda_memory_usage():
11429*da0073e9SAndroid Build Coastguard Worker            # we don't need CUDA synchronize because the statistics are not tracked at
11430*da0073e9SAndroid Build Coastguard Worker            # actual freeing, but at when marking the block as free.
11431*da0073e9SAndroid Build Coastguard Worker            num_devices = torch.cuda.device_count()
11432*da0073e9SAndroid Build Coastguard Worker            gc.collect()
11433*da0073e9SAndroid Build Coastguard Worker            return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices))
11434*da0073e9SAndroid Build Coastguard Worker
11435*da0073e9SAndroid Build Coastguard Worker        before = _get_cuda_memory_usage()
11436*da0073e9SAndroid Build Coastguard Worker
11437*da0073e9SAndroid Build Coastguard Worker        # Run as separate function so that gc can clean up everything when we
11438*da0073e9SAndroid Build Coastguard Worker        # check for memory usage.
11439*da0073e9SAndroid Build Coastguard Worker        self._test_reentrant_parent_error_on_cpu(device)
11440*da0073e9SAndroid Build Coastguard Worker
11441*da0073e9SAndroid Build Coastguard Worker        # Wait for autograd thread to cleanup failed tasks.
11442*da0073e9SAndroid Build Coastguard Worker        after = _get_cuda_memory_usage()
11443*da0073e9SAndroid Build Coastguard Worker        start = time.time()
11444*da0073e9SAndroid Build Coastguard Worker        while before != after and time.time() - start < 30:
11445*da0073e9SAndroid Build Coastguard Worker            time.sleep(0.1)
11446*da0073e9SAndroid Build Coastguard Worker            after = _get_cuda_memory_usage()
11447*da0073e9SAndroid Build Coastguard Worker
11448*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(before, after)
11449*da0073e9SAndroid Build Coastguard Worker
11450*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS
11451*da0073e9SAndroid Build Coastguard Worker    # TODO: see if these tests can be ported to OpInfos or moved to where's test suite
11452*da0073e9SAndroid Build Coastguard Worker    def test_where_functional(self, device):
11453*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
11454*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
11455*da0073e9SAndroid Build Coastguard Worker        cond = mask_not_all_zeros((5, 5)).to(device=device)
11456*da0073e9SAndroid Build Coastguard Worker
11457*da0073e9SAndroid Build Coastguard Worker        def where(cond, x, y):
11458*da0073e9SAndroid Build Coastguard Worker            return torch.where(cond, x, y)
11459*da0073e9SAndroid Build Coastguard Worker
11460*da0073e9SAndroid Build Coastguard Worker        gradcheck(where, [cond, x, y], raise_exception=True)
11461*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)])
11462*da0073e9SAndroid Build Coastguard Worker
11463*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 1, 5, dtype=torch.double, device=device, requires_grad=True)
11464*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(5, 5, 1, dtype=torch.double, device=device, requires_grad=True)
11465*da0073e9SAndroid Build Coastguard Worker        gradcheck(where, [cond, x, y], raise_exception=True)
11466*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)])
11467*da0073e9SAndroid Build Coastguard Worker
11468*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS
11469*da0073e9SAndroid Build Coastguard Worker    def test_where_scalar(self, device):
11470*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True)
11471*da0073e9SAndroid Build Coastguard Worker        scalar = 4.0
11472*da0073e9SAndroid Build Coastguard Worker        cond = mask_not_all_zeros((5, 5)).to(device=device)
11473*da0073e9SAndroid Build Coastguard Worker
11474*da0073e9SAndroid Build Coastguard Worker        def where_scalar_first(cond, x):
11475*da0073e9SAndroid Build Coastguard Worker            return torch.where(cond, scalar, x)
11476*da0073e9SAndroid Build Coastguard Worker
11477*da0073e9SAndroid Build Coastguard Worker        def where_scalar_second(cond, x):
11478*da0073e9SAndroid Build Coastguard Worker            return torch.where(cond, x, scalar)
11479*da0073e9SAndroid Build Coastguard Worker
11480*da0073e9SAndroid Build Coastguard Worker        gradcheck(where_scalar_first, (cond, x))
11481*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(where_scalar_first, (cond, x))
11482*da0073e9SAndroid Build Coastguard Worker
11483*da0073e9SAndroid Build Coastguard Worker        gradcheck(where_scalar_second, (cond, x))
11484*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(where_scalar_second, (cond, x))
11485*da0073e9SAndroid Build Coastguard Worker
11486*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11487*da0073e9SAndroid Build Coastguard Worker    def test_free_unneeded_tensor(self, device):
11488*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True)
11489*da0073e9SAndroid Build Coastguard Worker        m = torch.randn(1, 3, 1, 1, device=device)
11490*da0073e9SAndroid Build Coastguard Worker
11491*da0073e9SAndroid Build Coastguard Worker        z = x.sum()
11492*da0073e9SAndroid Build Coastguard Worker        base_mem = torch.cuda.memory_allocated()
11493*da0073e9SAndroid Build Coastguard Worker        z = ((x + 2) * m).sum()
11494*da0073e9SAndroid Build Coastguard Worker        end_mem = torch.cuda.memory_allocated()
11495*da0073e9SAndroid Build Coastguard Worker
11496*da0073e9SAndroid Build Coastguard Worker        # In the end the memory usage should remain equal, because neither of
11497*da0073e9SAndroid Build Coastguard Worker        # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the
11498*da0073e9SAndroid Build Coastguard Worker        # previous allocation of z had the same size as the current one.
11499*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(base_mem, end_mem)
11500*da0073e9SAndroid Build Coastguard Worker
11501*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11502*da0073e9SAndroid Build Coastguard Worker    def test_pin_memory(self, device):
11503*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, dtype=torch.double, requires_grad=True)
11504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x.pin_memory())
11505*da0073e9SAndroid Build Coastguard Worker        self.assertIsNot(x, x.pin_memory())
11506*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(x.pin_memory().requires_grad)
11507*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: x.pin_memory(), [x])
11508*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(lambda x: x.pin_memory(), [x])
11509*da0073e9SAndroid Build Coastguard Worker
11510*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11511*da0073e9SAndroid Build Coastguard Worker    def test_profiler_emit_nvtx(self, device):
11512*da0073e9SAndroid Build Coastguard Worker        # This test is not intended to ensure correctness of nvtx ranges.
11513*da0073e9SAndroid Build Coastguard Worker        # That would require something a great deal more complex (you'd have to create a
11514*da0073e9SAndroid Build Coastguard Worker        # profile in a subprocess, open it, and parse the sql somehow).
11515*da0073e9SAndroid Build Coastguard Worker        # This test is merely intended to catch if emit_nvtx breaks on construction.
11516*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
11517*da0073e9SAndroid Build Coastguard Worker        with torch.cuda.profiler.profile():
11518*da0073e9SAndroid Build Coastguard Worker            with emit_nvtx():
11519*da0073e9SAndroid Build Coastguard Worker                a.add(1.0)
11520*da0073e9SAndroid Build Coastguard Worker
11521*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11522*da0073e9SAndroid Build Coastguard Worker    def test_rnn_backward_to_input_but_not_parameters(self, device):
11523*da0073e9SAndroid Build Coastguard Worker        # this checks whether it is possible to not require
11524*da0073e9SAndroid Build Coastguard Worker        # weight parameters, but require inputs, see #7722
11525*da0073e9SAndroid Build Coastguard Worker        l = torch.nn.LSTM(2, 3).to(device)
11526*da0073e9SAndroid Build Coastguard Worker        for p in l.parameters():
11527*da0073e9SAndroid Build Coastguard Worker            p.requires_grad = False
11528*da0073e9SAndroid Build Coastguard Worker        s = torch.randn(1, 1, 2, requires_grad=True, device=device)
11529*da0073e9SAndroid Build Coastguard Worker        out, _ = l(s)
11530*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
11531*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0)
11532*da0073e9SAndroid Build Coastguard Worker
11533*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required")
11534*da0073e9SAndroid Build Coastguard Worker    def test_profiler_emit_itt(self, device):
11535*da0073e9SAndroid Build Coastguard Worker        # This test is not intended to ensure correctness of itt ranges.
11536*da0073e9SAndroid Build Coastguard Worker        # That would require something a great deal more complex (you'd have to create a
11537*da0073e9SAndroid Build Coastguard Worker        # profile in a subprocess, open it, and parse the sql somehow).
11538*da0073e9SAndroid Build Coastguard Worker        # This test is merely intended to catch if emit_itt breaks on construction.
11539*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device)
11540*da0073e9SAndroid Build Coastguard Worker        with emit_itt():
11541*da0073e9SAndroid Build Coastguard Worker            a.add(1.0)
11542*da0073e9SAndroid Build Coastguard Worker
11543*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work as randn is not supported with type long
11544*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
11545*da0073e9SAndroid Build Coastguard Worker    def test_grad_assignment(self, devices):
11546*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, device=devices[0])
11547*da0073e9SAndroid Build Coastguard Worker
11548*da0073e9SAndroid Build Coastguard Worker        # Tests that the wrong type raises
11549*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "expected to be a Tensor or None"):
11550*da0073e9SAndroid Build Coastguard Worker            x.grad = 0
11551*da0073e9SAndroid Build Coastguard Worker
11552*da0073e9SAndroid Build Coastguard Worker        # Tests that the wrong shape raises
11553*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11554*da0073e9SAndroid Build Coastguard Worker            x.grad = torch.randn(2, 2, device=devices[0])
11555*da0073e9SAndroid Build Coastguard Worker
11556*da0073e9SAndroid Build Coastguard Worker        # Tests that the wrong dtype raises
11557*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11558*da0073e9SAndroid Build Coastguard Worker            x.grad = torch.randn(5, 5, dtype=torch.long, device=devices[0])
11559*da0073e9SAndroid Build Coastguard Worker
11560*da0073e9SAndroid Build Coastguard Worker        # Tests that self-assignment raises
11561*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11562*da0073e9SAndroid Build Coastguard Worker            x.grad = x
11563*da0073e9SAndroid Build Coastguard Worker
11564*da0073e9SAndroid Build Coastguard Worker        # Tests device -> cpu grad assignment raises
11565*da0073e9SAndroid Build Coastguard Worker        if self.device_type != "cpu":
11566*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
11567*da0073e9SAndroid Build Coastguard Worker                t_cpu = torch.rand(5, 5)
11568*da0073e9SAndroid Build Coastguard Worker                t_cpu.grad = torch.randn(5, 5, device=devices[0])
11569*da0073e9SAndroid Build Coastguard Worker
11570*da0073e9SAndroid Build Coastguard Worker        # Tests half type on CUDA
11571*da0073e9SAndroid Build Coastguard Worker        if self.device_type == "cuda":
11572*da0073e9SAndroid Build Coastguard Worker            x = x.to(dtype=torch.half, device=devices[0])
11573*da0073e9SAndroid Build Coastguard Worker            x.grad = torch.zeros_like(x)
11574*da0073e9SAndroid Build Coastguard Worker
11575*da0073e9SAndroid Build Coastguard Worker        # Tests cross-device assignment raises
11576*da0073e9SAndroid Build Coastguard Worker        if len(devices) > 1:
11577*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(5, 5, device=devices[0])
11578*da0073e9SAndroid Build Coastguard Worker            with self.assertRaises(RuntimeError):
11579*da0073e9SAndroid Build Coastguard Worker                x.grad = torch.randn(5, 5, device=devices[1])
11580*da0073e9SAndroid Build Coastguard Worker
11581*da0073e9SAndroid Build Coastguard Worker    @dtypesIfMPS(torch.float32)
11582*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(1)
11583*da0073e9SAndroid Build Coastguard Worker    @dtypes(torch.float, torch.double)
11584*da0073e9SAndroid Build Coastguard Worker    def test_requires_grad_factory(self, devices, dtype):
11585*da0073e9SAndroid Build Coastguard Worker        fns = [torch.ones_like, torch.randn_like]
11586*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 3, dtype=dtype, device=devices[0])
11587*da0073e9SAndroid Build Coastguard Worker
11588*da0073e9SAndroid Build Coastguard Worker        for fn in fns:
11589*da0073e9SAndroid Build Coastguard Worker            for requires_grad in [True, False]:
11590*da0073e9SAndroid Build Coastguard Worker                output = fn(
11591*da0073e9SAndroid Build Coastguard Worker                    x, dtype=dtype, device=devices[0], requires_grad=requires_grad
11592*da0073e9SAndroid Build Coastguard Worker                )
11593*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(requires_grad, output.requires_grad)
11594*da0073e9SAndroid Build Coastguard Worker                self.assertIs(dtype, output.dtype)
11595*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(devices[0], str(x.device))
11596*da0073e9SAndroid Build Coastguard Worker
11597*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
11598*da0073e9SAndroid Build Coastguard Worker    def test_unused_output_device(self, devices):
11599*da0073e9SAndroid Build Coastguard Worker        from torch.nn.parallel._functions import Broadcast
11600*da0073e9SAndroid Build Coastguard Worker
11601*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(5, 5, dtype=torch.float, device=devices[0], requires_grad=True)
11602*da0073e9SAndroid Build Coastguard Worker        outputs = Broadcast.apply(list(range(len(devices))), x)
11603*da0073e9SAndroid Build Coastguard Worker        y = outputs[-1] * 2
11604*da0073e9SAndroid Build Coastguard Worker        y.sum().backward()
11605*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones(5, 5) * 2)
11606*da0073e9SAndroid Build Coastguard Worker
11607*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
11608*da0073e9SAndroid Build Coastguard Worker    def test_backward_device(self, devices):
11609*da0073e9SAndroid Build Coastguard Worker        # check that current device matches the variable's device
11610*da0073e9SAndroid Build Coastguard Worker        device = [None]
11611*da0073e9SAndroid Build Coastguard Worker
11612*da0073e9SAndroid Build Coastguard Worker        class Identity(torch.autograd.Function):
11613*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11614*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
11615*da0073e9SAndroid Build Coastguard Worker                return x.clone()
11616*da0073e9SAndroid Build Coastguard Worker
11617*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11618*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
11619*da0073e9SAndroid Build Coastguard Worker                device[0] = grad_output.device
11620*da0073e9SAndroid Build Coastguard Worker                return grad_output.clone()
11621*da0073e9SAndroid Build Coastguard Worker
11622*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(1, device=devices[1], requires_grad=True)
11623*da0073e9SAndroid Build Coastguard Worker        Identity.apply(v).backward()
11624*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(str(device[0]), devices[1])
11625*da0073e9SAndroid Build Coastguard Worker
11626*da0073e9SAndroid Build Coastguard Worker    @deviceCountAtLeast(2)
11627*da0073e9SAndroid Build Coastguard Worker    def test_inputbuffer_add_multidevice(self, devices):
11628*da0073e9SAndroid Build Coastguard Worker        input = torch.randn(1, device=devices[0], requires_grad=True)
11629*da0073e9SAndroid Build Coastguard Worker        output = input.to(device=devices[1]) + input.to(device=devices[1])
11630*da0073e9SAndroid Build Coastguard Worker        output.backward()
11631*da0073e9SAndroid Build Coastguard Worker
11632*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
11633*da0073e9SAndroid Build Coastguard Worker    def test_copy_(self, device):
11634*da0073e9SAndroid Build Coastguard Worker        # At the time of writing this test, copy_ is not generated from native_functions.yaml
11635*da0073e9SAndroid Build Coastguard Worker        # there was a bug that bfloat16 was not recognized as floating.
11636*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(10, device=device, requires_grad=True)
11637*da0073e9SAndroid Build Coastguard Worker        floating_dt = floating_types_and(torch.half, torch.bfloat16)
11638*da0073e9SAndroid Build Coastguard Worker        for dt in floating_dt:
11639*da0073e9SAndroid Build Coastguard Worker            y = torch.empty(10, device=device, dtype=dt)
11640*da0073e9SAndroid Build Coastguard Worker            y.copy_(x)
11641*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(y.requires_grad)
11642*da0073e9SAndroid Build Coastguard Worker            z = x.to(torch.bfloat16)
11643*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(z.requires_grad)
11644*da0073e9SAndroid Build Coastguard Worker
11645*da0073e9SAndroid Build Coastguard Worker    def test_copy_forward_ad_broadcasting(self, device):
11646*da0073e9SAndroid Build Coastguard Worker        # copy_ allows the src to have a different shape from self as long as src is
11647*da0073e9SAndroid Build Coastguard Worker        # broadcastable to self. Make sure forward AD handles this case.
11648*da0073e9SAndroid Build Coastguard Worker        primal = torch.rand(3, 3, device=device)
11649*da0073e9SAndroid Build Coastguard Worker        tangent = torch.rand(3, 3, device=device)
11650*da0073e9SAndroid Build Coastguard Worker        non_dual = torch.rand(1, 3, 3, device=device)
11651*da0073e9SAndroid Build Coastguard Worker
11652*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
11653*da0073e9SAndroid Build Coastguard Worker            dual = fwAD.make_dual(primal, tangent)
11654*da0073e9SAndroid Build Coastguard Worker            non_dual.copy_(dual)
11655*da0073e9SAndroid Build Coastguard Worker
11656*da0073e9SAndroid Build Coastguard Worker    def test_copy_forward_ad_same_layout_copies_grad(self, device):
11657*da0073e9SAndroid Build Coastguard Worker        primal = torch.tensor([[3.0], [4.0]], device=device)
11658*da0073e9SAndroid Build Coastguard Worker        tangent = torch.tensor([[5.0], [6.0]], device=device)
11659*da0073e9SAndroid Build Coastguard Worker
11660*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
11661*da0073e9SAndroid Build Coastguard Worker            x_dual = fwAD.make_dual(primal, tangent)
11662*da0073e9SAndroid Build Coastguard Worker            non_dual = torch.tensor([[1.0], [2.0]])
11663*da0073e9SAndroid Build Coastguard Worker            non_dual.copy_(x_dual)
11664*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(fwAD.unpack_dual(non_dual).tangent is not tangent)
11665*da0073e9SAndroid Build Coastguard Worker
11666*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11667*da0073e9SAndroid Build Coastguard Worker    def test_simple_reentrant_cross_device(self, device):
11668*da0073e9SAndroid Build Coastguard Worker        class ReentrantFunc(Function):
11669*da0073e9SAndroid Build Coastguard Worker            _cpu_mode = True
11670*da0073e9SAndroid Build Coastguard Worker
11671*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11672*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
11673*da0073e9SAndroid Build Coastguard Worker                return x * (x + 2)
11674*da0073e9SAndroid Build Coastguard Worker
11675*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11676*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
11677*da0073e9SAndroid Build Coastguard Worker                with torch.enable_grad():
11678*da0073e9SAndroid Build Coastguard Worker                    if ReentrantFunc._cpu_mode:
11679*da0073e9SAndroid Build Coastguard Worker                        new_param = torch.randn(2, 2, requires_grad=True)
11680*da0073e9SAndroid Build Coastguard Worker                        (new_param**2).sum().backward()
11681*da0073e9SAndroid Build Coastguard Worker                    else:
11682*da0073e9SAndroid Build Coastguard Worker                        new_param = torch.randn(2, 2, device=device, requires_grad=True)
11683*da0073e9SAndroid Build Coastguard Worker                        (new_param**2).sum().backward()
11684*da0073e9SAndroid Build Coastguard Worker                return grad_output
11685*da0073e9SAndroid Build Coastguard Worker
11686*da0073e9SAndroid Build Coastguard Worker        # Reentrant starts on GPU thread, finishs on GPU thread
11687*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, device=device, requires_grad=True)
11688*da0073e9SAndroid Build Coastguard Worker        out = ReentrantFunc.apply(x)
11689*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
11690*da0073e9SAndroid Build Coastguard Worker
11691*da0073e9SAndroid Build Coastguard Worker        # Reentrant starts on CPU thread, finishs on GPU thread
11692*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, requires_grad=True)
11693*da0073e9SAndroid Build Coastguard Worker        # set ReentrantFunc node to GPU to emit tasks to GPU queue
11694*da0073e9SAndroid Build Coastguard Worker        ReentrantFunc._cpu_mode = False
11695*da0073e9SAndroid Build Coastguard Worker        out = ReentrantFunc.apply(x)
11696*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
11697*da0073e9SAndroid Build Coastguard Worker
11698*da0073e9SAndroid Build Coastguard Worker        # Reentrant starts on GPU thread, finishs on CPU thread
11699*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2, device=device, requires_grad=True)
11700*da0073e9SAndroid Build Coastguard Worker        # set ReentrantFunc node to CPU to emit tasks to CPU queue
11701*da0073e9SAndroid Build Coastguard Worker        ReentrantFunc._cpu_mode = True
11702*da0073e9SAndroid Build Coastguard Worker        out = ReentrantFunc.apply(x)
11703*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
11704*da0073e9SAndroid Build Coastguard Worker
11705*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11706*da0073e9SAndroid Build Coastguard Worker    def test_cross_device_reentrant_autograd(self, device):
11707*da0073e9SAndroid Build Coastguard Worker        # Output on gpu so that this task will be associated with the gpu thread
11708*da0073e9SAndroid Build Coastguard Worker        def fn_on_gpu(inp):
11709*da0073e9SAndroid Build Coastguard Worker            # Artificially increase the priority of the next op to make sure it runs
11710*da0073e9SAndroid Build Coastguard Worker            # as soon as we reach it before the ops of branch1.
11711*da0073e9SAndroid Build Coastguard Worker            dummy = inp * 2 * 2 * 2 * 2
11712*da0073e9SAndroid Build Coastguard Worker            return inp.to(device=device)
11713*da0073e9SAndroid Build Coastguard Worker
11714*da0073e9SAndroid Build Coastguard Worker        def parent_on_cpu(inp):
11715*da0073e9SAndroid Build Coastguard Worker            # Slow branch of ops on gpu so that the work queue for the gpu thread
11716*da0073e9SAndroid Build Coastguard Worker            # won't empty too quickly. They also have smaller priorities than the
11717*da0073e9SAndroid Build Coastguard Worker            # ones created by fn_on_gpu
11718*da0073e9SAndroid Build Coastguard Worker            branch1 = inp.to(device=device)
11719*da0073e9SAndroid Build Coastguard Worker            branch1 = branch1 / branch1
11720*da0073e9SAndroid Build Coastguard Worker            branch1 = branch1 / branch1
11721*da0073e9SAndroid Build Coastguard Worker            branch1 = branch1 / branch1
11722*da0073e9SAndroid Build Coastguard Worker            # Perform checkpoint on cpu tensors. So the last op performed in the reentrant
11723*da0073e9SAndroid Build Coastguard Worker            # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread.
11724*da0073e9SAndroid Build Coastguard Worker            # So the cpu thread will notify the gpu thread with an empty NodeTask.
11725*da0073e9SAndroid Build Coastguard Worker            branch2 = checkpoint(fn_on_gpu, inp, use_reentrant=True)
11726*da0073e9SAndroid Build Coastguard Worker            out = branch2 + branch1
11727*da0073e9SAndroid Build Coastguard Worker            return out
11728*da0073e9SAndroid Build Coastguard Worker
11729*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2, requires_grad=True)
11730*da0073e9SAndroid Build Coastguard Worker        out = parent_on_cpu(inp)
11731*da0073e9SAndroid Build Coastguard Worker        # This will segfault if the empty NodeTask is not handled properly in the
11732*da0073e9SAndroid Build Coastguard Worker        # gpu thread ReadyQueue
11733*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
11734*da0073e9SAndroid Build Coastguard Worker
11735*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_backprop_base(self, device):
11736*da0073e9SAndroid Build Coastguard Worker        # modify view and back-prop through base
11737*da0073e9SAndroid Build Coastguard Worker        root = torch.randn(2, 2, device=device, requires_grad=True)
11738*da0073e9SAndroid Build Coastguard Worker        x = root.clone()
11739*da0073e9SAndroid Build Coastguard Worker        v1 = x.narrow(0, 0, 1)
11740*da0073e9SAndroid Build Coastguard Worker        v1.mul_(2)
11741*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
11742*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]])
11743*da0073e9SAndroid Build Coastguard Worker
11744*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_backprop_view_of_view(self, device):
11745*da0073e9SAndroid Build Coastguard Worker        # modify view and backprop through view-of-view
11746*da0073e9SAndroid Build Coastguard Worker        root = torch.randn(2, 2, device=device, requires_grad=True)
11747*da0073e9SAndroid Build Coastguard Worker        x = root.clone()
11748*da0073e9SAndroid Build Coastguard Worker        v1 = x.narrow(0, 0, 1)
11749*da0073e9SAndroid Build Coastguard Worker        v2 = x.narrow(0, 0, 1)
11750*da0073e9SAndroid Build Coastguard Worker        v1.mul_(2)
11751*da0073e9SAndroid Build Coastguard Worker        v2.sum().backward()
11752*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]])
11753*da0073e9SAndroid Build Coastguard Worker
11754*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_of_view(self, device):
11755*da0073e9SAndroid Build Coastguard Worker        # modify view-of-view and backprop through base
11756*da0073e9SAndroid Build Coastguard Worker        root = torch.randn(2, 2, device=device, requires_grad=True)
11757*da0073e9SAndroid Build Coastguard Worker        x = root.clone()
11758*da0073e9SAndroid Build Coastguard Worker
11759*da0073e9SAndroid Build Coastguard Worker        v1 = x.narrow(0, 0, 1)
11760*da0073e9SAndroid Build Coastguard Worker        v2 = v1.narrow(1, 1, 1)
11761*da0073e9SAndroid Build Coastguard Worker        v2.mul_(2)
11762*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
11763*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]])
11764*da0073e9SAndroid Build Coastguard Worker
11765*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11766*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_then_no_grad(self, device):
11767*da0073e9SAndroid Build Coastguard Worker        # Perform an in-place operation on a view of a non-leaf variable.
11768*da0073e9SAndroid Build Coastguard Worker        a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True)
11769*da0073e9SAndroid Build Coastguard Worker        b = a * 2
11770*da0073e9SAndroid Build Coastguard Worker        c = b.view_as(b)
11771*da0073e9SAndroid Build Coastguard Worker        c[0][0] = 3
11772*da0073e9SAndroid Build Coastguard Worker
11773*da0073e9SAndroid Build Coastguard Worker        # Force a graph update with grad disabled.
11774*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
11775*da0073e9SAndroid Build Coastguard Worker            c.grad_fn
11776*da0073e9SAndroid Build Coastguard Worker
11777*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
11778*da0073e9SAndroid Build Coastguard Worker
11779*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11780*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_gradcheck(self, device):
11781*da0073e9SAndroid Build Coastguard Worker        # gradcheck modifications to views
11782*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
11783*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
11784*da0073e9SAndroid Build Coastguard Worker
11785*da0073e9SAndroid Build Coastguard Worker        def func(root, b):
11786*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
11787*da0073e9SAndroid Build Coastguard Worker            x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b)
11788*da0073e9SAndroid Build Coastguard Worker            x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b)
11789*da0073e9SAndroid Build Coastguard Worker            return x
11790*da0073e9SAndroid Build Coastguard Worker
11791*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [a, b], raise_exception=True)
11792*da0073e9SAndroid Build Coastguard Worker        go = torch.randn(
11793*da0073e9SAndroid Build Coastguard Worker            a.size(), dtype=torch.double, device=device, requires_grad=True
11794*da0073e9SAndroid Build Coastguard Worker        )
11795*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, (a, b), (go,))
11796*da0073e9SAndroid Build Coastguard Worker
11797*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_multiple_outputs(self, device):
11798*da0073e9SAndroid Build Coastguard Worker        root = torch.arange(9.0, dtype=torch.double).reshape(3, 3).requires_grad_()
11799*da0073e9SAndroid Build Coastguard Worker        x = root.clone()
11800*da0073e9SAndroid Build Coastguard Worker        v1 = x.unbind()
11801*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11802*da0073e9SAndroid Build Coastguard Worker            v1[0].mul_(2)
11803*da0073e9SAndroid Build Coastguard Worker
11804*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11805*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_of_multiple_output_view(self, device):
11806*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(
11807*da0073e9SAndroid Build Coastguard Worker            10, dtype=torch.double, device=device, requires_grad=True
11808*da0073e9SAndroid Build Coastguard Worker        ).clone()
11809*da0073e9SAndroid Build Coastguard Worker        b = a.unbind(0)
11810*da0073e9SAndroid Build Coastguard Worker        c = b[0].view_as(b[0])
11811*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11812*da0073e9SAndroid Build Coastguard Worker            c.mul_(2)
11813*da0073e9SAndroid Build Coastguard Worker
11814*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # MPS backend doesn't support double types
11815*da0073e9SAndroid Build Coastguard Worker    def test_inplace_multiple_output_view_of_view(self, device):
11816*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(
11817*da0073e9SAndroid Build Coastguard Worker            10, dtype=torch.double, device=device, requires_grad=True
11818*da0073e9SAndroid Build Coastguard Worker        ).clone()
11819*da0073e9SAndroid Build Coastguard Worker        b = a.view_as(a)
11820*da0073e9SAndroid Build Coastguard Worker        c = b.unbind(0)
11821*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(RuntimeError):
11822*da0073e9SAndroid Build Coastguard Worker            c[0].mul_(2)
11823*da0073e9SAndroid Build Coastguard Worker
11824*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # MPS backend doesn't support double types
11825*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_makes_base_require_grad(self, device):
11826*da0073e9SAndroid Build Coastguard Worker        # in-place modification to view makes base require grad
11827*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False)
11828*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True)
11829*da0073e9SAndroid Build Coastguard Worker
11830*da0073e9SAndroid Build Coastguard Worker        def func(root, b):
11831*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
11832*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x.requires_grad)
11833*da0073e9SAndroid Build Coastguard Worker            x.narrow(1, 2, 2).mul_(b)
11834*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(x.requires_grad)
11835*da0073e9SAndroid Build Coastguard Worker            return x
11836*da0073e9SAndroid Build Coastguard Worker
11837*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [a, b], raise_exception=True)
11838*da0073e9SAndroid Build Coastguard Worker        go = torch.randn(
11839*da0073e9SAndroid Build Coastguard Worker            a.size(), dtype=torch.double, device=device, requires_grad=True
11840*da0073e9SAndroid Build Coastguard Worker        )
11841*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, (a, b), (go,))
11842*da0073e9SAndroid Build Coastguard Worker
11843*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_backprop_view(self, device):
11844*da0073e9SAndroid Build Coastguard Worker        # modify view and backprop through view
11845*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([2.0, 5.0], device=device, requires_grad=False)
11846*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor([3.0], device=device, requires_grad=True)
11847*da0073e9SAndroid Build Coastguard Worker        res = a.narrow(0, 1, 1).mul_(b)
11848*da0073e9SAndroid Build Coastguard Worker        res.sum().backward()
11849*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad.tolist(), [5])
11850*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(a.grad)
11851*da0073e9SAndroid Build Coastguard Worker
11852*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11853*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_modify_base(self, device):
11854*da0073e9SAndroid Build Coastguard Worker        # Test that an in-place operation on a base that forced it to require
11855*da0073e9SAndroid Build Coastguard Worker        # grad also forces any previous views to require grad and backprop
11856*da0073e9SAndroid Build Coastguard Worker        # correctly
11857*da0073e9SAndroid Build Coastguard Worker        r = torch.ones(1, dtype=torch.double, device=device, requires_grad=True)
11858*da0073e9SAndroid Build Coastguard Worker
11859*da0073e9SAndroid Build Coastguard Worker        def fn(r):
11860*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(5, dtype=torch.double, device=device)
11861*da0073e9SAndroid Build Coastguard Worker            v = x.select(0, 1)
11862*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(v.requires_grad)
11863*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(v.grad_fn)
11864*da0073e9SAndroid Build Coastguard Worker            x.add_(r)  # v is now dependent on r due to the in-place op on x
11865*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(v.requires_grad)
11866*da0073e9SAndroid Build Coastguard Worker            return v
11867*da0073e9SAndroid Build Coastguard Worker
11868*da0073e9SAndroid Build Coastguard Worker        gradcheck(fn, [r])
11869*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(fn, [r])
11870*da0073e9SAndroid Build Coastguard Worker
11871*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11872*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_python(self, device):
11873*da0073e9SAndroid Build Coastguard Worker        # in-place modifications of Python-autograd created view
11874*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True)
11875*da0073e9SAndroid Build Coastguard Worker        b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True)
11876*da0073e9SAndroid Build Coastguard Worker
11877*da0073e9SAndroid Build Coastguard Worker        class PyAdd(torch.autograd.Function):
11878*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11879*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, y):
11880*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(x)
11881*da0073e9SAndroid Build Coastguard Worker                x.add_(y)
11882*da0073e9SAndroid Build Coastguard Worker                return x
11883*da0073e9SAndroid Build Coastguard Worker
11884*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11885*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
11886*da0073e9SAndroid Build Coastguard Worker                return grad, grad
11887*da0073e9SAndroid Build Coastguard Worker
11888*da0073e9SAndroid Build Coastguard Worker        def func(root, b):
11889*da0073e9SAndroid Build Coastguard Worker            x = root.clone()
11890*da0073e9SAndroid Build Coastguard Worker            PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b)
11891*da0073e9SAndroid Build Coastguard Worker            PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b)
11892*da0073e9SAndroid Build Coastguard Worker            return x
11893*da0073e9SAndroid Build Coastguard Worker
11894*da0073e9SAndroid Build Coastguard Worker        gradcheck(func, [a, b], raise_exception=True)
11895*da0073e9SAndroid Build Coastguard Worker        go = torch.randn(
11896*da0073e9SAndroid Build Coastguard Worker            a.size(), dtype=torch.double, device=device, requires_grad=True
11897*da0073e9SAndroid Build Coastguard Worker        )
11898*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(func, (a, b), (go,))
11899*da0073e9SAndroid Build Coastguard Worker
11900*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_non_contig(self, device):
11901*da0073e9SAndroid Build Coastguard Worker        root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True)
11902*da0073e9SAndroid Build Coastguard Worker        x = root.clone()
11903*da0073e9SAndroid Build Coastguard Worker        v1 = x.narrow(0, 0, 1)
11904*da0073e9SAndroid Build Coastguard Worker        v2 = v1.narrow(1, 1, 1)
11905*da0073e9SAndroid Build Coastguard Worker        v2.mul_(2)
11906*da0073e9SAndroid Build Coastguard Worker        x.sum().backward()
11907*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]])
11908*da0073e9SAndroid Build Coastguard Worker
11909*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_multi_output_unsafe(self, device):
11910*da0073e9SAndroid Build Coastguard Worker        for f in [
11911*da0073e9SAndroid Build Coastguard Worker            lambda t: t.unsafe_split(1),
11912*da0073e9SAndroid Build Coastguard Worker            lambda t: t.unsafe_split_with_sizes((1, 1, 1)),
11913*da0073e9SAndroid Build Coastguard Worker            lambda t: t.unsafe_chunk(3),
11914*da0073e9SAndroid Build Coastguard Worker        ]:
11915*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(3, 3, device=device, requires_grad=True)
11916*da0073e9SAndroid Build Coastguard Worker            b = a + a
11917*da0073e9SAndroid Build Coastguard Worker            s1, s2, s3 = f(b)
11918*da0073e9SAndroid Build Coastguard Worker            s1.mul_(s2)
11919*da0073e9SAndroid Build Coastguard Worker            s1.sum().backward()
11920*da0073e9SAndroid Build Coastguard Worker
11921*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_multi_output_safe(self, device):
11922*da0073e9SAndroid Build Coastguard Worker        for f in [
11923*da0073e9SAndroid Build Coastguard Worker            lambda t: t.split(1),
11924*da0073e9SAndroid Build Coastguard Worker            lambda t: t.split_with_sizes((1, 1, 1)),
11925*da0073e9SAndroid Build Coastguard Worker            lambda t: t.chunk(3),
11926*da0073e9SAndroid Build Coastguard Worker        ]:
11927*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(3, 3, device=device, requires_grad=True)
11928*da0073e9SAndroid Build Coastguard Worker            b = a + a
11929*da0073e9SAndroid Build Coastguard Worker            s1, s2, s3 = f(b)
11930*da0073e9SAndroid Build Coastguard Worker            error_msg = (
11931*da0073e9SAndroid Build Coastguard Worker                "This view is the output of a function that returns multiple views."
11932*da0073e9SAndroid Build Coastguard Worker            )
11933*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, error_msg):
11934*da0073e9SAndroid Build Coastguard Worker                s1.mul_(s2)
11935*da0073e9SAndroid Build Coastguard Worker
11936*da0073e9SAndroid Build Coastguard Worker    def test_inplace_on_view_undefined_grad_output(self, device):
11937*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor([1.0], requires_grad=True)
11938*da0073e9SAndroid Build Coastguard Worker        c = a.clone()
11939*da0073e9SAndroid Build Coastguard Worker        v = c[:]
11940*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(1.0, requires_grad=True)
11941*da0073e9SAndroid Build Coastguard Worker
11942*da0073e9SAndroid Build Coastguard Worker        class InplaceFunc(torch.autograd.Function):
11943*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11944*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, other):
11945*da0073e9SAndroid Build Coastguard Worker                ctx.mark_dirty(x)
11946*da0073e9SAndroid Build Coastguard Worker                return x.mul_(2)
11947*da0073e9SAndroid Build Coastguard Worker
11948*da0073e9SAndroid Build Coastguard Worker            @staticmethod
11949*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad):
11950*da0073e9SAndroid Build Coastguard Worker                return grad * 2, None
11951*da0073e9SAndroid Build Coastguard Worker
11952*da0073e9SAndroid Build Coastguard Worker        out = InplaceFunc.apply(v, b)
11953*da0073e9SAndroid Build Coastguard Worker        out.backward()
11954*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(b.grad)
11955*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad.item(), 2)
11956*da0073e9SAndroid Build Coastguard Worker
11957*da0073e9SAndroid Build Coastguard Worker    @skipIfMps  # the test doesn't work on MPS as double types are not supported
11958*da0073e9SAndroid Build Coastguard Worker    def test_mv_grad_stride_0(self, device):
11959*da0073e9SAndroid Build Coastguard Worker        # Reference: https://github.com/pytorch/pytorch/issues/38315
11960*da0073e9SAndroid Build Coastguard Worker        mat = torch.randn(2, 2, dtype=torch.double, device=device)
11961*da0073e9SAndroid Build Coastguard Worker        vec = torch.randn(1, dtype=torch.double, device=device).requires_grad_(True)
11962*da0073e9SAndroid Build Coastguard Worker
11963*da0073e9SAndroid Build Coastguard Worker        def fn(vec):
11964*da0073e9SAndroid Build Coastguard Worker            # Expand inside the function to make sure the input to
11965*da0073e9SAndroid Build Coastguard Worker            # gradcheck does not have overlapping memory
11966*da0073e9SAndroid Build Coastguard Worker            vec = vec.expand(2)
11967*da0073e9SAndroid Build Coastguard Worker            return (mat @ vec).sum()
11968*da0073e9SAndroid Build Coastguard Worker
11969*da0073e9SAndroid Build Coastguard Worker        gradcheck(fn, (vec))
11970*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(fn, (vec))
11971*da0073e9SAndroid Build Coastguard Worker
11972*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
11973*da0073e9SAndroid Build Coastguard Worker    def test_gradcheck_input_output_different_device(self, device):
11974*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((1,), dtype=torch.double, device="cuda", requires_grad=True)
11975*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: x.to("cpu"), (x,))
11976*da0073e9SAndroid Build Coastguard Worker
11977*da0073e9SAndroid Build Coastguard Worker        x = torch.ones((1,), dtype=torch.double, device="cpu", requires_grad=True)
11978*da0073e9SAndroid Build Coastguard Worker        gradcheck(lambda x: x.to("cuda"), (x,))
11979*da0073e9SAndroid Build Coastguard Worker
11980*da0073e9SAndroid Build Coastguard Worker    def test_strided_leaf_grad_layout(self, device):
11981*da0073e9SAndroid Build Coastguard Worker        # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf.
11982*da0073e9SAndroid Build Coastguard Worker        for fmt_a in (torch.contiguous_format, torch.channels_last):
11983*da0073e9SAndroid Build Coastguard Worker            for fmt_b in (torch.contiguous_format, torch.channels_last):
11984*da0073e9SAndroid Build Coastguard Worker                a = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_a)
11985*da0073e9SAndroid Build Coastguard Worker                b = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_b)
11986*da0073e9SAndroid Build Coastguard Worker                a.requires_grad_()
11987*da0073e9SAndroid Build Coastguard Worker                b.requires_grad_()
11988*da0073e9SAndroid Build Coastguard Worker                # checks (1) for broadcasted gradients
11989*da0073e9SAndroid Build Coastguard Worker                a.sum().backward()
11990*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.grad.stride(), a.stride())
11991*da0073e9SAndroid Build Coastguard Worker                b.sum().backward()
11992*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.grad.stride(), b.stride())
11993*da0073e9SAndroid Build Coastguard Worker                # checks (1) for non-broadcasted gradients
11994*da0073e9SAndroid Build Coastguard Worker                a.grad = None
11995*da0073e9SAndroid Build Coastguard Worker                b.grad = None
11996*da0073e9SAndroid Build Coastguard Worker                (a * b).sum().backward()
11997*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.grad.stride(), a.stride())
11998*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(b.grad.stride(), b.stride())
11999*da0073e9SAndroid Build Coastguard Worker
12000*da0073e9SAndroid Build Coastguard Worker        # (2) If leaf isn't dense, checks that grads are rowmajor contiguous.
12001*da0073e9SAndroid Build Coastguard Worker        c = torch.empty_strided((2, 2), (4, 2), device=device).copy_(
12002*da0073e9SAndroid Build Coastguard Worker            torch.rand((2, 2), device=device)
12003*da0073e9SAndroid Build Coastguard Worker        )
12004*da0073e9SAndroid Build Coastguard Worker        c.requires_grad_()
12005*da0073e9SAndroid Build Coastguard Worker        d = torch.rand((2, 2), device=device)
12006*da0073e9SAndroid Build Coastguard Worker        # checks (2) for broadcasted gradients
12007*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
12008*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.grad.stride(), (2, 1))
12009*da0073e9SAndroid Build Coastguard Worker        # checks (2) for non-broadcasted gradients
12010*da0073e9SAndroid Build Coastguard Worker        c.grad = None
12011*da0073e9SAndroid Build Coastguard Worker        (c * d).sum().backward()
12012*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.grad.stride(), (2, 1))
12013*da0073e9SAndroid Build Coastguard Worker
12014*da0073e9SAndroid Build Coastguard Worker    @skipIfMps
12015*da0073e9SAndroid Build Coastguard Worker    def test_copy_r_to_c(self, device):
12016*da0073e9SAndroid Build Coastguard Worker        out_c = torch.empty(3, 2, dtype=torch.cdouble, device=device)
12017*da0073e9SAndroid Build Coastguard Worker        inp_r = torch.randn(3, 2, dtype=torch.double, device=device, requires_grad=True)
12018*da0073e9SAndroid Build Coastguard Worker
12019*da0073e9SAndroid Build Coastguard Worker        def do_test():
12020*da0073e9SAndroid Build Coastguard Worker            out_c.copy_(inp_r)
12021*da0073e9SAndroid Build Coastguard Worker            out_c_inter = out_c.sum()
12022*da0073e9SAndroid Build Coastguard Worker            out_c_inter.abs().backward()
12023*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
12024*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
12025*da0073e9SAndroid Build Coastguard Worker                    inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real
12026*da0073e9SAndroid Build Coastguard Worker                )
12027*da0073e9SAndroid Build Coastguard Worker
12028*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(do_test)
12029*da0073e9SAndroid Build Coastguard Worker
12030*da0073e9SAndroid Build Coastguard Worker    def test_to_r_to_c(self, device):
12031*da0073e9SAndroid Build Coastguard Worker        def do_test():
12032*da0073e9SAndroid Build Coastguard Worker            inp_r = torch.randn(
12033*da0073e9SAndroid Build Coastguard Worker                3, 2, dtype=torch.double, device=device, requires_grad=True
12034*da0073e9SAndroid Build Coastguard Worker            )
12035*da0073e9SAndroid Build Coastguard Worker            out = inp_r.to(torch.complex128)
12036*da0073e9SAndroid Build Coastguard Worker            out_inter = out.sum()
12037*da0073e9SAndroid Build Coastguard Worker            out_inter.abs().backward()
12038*da0073e9SAndroid Build Coastguard Worker            with torch.no_grad():
12039*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
12040*da0073e9SAndroid Build Coastguard Worker                    inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real
12041*da0073e9SAndroid Build Coastguard Worker                )
12042*da0073e9SAndroid Build Coastguard Worker
12043*da0073e9SAndroid Build Coastguard Worker        self.assertNotWarn(do_test)
12044*da0073e9SAndroid Build Coastguard Worker
12045*da0073e9SAndroid Build Coastguard Worker    def test_non_differentiable_ops(self, device):
12046*da0073e9SAndroid Build Coastguard Worker        # Just make sure the op doesn't raise an error
12047*da0073e9SAndroid Build Coastguard Worker        # and resulting tensor has requires_grad=False.
12048*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([[1, 2], [3, 4.0]], requires_grad=True, device=device)
12049*da0073e9SAndroid Build Coastguard Worker        out = torch.isin(x, torch.tensor([2, 3], device=device))
12050*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
12051*da0073e9SAndroid Build Coastguard Worker
12052*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3, requires_grad=True)
12053*da0073e9SAndroid Build Coastguard Worker        out = torch.signbit(x)
12054*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
12055*da0073e9SAndroid Build Coastguard Worker
12056*da0073e9SAndroid Build Coastguard Worker    def test_warning_in_backward(self, device):
12057*da0073e9SAndroid Build Coastguard Worker        # Test warning during backward are always propagated as python warnings (gh-50209)
12058*da0073e9SAndroid Build Coastguard Worker        # NOTE: For device=cuda, warning gets propagated from a worker thread
12059*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros((), device=device, requires_grad=True)
12060*da0073e9SAndroid Build Coastguard Worker        b = torch._C._nn._test_warn_in_autograd(a)
12061*da0073e9SAndroid Build Coastguard Worker
12062*da0073e9SAndroid Build Coastguard Worker        with self.assertWarnsRegex(UserWarning, "Warn from backward"):
12063*da0073e9SAndroid Build Coastguard Worker            b.backward()
12064*da0073e9SAndroid Build Coastguard Worker
12065*da0073e9SAndroid Build Coastguard Worker    def test_complex_scalar_backward(self, device):
12066*da0073e9SAndroid Build Coastguard Worker        a = torch.zeros(1, device=device, requires_grad=True)
12067*da0073e9SAndroid Build Coastguard Worker        b = a * 0.5j
12068*da0073e9SAndroid Build Coastguard Worker
12069*da0073e9SAndroid Build Coastguard Worker        msg = "grad can be implicitly created only for real scalar outputs"
12070*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
12071*da0073e9SAndroid Build Coastguard Worker            b.backward()
12072*da0073e9SAndroid Build Coastguard Worker
12073*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
12074*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(b, a)
12075*da0073e9SAndroid Build Coastguard Worker
12076*da0073e9SAndroid Build Coastguard Worker    def test_pow_real_negative_base_complex_exponent(self, device):
12077*da0073e9SAndroid Build Coastguard Worker        # OpInfo doesn't naturally support input of mixed types, hence this test here.
12078*da0073e9SAndroid Build Coastguard Worker        base = -torch.ones(2, device=device, dtype=torch.double)
12079*da0073e9SAndroid Build Coastguard Worker        exponent = torch.randn(
12080*da0073e9SAndroid Build Coastguard Worker            2, device=device, dtype=torch.cdouble, requires_grad=True
12081*da0073e9SAndroid Build Coastguard Worker        )
12082*da0073e9SAndroid Build Coastguard Worker
12083*da0073e9SAndroid Build Coastguard Worker        def fn(exponent):
12084*da0073e9SAndroid Build Coastguard Worker            return torch.pow(base, exponent)
12085*da0073e9SAndroid Build Coastguard Worker
12086*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(fn, (exponent,))
12087*da0073e9SAndroid Build Coastguard Worker
12088*da0073e9SAndroid Build Coastguard Worker        def fn(exponent):
12089*da0073e9SAndroid Build Coastguard Worker            return torch.pow(-1, exponent)
12090*da0073e9SAndroid Build Coastguard Worker
12091*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(fn, (exponent,))
12092*da0073e9SAndroid Build Coastguard Worker
12093*da0073e9SAndroid Build Coastguard Worker    def test_resize_version_bump(self, device):
12094*da0073e9SAndroid Build Coastguard Worker        x = torch.rand((1,), device=device)
12095*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((3,), device=device)
12096*da0073e9SAndroid Build Coastguard Worker        x.resize_((1, 2))
12097*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x._version, 1)
12098*da0073e9SAndroid Build Coastguard Worker        x.resize_as_(y)
12099*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x._version, 2)
12100*da0073e9SAndroid Build Coastguard Worker
12101*da0073e9SAndroid Build Coastguard Worker        # In the following cases, `resize` is no-op,
12102*da0073e9SAndroid Build Coastguard Worker        # so no version bumps.
12103*da0073e9SAndroid Build Coastguard Worker        x.resize_((3,))
12104*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x._version, 2)
12105*da0073e9SAndroid Build Coastguard Worker
12106*da0073e9SAndroid Build Coastguard Worker        x.resize_as_(y)
12107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x._version, 2)
12108*da0073e9SAndroid Build Coastguard Worker
12109*da0073e9SAndroid Build Coastguard Worker
12110*da0073e9SAndroid Build Coastguard Workerclass TestAllowMutationOnSaved(TestCase):
12111*da0073e9SAndroid Build Coastguard Worker    def assertClonedLenEqual(self, ctx, n):
12112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(ctx.cloned.items())), n)
12113*da0073e9SAndroid Build Coastguard Worker
12114*da0073e9SAndroid Build Coastguard Worker    def assertTIDMapLenEqual(self, ctx, n):
12115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n)
12116*da0073e9SAndroid Build Coastguard Worker
12117*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
12118*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 3, requires_grad=True)
12119*da0073e9SAndroid Build Coastguard Worker
12120*da0073e9SAndroid Build Coastguard Worker        def fn(a):
12121*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
12122*da0073e9SAndroid Build Coastguard Worker            out = (b**2).sum()
12123*da0073e9SAndroid Build Coastguard Worker            b.sin_()
12124*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
12125*da0073e9SAndroid Build Coastguard Worker            return a.grad
12126*da0073e9SAndroid Build Coastguard Worker
12127*da0073e9SAndroid Build Coastguard Worker        msg = (
12128*da0073e9SAndroid Build Coastguard Worker            "variables needed for gradient computation has been modified by an inplace"
12129*da0073e9SAndroid Build Coastguard Worker        )
12130*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
12131*da0073e9SAndroid Build Coastguard Worker            fn(a)
12132*da0073e9SAndroid Build Coastguard Worker
12133*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12134*da0073e9SAndroid Build Coastguard Worker            da = fn(a)
12135*da0073e9SAndroid Build Coastguard Worker
12136*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(a * 2, da))
12137*da0073e9SAndroid Build Coastguard Worker        self.assertClonedLenEqual(ctx, 0)
12138*da0073e9SAndroid Build Coastguard Worker
12139*da0073e9SAndroid Build Coastguard Worker    def test_views(self):
12140*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(2, 3, requires_grad=True)
12141*da0073e9SAndroid Build Coastguard Worker
12142*da0073e9SAndroid Build Coastguard Worker        def fn(a):
12143*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
12144*da0073e9SAndroid Build Coastguard Worker            c = b.view_as(b)
12145*da0073e9SAndroid Build Coastguard Worker            out = (b**2).sum()  # How does this work?
12146*da0073e9SAndroid Build Coastguard Worker            c.sin_()
12147*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
12148*da0073e9SAndroid Build Coastguard Worker            return a.grad
12149*da0073e9SAndroid Build Coastguard Worker
12150*da0073e9SAndroid Build Coastguard Worker        msg = (
12151*da0073e9SAndroid Build Coastguard Worker            "variables needed for gradient computation has been modified by an inplace"
12152*da0073e9SAndroid Build Coastguard Worker        )
12153*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, msg):
12154*da0073e9SAndroid Build Coastguard Worker            fn(a)
12155*da0073e9SAndroid Build Coastguard Worker
12156*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12157*da0073e9SAndroid Build Coastguard Worker            da = fn(a)
12158*da0073e9SAndroid Build Coastguard Worker
12159*da0073e9SAndroid Build Coastguard Worker        self.assertClonedLenEqual(ctx, 0)
12160*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(a * 2, da))
12161*da0073e9SAndroid Build Coastguard Worker
12162*da0073e9SAndroid Build Coastguard Worker    def test_save_base_and_modify_view(self):
12163*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12164*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3, requires_grad=True)
12165*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
12166*da0073e9SAndroid Build Coastguard Worker            c = b[:1]
12167*da0073e9SAndroid Build Coastguard Worker            out = b**2
12168*da0073e9SAndroid Build Coastguard Worker            # modify the view
12169*da0073e9SAndroid Build Coastguard Worker            c *= 10
12170*da0073e9SAndroid Build Coastguard Worker            # self.assertClonedLenEqual(ctx, 1)
12171*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
12172*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12173*da0073e9SAndroid Build Coastguard Worker
12174*da0073e9SAndroid Build Coastguard Worker        self.assertClonedLenEqual(ctx, 0)
12175*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(a * 2, a.grad))
12176*da0073e9SAndroid Build Coastguard Worker
12177*da0073e9SAndroid Build Coastguard Worker    def test_save_view_modify_base(self):
12178*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12179*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3, requires_grad=True)
12180*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
12181*da0073e9SAndroid Build Coastguard Worker            c = b[:]
12182*da0073e9SAndroid Build Coastguard Worker            out = (c**2).sum()
12183*da0073e9SAndroid Build Coastguard Worker            b *= 2
12184*da0073e9SAndroid Build Coastguard Worker            out.backward()
12185*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(a * 2, a.grad))
12186*da0073e9SAndroid Build Coastguard Worker
12187*da0073e9SAndroid Build Coastguard Worker    def test_double_backward(self):
12188*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12189*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3, requires_grad=True)
12190*da0073e9SAndroid Build Coastguard Worker            b = a.clone()
12191*da0073e9SAndroid Build Coastguard Worker            out = (b**2).sum()
12192*da0073e9SAndroid Build Coastguard Worker            b.sin_()
12193*da0073e9SAndroid Build Coastguard Worker            torch.autograd.grad(out, a, create_graph=True)
12194*da0073e9SAndroid Build Coastguard Worker            (da,) = torch.autograd.grad(out, a, create_graph=True)
12195*da0073e9SAndroid Build Coastguard Worker            (d2a,) = torch.autograd.grad(da.sum(), a)
12196*da0073e9SAndroid Build Coastguard Worker
12197*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a))
12198*da0073e9SAndroid Build Coastguard Worker        self.assertClonedLenEqual(ctx, 0)
12199*da0073e9SAndroid Build Coastguard Worker
12200*da0073e9SAndroid Build Coastguard Worker    def test_saved_but_not_anymore(self):
12201*da0073e9SAndroid Build Coastguard Worker        # Make sure we don't clone if the tensor was once saved, but
12202*da0073e9SAndroid Build Coastguard Worker        # by the time we do in-place, it is no longer saved
12203*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12204*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 3, requires_grad=True).clone()
12205*da0073e9SAndroid Build Coastguard Worker            out = (a**2).sum()
12206*da0073e9SAndroid Build Coastguard Worker            self.assertTIDMapLenEqual(ctx, 1)
12207*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12208*da0073e9SAndroid Build Coastguard Worker            out.backward()
12209*da0073e9SAndroid Build Coastguard Worker            a.sin_()
12210*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12211*da0073e9SAndroid Build Coastguard Worker            out = (a**2).sum()
12212*da0073e9SAndroid Build Coastguard Worker            a.sin_()
12213*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 1)
12214*da0073e9SAndroid Build Coastguard Worker            del out
12215*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12216*da0073e9SAndroid Build Coastguard Worker
12217*da0073e9SAndroid Build Coastguard Worker    def test_saved_same_tensor_many_times(self):
12218*da0073e9SAndroid Build Coastguard Worker        # We should only clone once
12219*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12220*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 3, requires_grad=True).clone()
12221*da0073e9SAndroid Build Coastguard Worker            b = a**2
12222*da0073e9SAndroid Build Coastguard Worker            c = a**2
12223*da0073e9SAndroid Build Coastguard Worker            a.sin_()
12224*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 1)
12225*da0073e9SAndroid Build Coastguard Worker            del b, c
12226*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12227*da0073e9SAndroid Build Coastguard Worker
12228*da0073e9SAndroid Build Coastguard Worker    def test_saved_same_tensor_different_versions(self):
12229*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12230*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 3, requires_grad=True).clone()
12231*da0073e9SAndroid Build Coastguard Worker            b = a**2
12232*da0073e9SAndroid Build Coastguard Worker            a.sin_()
12233*da0073e9SAndroid Build Coastguard Worker            c = a**2
12234*da0073e9SAndroid Build Coastguard Worker            a.sin_()
12235*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 2)
12236*da0073e9SAndroid Build Coastguard Worker            del b
12237*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 1)
12238*da0073e9SAndroid Build Coastguard Worker            del c
12239*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12240*da0073e9SAndroid Build Coastguard Worker
12241*da0073e9SAndroid Build Coastguard Worker    def test_with_math_views(self):
12242*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12243*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([1 + 1j], requires_grad=True).clone()
12244*da0073e9SAndroid Build Coastguard Worker            b = a.conj()
12245*da0073e9SAndroid Build Coastguard Worker            out = (b**2).sum()
12246*da0073e9SAndroid Build Coastguard Worker            a.sin_()
12247*da0073e9SAndroid Build Coastguard Worker            out.abs().backward()
12248*da0073e9SAndroid Build Coastguard Worker
12249*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([1 + 1j], requires_grad=True).clone()
12250*da0073e9SAndroid Build Coastguard Worker            b = a.conj()
12251*da0073e9SAndroid Build Coastguard Worker            out = (b**2).sum()
12252*da0073e9SAndroid Build Coastguard Worker            # in this case, it is no longer a view it seems
12253*da0073e9SAndroid Build Coastguard Worker            b.sin_()
12254*da0073e9SAndroid Build Coastguard Worker            out.abs().backward()
12255*da0073e9SAndroid Build Coastguard Worker
12256*da0073e9SAndroid Build Coastguard Worker    def test_with_out_variant(self):
12257*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12258*da0073e9SAndroid Build Coastguard Worker            a = torch.tensor([1.0], requires_grad=True)
12259*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([1.0])
12260*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor([2.0])
12261*da0073e9SAndroid Build Coastguard Worker            out = a * b
12262*da0073e9SAndroid Build Coastguard Worker            self.assertTIDMapLenEqual(ctx, 1)
12263*da0073e9SAndroid Build Coastguard Worker            torch.sin(c, out=b)
12264*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 1)
12265*da0073e9SAndroid Build Coastguard Worker            out.backward()
12266*da0073e9SAndroid Build Coastguard Worker            self.assertClonedLenEqual(ctx, 0)
12267*da0073e9SAndroid Build Coastguard Worker
12268*da0073e9SAndroid Build Coastguard Worker    def test_backward_out_of_context(self):
12269*da0073e9SAndroid Build Coastguard Worker        # Out of context
12270*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12271*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3, requires_grad=True)
12272*da0073e9SAndroid Build Coastguard Worker            out = (a**2).sum()
12273*da0073e9SAndroid Build Coastguard Worker
12274*da0073e9SAndroid Build Coastguard Worker        msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context"
12275*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(AssertionError, msg):
12276*da0073e9SAndroid Build Coastguard Worker            out.backward()
12277*da0073e9SAndroid Build Coastguard Worker
12278*da0073e9SAndroid Build Coastguard Worker        # Different context
12279*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12280*da0073e9SAndroid Build Coastguard Worker            a = torch.rand(2, 3, requires_grad=True)
12281*da0073e9SAndroid Build Coastguard Worker            out = (a**2).sum()
12282*da0073e9SAndroid Build Coastguard Worker
12283*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12284*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(AssertionError, msg):
12285*da0073e9SAndroid Build Coastguard Worker                out.backward()
12286*da0073e9SAndroid Build Coastguard Worker
12287*da0073e9SAndroid Build Coastguard Worker    def test_disallow_nesting(self):
12288*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12289*da0073e9SAndroid Build Coastguard Worker            msg = "allow_mutation_on_saved_tensors contexts cannot be nested"
12290*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, msg):
12291*da0073e9SAndroid Build Coastguard Worker                with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx:
12292*da0073e9SAndroid Build Coastguard Worker                    pass
12293*da0073e9SAndroid Build Coastguard Worker
12294*da0073e9SAndroid Build Coastguard Worker
12295*da0073e9SAndroid Build Coastguard Workerclass TestAutogradInferenceMode(TestCase):
12296*da0073e9SAndroid Build Coastguard Worker    def _is_inference_tensor(self, tensor):
12297*da0073e9SAndroid Build Coastguard Worker        try:
12298*da0073e9SAndroid Build Coastguard Worker            err_msg = "Inference tensors do not track version counter"
12299*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, err_msg):
12300*da0073e9SAndroid Build Coastguard Worker                tensor._version
12301*da0073e9SAndroid Build Coastguard Worker            return True
12302*da0073e9SAndroid Build Coastguard Worker        except AssertionError as e:
12303*da0073e9SAndroid Build Coastguard Worker            return False
12304*da0073e9SAndroid Build Coastguard Worker
12305*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_context_manager(self):
12306*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_inference_mode_enabled())
12307*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
12308*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference_mode_enabled())
12309*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode(False):
12310*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference_mode_enabled())
12311*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference_mode_enabled())
12312*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_inference_mode_enabled())
12313*da0073e9SAndroid Build Coastguard Worker
12314*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_decorator(self):
12315*da0073e9SAndroid Build Coastguard Worker        def func(x):
12316*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.is_inference_mode_enabled(), mode)
12317*da0073e9SAndroid Build Coastguard Worker            return x * x
12318*da0073e9SAndroid Build Coastguard Worker
12319*da0073e9SAndroid Build Coastguard Worker        for mode, use_kwarg in product((True, False, None), (True, False)):
12320*da0073e9SAndroid Build Coastguard Worker            if mode is None:
12321*da0073e9SAndroid Build Coastguard Worker                if use_kwarg:
12322*da0073e9SAndroid Build Coastguard Worker                    decorated = torch.inference_mode(mode=func)
12323*da0073e9SAndroid Build Coastguard Worker                else:
12324*da0073e9SAndroid Build Coastguard Worker                    decorated = torch.inference_mode(func)
12325*da0073e9SAndroid Build Coastguard Worker                mode = True
12326*da0073e9SAndroid Build Coastguard Worker            else:
12327*da0073e9SAndroid Build Coastguard Worker                if use_kwarg:
12328*da0073e9SAndroid Build Coastguard Worker                    decorated = torch.inference_mode(mode=mode)(func)
12329*da0073e9SAndroid Build Coastguard Worker                else:
12330*da0073e9SAndroid Build Coastguard Worker                    decorated = torch.inference_mode(mode)(func)
12331*da0073e9SAndroid Build Coastguard Worker
12332*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12333*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12334*da0073e9SAndroid Build Coastguard Worker                d = decorated(c)
12335*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(not mode or torch.is_inference(d))
12336*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(d.requires_grad, requires_grad and not mode)
12337*da0073e9SAndroid Build Coastguard Worker
12338*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_tensor_creation(self):
12339*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
12340*da0073e9SAndroid Build Coastguard Worker            # new tensors created through constructors are inference tensors
12341*da0073e9SAndroid Build Coastguard Worker            c = torch.ones(1, 2, 3)
12342*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(c.requires_grad)
12343*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference(c))
12344*da0073e9SAndroid Build Coastguard Worker
12345*da0073e9SAndroid Build Coastguard Worker            # requires_grad doesn't change inference tensor behavior in InferenceMode
12346*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(1, 2, 3, requires_grad=True)
12347*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(tmp.requires_grad)
12348*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference(tmp))
12349*da0073e9SAndroid Build Coastguard Worker
12350*da0073e9SAndroid Build Coastguard Worker            tmp = torch.ones(1, 2, 3).requires_grad_(False)
12351*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(tmp.requires_grad)
12352*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference(tmp))
12353*da0073e9SAndroid Build Coastguard Worker
12354*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_existing_autograd_session(self):
12355*da0073e9SAndroid Build Coastguard Worker        s = torch.ones(1, 2, 3, requires_grad=True)
12356*da0073e9SAndroid Build Coastguard Worker        a = s.clone()
12357*da0073e9SAndroid Build Coastguard Worker
12358*da0073e9SAndroid Build Coastguard Worker        # `a` gets saved outside of inference mode
12359*da0073e9SAndroid Build Coastguard Worker        out = a * a
12360*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
12361*da0073e9SAndroid Build Coastguard Worker            a.add_(2)
12362*da0073e9SAndroid Build Coastguard Worker
12363*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_inference(a))
12364*da0073e9SAndroid Build Coastguard Worker        # tensors created outside of inference mode aren't
12365*da0073e9SAndroid Build Coastguard Worker        # inference tensors, so they will still have their
12366*da0073e9SAndroid Build Coastguard Worker        # version counters tracked
12367*da0073e9SAndroid Build Coastguard Worker        err_msg = (
12368*da0073e9SAndroid Build Coastguard Worker            "one of the variables needed for gradient computation has been "
12369*da0073e9SAndroid Build Coastguard Worker            "modified by an inplace operation"
12370*da0073e9SAndroid Build Coastguard Worker        )
12371*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
12372*da0073e9SAndroid Build Coastguard Worker            out.backward(torch.ones_like(out))
12373*da0073e9SAndroid Build Coastguard Worker
12374*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_inf_tensor_in_inf_mode_functional_op(self):
12375*da0073e9SAndroid Build Coastguard Worker        def functional_op(x):
12376*da0073e9SAndroid Build Coastguard Worker            return x * x
12377*da0073e9SAndroid Build Coastguard Worker
12378*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
12379*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12380*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12381*da0073e9SAndroid Build Coastguard Worker
12382*da0073e9SAndroid Build Coastguard Worker                # performing a non-view operation produces a inference tensor
12383*da0073e9SAndroid Build Coastguard Worker                # that does not require grad
12384*da0073e9SAndroid Build Coastguard Worker                func_out = functional_op(c)
12385*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_inference(func_out))
12386*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(func_out.requires_grad)
12387*da0073e9SAndroid Build Coastguard Worker
12388*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self):
12389*da0073e9SAndroid Build Coastguard Worker        @torch.inference_mode()
12390*da0073e9SAndroid Build Coastguard Worker        def run_test(fn):
12391*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12392*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12393*da0073e9SAndroid Build Coastguard Worker
12394*da0073e9SAndroid Build Coastguard Worker                # after performing inplace operation, tensor is still
12395*da0073e9SAndroid Build Coastguard Worker                # an inference tensor
12396*da0073e9SAndroid Build Coastguard Worker                fn(c)
12397*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_inference(c))
12398*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c.requires_grad, requires_grad)
12399*da0073e9SAndroid Build Coastguard Worker
12400*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.add_(2))
12401*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.transpose_(0, 1))
12402*da0073e9SAndroid Build Coastguard Worker
12403*da0073e9SAndroid Build Coastguard Worker        # inplace ops with manual kernel for ADInplaceOrView key in VariableTypeManual.cpp
12404*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.resize_(1, 2))
12405*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.resize_as_(torch.ones(1, 2)))
12406*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.copy_(torch.ones(1, 2, 3)))
12407*da0073e9SAndroid Build Coastguard Worker
12408*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_inf_tensor_in_inf_mode_view_op(self):
12409*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
12410*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12411*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12412*da0073e9SAndroid Build Coastguard Worker
12413*da0073e9SAndroid Build Coastguard Worker                # perform view operation produces inference tensor
12414*da0073e9SAndroid Build Coastguard Worker                # that does not require grad
12415*da0073e9SAndroid Build Coastguard Worker                view_out = c.view(-1)
12416*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_inference(view_out))
12417*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(view_out.requires_grad)
12418*da0073e9SAndroid Build Coastguard Worker
12419*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_inf_tensor_in_normal_mode_functional_op(self):
12420*da0073e9SAndroid Build Coastguard Worker        def functional_op(x):
12421*da0073e9SAndroid Build Coastguard Worker            return x * x
12422*da0073e9SAndroid Build Coastguard Worker
12423*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12424*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12425*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12426*da0073e9SAndroid Build Coastguard Worker
12427*da0073e9SAndroid Build Coastguard Worker        func_out = functional_op(c)
12428*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.is_inference(func_out))
12429*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(func_out.requires_grad)
12430*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(func_out.is_leaf)
12431*da0073e9SAndroid Build Coastguard Worker
12432*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self):
12433*da0073e9SAndroid Build Coastguard Worker        def run_test(fn):
12434*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (False, True):
12435*da0073e9SAndroid Build Coastguard Worker                with torch.inference_mode():
12436*da0073e9SAndroid Build Coastguard Worker                    c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12437*da0073e9SAndroid Build Coastguard Worker
12438*da0073e9SAndroid Build Coastguard Worker                if requires_grad:
12439*da0073e9SAndroid Build Coastguard Worker                    # leaf variable that requires grad is being used in an inplace
12440*da0073e9SAndroid Build Coastguard Worker                    # operation when requires_grad=True
12441*da0073e9SAndroid Build Coastguard Worker                    pass
12442*da0073e9SAndroid Build Coastguard Worker                else:
12443*da0073e9SAndroid Build Coastguard Worker                    err_msg = "Inplace update to inference tensor outside InferenceMode"
12444*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, err_msg):
12445*da0073e9SAndroid Build Coastguard Worker                        fn(c)
12446*da0073e9SAndroid Build Coastguard Worker
12447*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.add_(2))
12448*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.transpose_(0, 1))
12449*da0073e9SAndroid Build Coastguard Worker
12450*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_inf_tensor_in_normal_mode_view_op(self):
12451*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12452*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12453*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12454*da0073e9SAndroid Build Coastguard Worker
12455*da0073e9SAndroid Build Coastguard Worker            out = c.view(-1)
12456*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference(out))
12457*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(out.requires_grad)
12458*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(out._is_view())
12459*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(out.is_leaf)
12460*da0073e9SAndroid Build Coastguard Worker
12461*da0073e9SAndroid Build Coastguard Worker    def test_normal_tensor_inplace_output_in_inference_mode(self):
12462*da0073e9SAndroid Build Coastguard Worker        def run_test(fn):
12463*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12464*da0073e9SAndroid Build Coastguard Worker                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12465*da0073e9SAndroid Build Coastguard Worker                a = s.clone()
12466*da0073e9SAndroid Build Coastguard Worker
12467*da0073e9SAndroid Build Coastguard Worker                with torch.inference_mode():
12468*da0073e9SAndroid Build Coastguard Worker                    fn(a)
12469*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_inference(a))
12470*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.requires_grad, requires_grad)
12471*da0073e9SAndroid Build Coastguard Worker
12472*da0073e9SAndroid Build Coastguard Worker                    # inplace -> inplace
12473*da0073e9SAndroid Build Coastguard Worker                    fn(a)
12474*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_inference(a))
12475*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.requires_grad, requires_grad)
12476*da0073e9SAndroid Build Coastguard Worker
12477*da0073e9SAndroid Build Coastguard Worker                    # inplace -> inplace -> view
12478*da0073e9SAndroid Build Coastguard Worker                    view_out = a.view(-1)
12479*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_inference(view_out))
12480*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(view_out.requires_grad, requires_grad)
12481*da0073e9SAndroid Build Coastguard Worker
12482*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.add_(2))
12483*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.transpose_(0, 1))
12484*da0073e9SAndroid Build Coastguard Worker
12485*da0073e9SAndroid Build Coastguard Worker    def test_normal_tensor_inplace_output_in_normal_mode(self):
12486*da0073e9SAndroid Build Coastguard Worker        def run_test(fn):
12487*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12488*da0073e9SAndroid Build Coastguard Worker                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12489*da0073e9SAndroid Build Coastguard Worker                a = s.clone()
12490*da0073e9SAndroid Build Coastguard Worker
12491*da0073e9SAndroid Build Coastguard Worker                with torch.inference_mode():
12492*da0073e9SAndroid Build Coastguard Worker                    fn(a)
12493*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(torch.is_inference(a))
12494*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(a.requires_grad, requires_grad)
12495*da0073e9SAndroid Build Coastguard Worker
12496*da0073e9SAndroid Build Coastguard Worker                fn(a)
12497*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(a))
12498*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.requires_grad, requires_grad)
12499*da0073e9SAndroid Build Coastguard Worker
12500*da0073e9SAndroid Build Coastguard Worker                # inplace -> inplace
12501*da0073e9SAndroid Build Coastguard Worker                fn(a)
12502*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(a))
12503*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a.requires_grad, requires_grad)
12504*da0073e9SAndroid Build Coastguard Worker
12505*da0073e9SAndroid Build Coastguard Worker                # inplace -> inplace -> view
12506*da0073e9SAndroid Build Coastguard Worker                view_out = a.view(-1)
12507*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(view_out))
12508*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(view_out.requires_grad, requires_grad)
12509*da0073e9SAndroid Build Coastguard Worker            run_test(lambda x: x.add_(2))
12510*da0073e9SAndroid Build Coastguard Worker            run_test(lambda x: x.transpose_(0, 1))
12511*da0073e9SAndroid Build Coastguard Worker
12512*da0073e9SAndroid Build Coastguard Worker    def test_normal_tensor_view_output_in_inference_mode(self):
12513*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12514*da0073e9SAndroid Build Coastguard Worker            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12515*da0073e9SAndroid Build Coastguard Worker            a = s.clone()
12516*da0073e9SAndroid Build Coastguard Worker
12517*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12518*da0073e9SAndroid Build Coastguard Worker                out = a.view(-1)
12519*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(out))
12520*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.requires_grad, requires_grad)
12521*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out._is_view())
12522*da0073e9SAndroid Build Coastguard Worker
12523*da0073e9SAndroid Build Coastguard Worker                # view -> view
12524*da0073e9SAndroid Build Coastguard Worker                tmp = out.view(-1)
12525*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(tmp))
12526*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tmp.requires_grad, requires_grad)
12527*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(tmp._is_view())
12528*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(tmp.is_leaf)
12529*da0073e9SAndroid Build Coastguard Worker
12530*da0073e9SAndroid Build Coastguard Worker                # view -> view -> inplace
12531*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.is_inference_mode_enabled())
12532*da0073e9SAndroid Build Coastguard Worker                tmp.add_(2)
12533*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(tmp))
12534*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(tmp.requires_grad, requires_grad)
12535*da0073e9SAndroid Build Coastguard Worker                # Accessing is_leaf in python tries to update grad_fn and raises:
12536*da0073e9SAndroid Build Coastguard Worker                # A view was created in inference mode and its base or
12537*da0073e9SAndroid Build Coastguard Worker                # another view of its base has been modified inplace in normal mode
12538*da0073e9SAndroid Build Coastguard Worker                # tmp.is_leaf
12539*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(a._version, tmp._version)
12540*da0073e9SAndroid Build Coastguard Worker
12541*da0073e9SAndroid Build Coastguard Worker    def test_normal_tensor_view_output_in_normal_mode(self):
12542*da0073e9SAndroid Build Coastguard Worker        def functional_op(x):
12543*da0073e9SAndroid Build Coastguard Worker            return x * x
12544*da0073e9SAndroid Build Coastguard Worker
12545*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12546*da0073e9SAndroid Build Coastguard Worker            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12547*da0073e9SAndroid Build Coastguard Worker            a = s.clone()
12548*da0073e9SAndroid Build Coastguard Worker
12549*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12550*da0073e9SAndroid Build Coastguard Worker                out = a.view(-1)
12551*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.is_inference(out))
12552*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.requires_grad, requires_grad)
12553*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out._is_view())
12554*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(out.is_leaf)
12555*da0073e9SAndroid Build Coastguard Worker
12556*da0073e9SAndroid Build Coastguard Worker            tmp = functional_op(out)
12557*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_inference(tmp))
12558*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tmp.requires_grad, requires_grad)
12559*da0073e9SAndroid Build Coastguard Worker
12560*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
12561*da0073e9SAndroid Build Coastguard Worker                err_msg = (
12562*da0073e9SAndroid Build Coastguard Worker                    "A view was created in inference mode and is being modified inplace"
12563*da0073e9SAndroid Build Coastguard Worker                )
12564*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, err_msg):
12565*da0073e9SAndroid Build Coastguard Worker                    out.add_(2)
12566*da0073e9SAndroid Build Coastguard Worker            else:
12567*da0073e9SAndroid Build Coastguard Worker                out.add_(2)
12568*da0073e9SAndroid Build Coastguard Worker
12569*da0073e9SAndroid Build Coastguard Worker            tmp = out.view(2, 3)
12570*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_inference(tmp))
12571*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tmp.requires_grad, requires_grad)
12572*da0073e9SAndroid Build Coastguard Worker
12573*da0073e9SAndroid Build Coastguard Worker    def test_mix_inference_and_normal_tensor_functional_op(self):
12574*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12575*da0073e9SAndroid Build Coastguard Worker            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12576*da0073e9SAndroid Build Coastguard Worker
12577*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12578*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3, requires_grad=requires_grad)
12579*da0073e9SAndroid Build Coastguard Worker
12580*da0073e9SAndroid Build Coastguard Worker            # add is safe since it doesn't save any variable for backward
12581*da0073e9SAndroid Build Coastguard Worker            out = c.add(s)
12582*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_inference(out))
12583*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(out.requires_grad, requires_grad)
12584*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
12585*da0073e9SAndroid Build Coastguard Worker                # leaf inference tensor with requires_grad=True can still have gradient
12586*da0073e9SAndroid Build Coastguard Worker                out.backward(torch.ones_like(out))
12587*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c.grad, torch.ones_like(c))
12588*da0073e9SAndroid Build Coastguard Worker
12589*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
12590*da0073e9SAndroid Build Coastguard Worker                err_msg = "Inference tensors cannot be saved for backward"
12591*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, err_msg):
12592*da0073e9SAndroid Build Coastguard Worker                    c * s
12593*da0073e9SAndroid Build Coastguard Worker
12594*da0073e9SAndroid Build Coastguard Worker                # TODO: Test this with an autograd.Function when it works
12595*da0073e9SAndroid Build Coastguard Worker                #       stack stopped capturing a TensorList input
12596*da0073e9SAndroid Build Coastguard Worker                # # inference tensor in TensorList input
12597*da0073e9SAndroid Build Coastguard Worker                # inputs = [s, c]
12598*da0073e9SAndroid Build Coastguard Worker                # with self.assertRaisesRegex(RuntimeError, err_msg):
12599*da0073e9SAndroid Build Coastguard Worker                #     torch.stack(inputs)
12600*da0073e9SAndroid Build Coastguard Worker
12601*da0073e9SAndroid Build Coastguard Worker    def test_mix_inference_and_normal_tensor_inplace_op(self):
12602*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12603*da0073e9SAndroid Build Coastguard Worker            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12604*da0073e9SAndroid Build Coastguard Worker            a = s.clone()
12605*da0073e9SAndroid Build Coastguard Worker
12606*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12607*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3)
12608*da0073e9SAndroid Build Coastguard Worker
12609*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference(c))
12610*da0073e9SAndroid Build Coastguard Worker            if requires_grad:
12611*da0073e9SAndroid Build Coastguard Worker                err_msg = "Inference tensors cannot be saved for backward"
12612*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, err_msg):
12613*da0073e9SAndroid Build Coastguard Worker                    a.mul_(c)
12614*da0073e9SAndroid Build Coastguard Worker
12615*da0073e9SAndroid Build Coastguard Worker                # inference tensor in TensorList input
12616*da0073e9SAndroid Build Coastguard Worker                err_msg = (
12617*da0073e9SAndroid Build Coastguard Worker                    "out=... arguments don't support automatic differentiation, "
12618*da0073e9SAndroid Build Coastguard Worker                    "but one of the arguments requires grad"
12619*da0073e9SAndroid Build Coastguard Worker                )
12620*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, err_msg):
12621*da0073e9SAndroid Build Coastguard Worker                    torch.mul(s, s, out=c)
12622*da0073e9SAndroid Build Coastguard Worker            else:
12623*da0073e9SAndroid Build Coastguard Worker                a.mul_(c)
12624*da0073e9SAndroid Build Coastguard Worker                err_msg = "Inplace update to inference tensor outside InferenceMode is not allowed"
12625*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, err_msg):
12626*da0073e9SAndroid Build Coastguard Worker                    torch.mul(s, s, out=c)
12627*da0073e9SAndroid Build Coastguard Worker
12628*da0073e9SAndroid Build Coastguard Worker    def test_mix_inference_and_normal_tensor_view_op(self):
12629*da0073e9SAndroid Build Coastguard Worker        for requires_grad in (True, False):
12630*da0073e9SAndroid Build Coastguard Worker            s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12631*da0073e9SAndroid Build Coastguard Worker
12632*da0073e9SAndroid Build Coastguard Worker            with torch.inference_mode():
12633*da0073e9SAndroid Build Coastguard Worker                c = torch.ones(1, 2, 3)
12634*da0073e9SAndroid Build Coastguard Worker
12635*da0073e9SAndroid Build Coastguard Worker            # view_as is a composite op which calls view with only one
12636*da0073e9SAndroid Build Coastguard Worker            # tensor argument. So there isn't a mixed inference and normal
12637*da0073e9SAndroid Build Coastguard Worker            # tensor inputs for view ops
12638*da0073e9SAndroid Build Coastguard Worker            tmp1 = c.view_as(s)
12639*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.is_inference(tmp1))
12640*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(tmp1.requires_grad)
12641*da0073e9SAndroid Build Coastguard Worker
12642*da0073e9SAndroid Build Coastguard Worker            # this is fine since its equivalent as s.view(c.sizes()) which
12643*da0073e9SAndroid Build Coastguard Worker            # isn't a mixed input scenario
12644*da0073e9SAndroid Build Coastguard Worker            tmp2 = s.view_as(c)
12645*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.is_inference(tmp2))
12646*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(tmp2.requires_grad, requires_grad)
12647*da0073e9SAndroid Build Coastguard Worker
12648*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_handle_direct_view_on_rebase(self):
12649*da0073e9SAndroid Build Coastguard Worker        def run_test(fn):
12650*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12651*da0073e9SAndroid Build Coastguard Worker                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12652*da0073e9SAndroid Build Coastguard Worker                a = s.clone()
12653*da0073e9SAndroid Build Coastguard Worker
12654*da0073e9SAndroid Build Coastguard Worker                with torch.inference_mode():
12655*da0073e9SAndroid Build Coastguard Worker                    view_out = a.view_as(a)
12656*da0073e9SAndroid Build Coastguard Worker
12657*da0073e9SAndroid Build Coastguard Worker                if requires_grad:
12658*da0073e9SAndroid Build Coastguard Worker                    err_msg = "A view was created in inference mode and is being modified inplace"
12659*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, err_msg):
12660*da0073e9SAndroid Build Coastguard Worker                        fn(view_out)
12661*da0073e9SAndroid Build Coastguard Worker                else:
12662*da0073e9SAndroid Build Coastguard Worker                    fn(view_out)
12663*da0073e9SAndroid Build Coastguard Worker
12664*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.add_(2))
12665*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.transpose_(0, 1))
12666*da0073e9SAndroid Build Coastguard Worker
12667*da0073e9SAndroid Build Coastguard Worker    def test_inference_mode_handle_indirect_view_on_rebase(self):
12668*da0073e9SAndroid Build Coastguard Worker        def run_test(fn):
12669*da0073e9SAndroid Build Coastguard Worker            for requires_grad in (True, False):
12670*da0073e9SAndroid Build Coastguard Worker                s = torch.ones(1, 2, 3, requires_grad=requires_grad)
12671*da0073e9SAndroid Build Coastguard Worker                a = s.clone()
12672*da0073e9SAndroid Build Coastguard Worker
12673*da0073e9SAndroid Build Coastguard Worker                with torch.inference_mode():
12674*da0073e9SAndroid Build Coastguard Worker                    view_out = a.view(-1)
12675*da0073e9SAndroid Build Coastguard Worker
12676*da0073e9SAndroid Build Coastguard Worker                fn(a)
12677*da0073e9SAndroid Build Coastguard Worker                if requires_grad:
12678*da0073e9SAndroid Build Coastguard Worker                    err_msg = "A view was created in inference mode and its base or another view "
12679*da0073e9SAndroid Build Coastguard Worker                    with self.assertRaisesRegex(RuntimeError, err_msg):
12680*da0073e9SAndroid Build Coastguard Worker                        view_out.grad_fn
12681*da0073e9SAndroid Build Coastguard Worker                else:
12682*da0073e9SAndroid Build Coastguard Worker                    view_out.grad_fn
12683*da0073e9SAndroid Build Coastguard Worker
12684*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.add_(2))
12685*da0073e9SAndroid Build Coastguard Worker        run_test(lambda x: x.transpose_(0, 1))
12686*da0073e9SAndroid Build Coastguard Worker
12687*da0073e9SAndroid Build Coastguard Worker
12688*da0073e9SAndroid Build Coastguard Workerclass TestMultithreadAutograd(TestCase):
12689*da0073e9SAndroid Build Coastguard Worker    def _run_py_multithread_fn(
12690*da0073e9SAndroid Build Coastguard Worker        self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False
12691*da0073e9SAndroid Build Coastguard Worker    ):
12692*da0073e9SAndroid Build Coastguard Worker        class PropagatingThread(threading.Thread):
12693*da0073e9SAndroid Build Coastguard Worker            """Helper class to propagate exception from child
12694*da0073e9SAndroid Build Coastguard Worker            thread to main thread on join.
12695*da0073e9SAndroid Build Coastguard Worker
12696*da0073e9SAndroid Build Coastguard Worker            Reference: https://stackoverflow.com/a/31614591/5602957
12697*da0073e9SAndroid Build Coastguard Worker            """
12698*da0073e9SAndroid Build Coastguard Worker
12699*da0073e9SAndroid Build Coastguard Worker            def run(self):
12700*da0073e9SAndroid Build Coastguard Worker                self.exception = None
12701*da0073e9SAndroid Build Coastguard Worker                try:
12702*da0073e9SAndroid Build Coastguard Worker                    self.ret = super().run()
12703*da0073e9SAndroid Build Coastguard Worker                except Exception as e:
12704*da0073e9SAndroid Build Coastguard Worker                    self.exception = e
12705*da0073e9SAndroid Build Coastguard Worker
12706*da0073e9SAndroid Build Coastguard Worker            def join(self, timeout=None):
12707*da0073e9SAndroid Build Coastguard Worker                super().join(timeout)
12708*da0073e9SAndroid Build Coastguard Worker                if self.exception:
12709*da0073e9SAndroid Build Coastguard Worker                    raise self.exception from self.exception
12710*da0073e9SAndroid Build Coastguard Worker                return self.ret
12711*da0073e9SAndroid Build Coastguard Worker
12712*da0073e9SAndroid Build Coastguard Worker        threads = []
12713*da0073e9SAndroid Build Coastguard Worker        for idx in range(num_threads):
12714*da0073e9SAndroid Build Coastguard Worker            p = PropagatingThread(target=fn, args=((idx, *args) if pass_idx else args))
12715*da0073e9SAndroid Build Coastguard Worker            p.start()
12716*da0073e9SAndroid Build Coastguard Worker            threads.append(p)
12717*da0073e9SAndroid Build Coastguard Worker
12718*da0073e9SAndroid Build Coastguard Worker        for p in threads:
12719*da0073e9SAndroid Build Coastguard Worker            p.join()
12720*da0073e9SAndroid Build Coastguard Worker
12721*da0073e9SAndroid Build Coastguard Worker    def test_multithreaded_exception_propagation(self):
12722*da0073e9SAndroid Build Coastguard Worker        # Test whether exception in child thread
12723*da0073e9SAndroid Build Coastguard Worker        # are propagated to main thread.
12724*da0073e9SAndroid Build Coastguard Worker        def fn():
12725*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(False)
12726*da0073e9SAndroid Build Coastguard Worker
12727*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
12728*da0073e9SAndroid Build Coastguard Worker            self._run_py_multithread_fn(fn)
12729*da0073e9SAndroid Build Coastguard Worker
12730*da0073e9SAndroid Build Coastguard Worker    def test_simple_backward(self):
12731*da0073e9SAndroid Build Coastguard Worker        # simple multithreaded backward that create threads in the beginning of training
12732*da0073e9SAndroid Build Coastguard Worker        # and everything else is training separately, i.e. inputs, operations, etc.
12733*da0073e9SAndroid Build Coastguard Worker        def train_fn():
12734*da0073e9SAndroid Build Coastguard Worker            x = torch.ones(5, 5, requires_grad=True)
12735*da0073e9SAndroid Build Coastguard Worker            y = (x + 3) * (x + 4) * 0.5
12736*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
12737*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, x + 3.5)
12738*da0073e9SAndroid Build Coastguard Worker
12739*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(train_fn)
12740*da0073e9SAndroid Build Coastguard Worker
12741*da0073e9SAndroid Build Coastguard Worker    def test_simple_backward_same_input(self):
12742*da0073e9SAndroid Build Coastguard Worker        # simple multithreaded backward with only shared inputs (i.e. This is common
12743*da0073e9SAndroid Build Coastguard Worker        # for things like Hogwild multithreaded training with multiple CPU threads)
12744*da0073e9SAndroid Build Coastguard Worker        def train_fn_backward(x):
12745*da0073e9SAndroid Build Coastguard Worker            y = (x + 3) * (x + 4) * 0.5
12746*da0073e9SAndroid Build Coastguard Worker            y.sum().backward()
12747*da0073e9SAndroid Build Coastguard Worker
12748*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
12749*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(train_fn_backward, (x,))
12750*da0073e9SAndroid Build Coastguard Worker        # Since we are calling backward from multiple threads
12751*da0073e9SAndroid Build Coastguard Worker        # and all threads share the same input, when we do backward
12752*da0073e9SAndroid Build Coastguard Worker        # concurrently, different backwards will all accumulate to
12753*da0073e9SAndroid Build Coastguard Worker        # the same .grad for each input, and the gradients should
12754*da0073e9SAndroid Build Coastguard Worker        # be equal to num_threads * gradient
12755*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, 10 * (x + 3.5))
12756*da0073e9SAndroid Build Coastguard Worker
12757*da0073e9SAndroid Build Coastguard Worker        def train_fn_grad(x):
12758*da0073e9SAndroid Build Coastguard Worker            y = (x + 3) * (x + 4) * 0.5
12759*da0073e9SAndroid Build Coastguard Worker            grads = torch.autograd.grad(y.sum(), x)
12760*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(grads), 1)
12761*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(grads[0], x + 3.5)
12762*da0073e9SAndroid Build Coastguard Worker
12763*da0073e9SAndroid Build Coastguard Worker        # since we use functional grad() api, gradients will not
12764*da0073e9SAndroid Build Coastguard Worker        # be accumulate to the same place and should be the same
12765*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(train_fn_grad, (x,))
12766*da0073e9SAndroid Build Coastguard Worker
12767*da0073e9SAndroid Build Coastguard Worker    def test_multi_grad_all_hooks(self):
12768*da0073e9SAndroid Build Coastguard Worker        # Multihooks should behave independently per execution of backward
12769*da0073e9SAndroid Build Coastguard Worker        # Test that the hook fired the number of times we ran backward
12770*da0073e9SAndroid Build Coastguard Worker        # even if those executions occur concurrently on different threads
12771*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(2, requires_grad=True)
12772*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(2, requires_grad=True)
12773*da0073e9SAndroid Build Coastguard Worker        t3 = torch.rand(2, requires_grad=True)
12774*da0073e9SAndroid Build Coastguard Worker        t4 = torch.rand(2, requires_grad=True)
12775*da0073e9SAndroid Build Coastguard Worker
12776*da0073e9SAndroid Build Coastguard Worker        res = None
12777*da0073e9SAndroid Build Coastguard Worker        count = [0]
12778*da0073e9SAndroid Build Coastguard Worker        hook_lock = threading.Lock()
12779*da0073e9SAndroid Build Coastguard Worker
12780*da0073e9SAndroid Build Coastguard Worker        def hook(grads):
12781*da0073e9SAndroid Build Coastguard Worker            nonlocal res
12782*da0073e9SAndroid Build Coastguard Worker            with hook_lock:
12783*da0073e9SAndroid Build Coastguard Worker                count[0] += 1
12784*da0073e9SAndroid Build Coastguard Worker                grad_is_none = [g is not None for g in grads]
12785*da0073e9SAndroid Build Coastguard Worker                if res is None:
12786*da0073e9SAndroid Build Coastguard Worker                    res = grad_is_none
12787*da0073e9SAndroid Build Coastguard Worker                else:
12788*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res, grad_is_none)
12789*da0073e9SAndroid Build Coastguard Worker
12790*da0073e9SAndroid Build Coastguard Worker        torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook)
12791*da0073e9SAndroid Build Coastguard Worker
12792*da0073e9SAndroid Build Coastguard Worker        out = (t2 * t3).sum()
12793*da0073e9SAndroid Build Coastguard Worker
12794*da0073e9SAndroid Build Coastguard Worker        def backward_retain_graph(out, t2, t3):
12795*da0073e9SAndroid Build Coastguard Worker            out.backward(inputs=(t2, t3), retain_graph=True)
12796*da0073e9SAndroid Build Coastguard Worker
12797*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12798*da0073e9SAndroid Build Coastguard Worker
12799*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 5)
12800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, [False, True, True, False])
12801*da0073e9SAndroid Build Coastguard Worker
12802*da0073e9SAndroid Build Coastguard Worker        # Leave one hook partially applied
12803*da0073e9SAndroid Build Coastguard Worker        res = None
12804*da0073e9SAndroid Build Coastguard Worker        count = [0]
12805*da0073e9SAndroid Build Coastguard Worker        err_count = [0]
12806*da0073e9SAndroid Build Coastguard Worker        bw_count = [0]
12807*da0073e9SAndroid Build Coastguard Worker        bw_count_lock = threading.Lock()
12808*da0073e9SAndroid Build Coastguard Worker        err_count_lock = threading.Lock()
12809*da0073e9SAndroid Build Coastguard Worker
12810*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
12811*da0073e9SAndroid Build Coastguard Worker            @staticmethod
12812*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
12813*da0073e9SAndroid Build Coastguard Worker                return x
12814*da0073e9SAndroid Build Coastguard Worker
12815*da0073e9SAndroid Build Coastguard Worker            @staticmethod
12816*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
12817*da0073e9SAndroid Build Coastguard Worker                with bw_count_lock:
12818*da0073e9SAndroid Build Coastguard Worker                    bw_count[0] += 1
12819*da0073e9SAndroid Build Coastguard Worker                    if bw_count[0] == 1:
12820*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError("error message")
12821*da0073e9SAndroid Build Coastguard Worker                    else:
12822*da0073e9SAndroid Build Coastguard Worker                        return gO
12823*da0073e9SAndroid Build Coastguard Worker
12824*da0073e9SAndroid Build Coastguard Worker        out = (Func.apply(t2) * t3).sum()
12825*da0073e9SAndroid Build Coastguard Worker
12826*da0073e9SAndroid Build Coastguard Worker        def backward_retain_graph(out, t2, t3):
12827*da0073e9SAndroid Build Coastguard Worker            try:
12828*da0073e9SAndroid Build Coastguard Worker                out.backward(inputs=(t2, t3), retain_graph=True)
12829*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
12830*da0073e9SAndroid Build Coastguard Worker                with err_count_lock:
12831*da0073e9SAndroid Build Coastguard Worker                    err_count[0] += 1
12832*da0073e9SAndroid Build Coastguard Worker
12833*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12834*da0073e9SAndroid Build Coastguard Worker
12835*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 4)
12836*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(err_count[0], 1)
12837*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, [False, True, True, False])
12838*da0073e9SAndroid Build Coastguard Worker
12839*da0073e9SAndroid Build Coastguard Worker    def test_multi_grad_any_hooks(self):
12840*da0073e9SAndroid Build Coastguard Worker        # Multihooks should behave independently per execution of backward
12841*da0073e9SAndroid Build Coastguard Worker        # Test that the hook fired the number of times we ran backward
12842*da0073e9SAndroid Build Coastguard Worker        # even if those executions occur concurrently on different threads
12843*da0073e9SAndroid Build Coastguard Worker        t1 = torch.rand(2, requires_grad=True)
12844*da0073e9SAndroid Build Coastguard Worker        t2 = torch.rand(2, requires_grad=True)
12845*da0073e9SAndroid Build Coastguard Worker        t3 = torch.rand(2, requires_grad=True)
12846*da0073e9SAndroid Build Coastguard Worker        t4 = torch.rand(2, requires_grad=True)
12847*da0073e9SAndroid Build Coastguard Worker
12848*da0073e9SAndroid Build Coastguard Worker        res = None
12849*da0073e9SAndroid Build Coastguard Worker        count = [0]
12850*da0073e9SAndroid Build Coastguard Worker        hook_lock = threading.Lock()
12851*da0073e9SAndroid Build Coastguard Worker
12852*da0073e9SAndroid Build Coastguard Worker        def hook(grad):
12853*da0073e9SAndroid Build Coastguard Worker            nonlocal res
12854*da0073e9SAndroid Build Coastguard Worker            with hook_lock:
12855*da0073e9SAndroid Build Coastguard Worker                count[0] += 1
12856*da0073e9SAndroid Build Coastguard Worker                if res is None:
12857*da0073e9SAndroid Build Coastguard Worker                    res = "foo"
12858*da0073e9SAndroid Build Coastguard Worker                else:
12859*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(res, "foo")
12860*da0073e9SAndroid Build Coastguard Worker
12861*da0073e9SAndroid Build Coastguard Worker        torch.autograd.graph.register_multi_grad_hook(
12862*da0073e9SAndroid Build Coastguard Worker            (t1, t2, t3, t4), hook, mode="any"
12863*da0073e9SAndroid Build Coastguard Worker        )
12864*da0073e9SAndroid Build Coastguard Worker
12865*da0073e9SAndroid Build Coastguard Worker        out = (t2 * t3).sum()
12866*da0073e9SAndroid Build Coastguard Worker
12867*da0073e9SAndroid Build Coastguard Worker        def backward_retain_graph(out, t2, t3):
12868*da0073e9SAndroid Build Coastguard Worker            out.backward(inputs=(t2, t3), retain_graph=True)
12869*da0073e9SAndroid Build Coastguard Worker
12870*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12871*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 5)
12872*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, "foo")
12873*da0073e9SAndroid Build Coastguard Worker
12874*da0073e9SAndroid Build Coastguard Worker        # Raise an error in one thread's backward
12875*da0073e9SAndroid Build Coastguard Worker        res = None
12876*da0073e9SAndroid Build Coastguard Worker        count = [0]
12877*da0073e9SAndroid Build Coastguard Worker        err_count = [0]
12878*da0073e9SAndroid Build Coastguard Worker        bw_count = [0]
12879*da0073e9SAndroid Build Coastguard Worker        bw_count_lock = threading.Lock()
12880*da0073e9SAndroid Build Coastguard Worker        err_count_lock = threading.Lock()
12881*da0073e9SAndroid Build Coastguard Worker
12882*da0073e9SAndroid Build Coastguard Worker        class Func(torch.autograd.Function):
12883*da0073e9SAndroid Build Coastguard Worker            @staticmethod
12884*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
12885*da0073e9SAndroid Build Coastguard Worker                return x
12886*da0073e9SAndroid Build Coastguard Worker
12887*da0073e9SAndroid Build Coastguard Worker            @staticmethod
12888*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
12889*da0073e9SAndroid Build Coastguard Worker                with bw_count_lock:
12890*da0073e9SAndroid Build Coastguard Worker                    bw_count[0] += 1
12891*da0073e9SAndroid Build Coastguard Worker                    if bw_count[0] == 1:
12892*da0073e9SAndroid Build Coastguard Worker                        raise RuntimeError("error message")
12893*da0073e9SAndroid Build Coastguard Worker                    else:
12894*da0073e9SAndroid Build Coastguard Worker                        return gO
12895*da0073e9SAndroid Build Coastguard Worker
12896*da0073e9SAndroid Build Coastguard Worker        out = (Func.apply(t2) * t3).sum()
12897*da0073e9SAndroid Build Coastguard Worker
12898*da0073e9SAndroid Build Coastguard Worker        def backward_retain_graph(out, t2, t3):
12899*da0073e9SAndroid Build Coastguard Worker            try:
12900*da0073e9SAndroid Build Coastguard Worker                out.backward(inputs=(t2, t3), retain_graph=True)
12901*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
12902*da0073e9SAndroid Build Coastguard Worker                with err_count_lock:
12903*da0073e9SAndroid Build Coastguard Worker                    err_count[0] += 1
12904*da0073e9SAndroid Build Coastguard Worker
12905*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5)
12906*da0073e9SAndroid Build Coastguard Worker
12907*da0073e9SAndroid Build Coastguard Worker        # Expect all 5 threads to increment count since the hook runs before
12908*da0073e9SAndroid Build Coastguard Worker        # the custom backward
12909*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(count[0], 5)
12910*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(err_count[0], 1)
12911*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, "foo")
12912*da0073e9SAndroid Build Coastguard Worker
12913*da0073e9SAndroid Build Coastguard Worker    def test_dataparallel_saved_tensors_hooks(self):
12914*da0073e9SAndroid Build Coastguard Worker        def pack(x):
12915*da0073e9SAndroid Build Coastguard Worker            warnings.warn("pack")
12916*da0073e9SAndroid Build Coastguard Worker            return x
12917*da0073e9SAndroid Build Coastguard Worker
12918*da0073e9SAndroid Build Coastguard Worker        _self = self
12919*da0073e9SAndroid Build Coastguard Worker
12920*da0073e9SAndroid Build Coastguard Worker        class Model(torch.nn.Module):
12921*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
12922*da0073e9SAndroid Build Coastguard Worker                with warnings.catch_warnings(record=True) as w:
12923*da0073e9SAndroid Build Coastguard Worker                    y = x * x
12924*da0073e9SAndroid Build Coastguard Worker                    if torch.cuda.device_count() >= 2:
12925*da0073e9SAndroid Build Coastguard Worker                        # DataParallel is calling the forward in different threads
12926*da0073e9SAndroid Build Coastguard Worker                        # without progating TLS, so hooks should not be called here
12927*da0073e9SAndroid Build Coastguard Worker                        _self.assertEqual(len(w), 0)
12928*da0073e9SAndroid Build Coastguard Worker                    else:
12929*da0073e9SAndroid Build Coastguard Worker                        # DataParallel only uses one thread
12930*da0073e9SAndroid Build Coastguard Worker                        # so hooks should be called here
12931*da0073e9SAndroid Build Coastguard Worker                        _self.assertGreater(len(w), 0)
12932*da0073e9SAndroid Build Coastguard Worker
12933*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(5, 5, requires_grad=True)
12934*da0073e9SAndroid Build Coastguard Worker        model = torch.nn.DataParallel(Model())
12935*da0073e9SAndroid Build Coastguard Worker
12936*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x):
12937*da0073e9SAndroid Build Coastguard Worker            model(x)
12938*da0073e9SAndroid Build Coastguard Worker            with warnings.catch_warnings(record=True) as w:
12939*da0073e9SAndroid Build Coastguard Worker                y = x * x
12940*da0073e9SAndroid Build Coastguard Worker                # hooks should be called here
12941*da0073e9SAndroid Build Coastguard Worker                _self.assertGreater(len(w), 0)
12942*da0073e9SAndroid Build Coastguard Worker
12943*da0073e9SAndroid Build Coastguard Worker    def test_python_thread_in_middle(self):
12944*da0073e9SAndroid Build Coastguard Worker        # User might write a network that starts on one CPU thread, then runs its second half
12945*da0073e9SAndroid Build Coastguard Worker        # concurrently with other threads (either via python threading or fork/join calls),
12946*da0073e9SAndroid Build Coastguard Worker        # then calls backward()/grad() on BOTH threads, like a Y pattern from input at the
12947*da0073e9SAndroid Build Coastguard Worker        # bottom to output at the top. This way part of the GraphTask is being shared across
12948*da0073e9SAndroid Build Coastguard Worker        # different threads and we need to ensure user specify retain_graph=True, otherwise
12949*da0073e9SAndroid Build Coastguard Worker        # error out with the correct error message
12950*da0073e9SAndroid Build Coastguard Worker
12951*da0073e9SAndroid Build Coastguard Worker        # Case 1: multiple backward with python threads, retain_graph=False
12952*da0073e9SAndroid Build Coastguard Worker        # should throw error in some threads with no retain_graph.
12953*da0073e9SAndroid Build Coastguard Worker        success_vs_raises = [0, 0]
12954*da0073e9SAndroid Build Coastguard Worker
12955*da0073e9SAndroid Build Coastguard Worker        def train_fn_no_retain_graph(x):
12956*da0073e9SAndroid Build Coastguard Worker            y = x + x**2
12957*da0073e9SAndroid Build Coastguard Worker            try:
12958*da0073e9SAndroid Build Coastguard Worker                y.sum().backward()
12959*da0073e9SAndroid Build Coastguard Worker                success_vs_raises[0] += 1
12960*da0073e9SAndroid Build Coastguard Worker            except RuntimeError as error:
12961*da0073e9SAndroid Build Coastguard Worker                success_vs_raises[1] += 1
12962*da0073e9SAndroid Build Coastguard Worker                self.assertRegex(str(error), "Specify retain_graph=True")
12963*da0073e9SAndroid Build Coastguard Worker
12964*da0073e9SAndroid Build Coastguard Worker        x_no_retain = torch.ones(5, 5, requires_grad=True)
12965*da0073e9SAndroid Build Coastguard Worker        y_no_retain = x_no_retain + x_no_retain**2
12966*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(
12967*da0073e9SAndroid Build Coastguard Worker            train_fn_no_retain_graph, (y_no_retain,), num_threads=5
12968*da0073e9SAndroid Build Coastguard Worker        )
12969*da0073e9SAndroid Build Coastguard Worker        # at least one thread will be success in this case, all other threads should raise
12970*da0073e9SAndroid Build Coastguard Worker        # with the error that throw to user to recommend them specify retain_graph=True
12971*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(success_vs_raises[0] >= 1)
12972*da0073e9SAndroid Build Coastguard Worker
12973*da0073e9SAndroid Build Coastguard Worker        # multiple backward with python threads, no error with retain_graph=True
12974*da0073e9SAndroid Build Coastguard Worker        def train_fn_retain_graph(x):
12975*da0073e9SAndroid Build Coastguard Worker            y = x + x**2
12976*da0073e9SAndroid Build Coastguard Worker            y.sum().backward(retain_graph=True)
12977*da0073e9SAndroid Build Coastguard Worker
12978*da0073e9SAndroid Build Coastguard Worker        x_retain = torch.ones(5, 5, requires_grad=True)
12979*da0073e9SAndroid Build Coastguard Worker        y_retain = x_retain + x_retain**2
12980*da0073e9SAndroid Build Coastguard Worker        self._run_py_multithread_fn(train_fn_retain_graph, (y_retain,), num_threads=5)
12981*da0073e9SAndroid Build Coastguard Worker        # result should equal to num_thread * gradients
12982*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
12983*da0073e9SAndroid Build Coastguard Worker            x_retain.grad,
12984*da0073e9SAndroid Build Coastguard Worker            5 * (4 * x_retain**3 + 6 * (x_retain**2) + 4 * x_retain + 1),
12985*da0073e9SAndroid Build Coastguard Worker        )
12986*da0073e9SAndroid Build Coastguard Worker
12987*da0073e9SAndroid Build Coastguard Worker    def test_fork_join_in_middle(self):
12988*da0073e9SAndroid Build Coastguard Worker        # multiple backward with jit threads (fork/join primitive)
12989*da0073e9SAndroid Build Coastguard Worker        # similar to test_python_thread_in_middle, we test with retain_graph=False/True
12990*da0073e9SAndroid Build Coastguard Worker
12991*da0073e9SAndroid Build Coastguard Worker        # Case 1: multiple grad() calls with jit threads, retain_graph=False
12992*da0073e9SAndroid Build Coastguard Worker        # should throw error in some threads with no retain_graph.
12993*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12994*da0073e9SAndroid Build Coastguard Worker        def train_fn_jit_no_retain(middle, orig_x):
12995*da0073e9SAndroid Build Coastguard Worker            y = middle + middle**2
12996*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad([y.sum()], [orig_x])
12997*da0073e9SAndroid Build Coastguard Worker
12998*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
12999*da0073e9SAndroid Build Coastguard Worker        def train_fn_fork_join_calls_no_retain(x):
13000*da0073e9SAndroid Build Coastguard Worker            y_no_retain = (x + 3) * (x + 4) * 0.5
13001*da0073e9SAndroid Build Coastguard Worker
13002*da0073e9SAndroid Build Coastguard Worker            fut = torch.jit._fork(train_fn_jit_no_retain, y_no_retain, x)
13003*da0073e9SAndroid Build Coastguard Worker            grad_hat = train_fn_jit_no_retain(y_no_retain, x)
13004*da0073e9SAndroid Build Coastguard Worker            grad = torch.jit._wait(fut)
13005*da0073e9SAndroid Build Coastguard Worker            return grad, grad_hat
13006*da0073e9SAndroid Build Coastguard Worker
13007*da0073e9SAndroid Build Coastguard Worker        try:
13008*da0073e9SAndroid Build Coastguard Worker            train_fn_fork_join_calls_no_retain(torch.randn(5, 5, requires_grad=True))
13009*da0073e9SAndroid Build Coastguard Worker        except RuntimeError as error:
13010*da0073e9SAndroid Build Coastguard Worker            self.assertRegex(str(error), "Specify retain_graph=True")
13011*da0073e9SAndroid Build Coastguard Worker
13012*da0073e9SAndroid Build Coastguard Worker        # Case 2: no error with retain_graph=True
13013*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13014*da0073e9SAndroid Build Coastguard Worker        def train_fn_jit_retain(middle, orig_x):
13015*da0073e9SAndroid Build Coastguard Worker            y = middle + middle**2
13016*da0073e9SAndroid Build Coastguard Worker            return torch.autograd.grad([y.sum()], [orig_x], retain_graph=True)
13017*da0073e9SAndroid Build Coastguard Worker
13018*da0073e9SAndroid Build Coastguard Worker        @torch.jit.script
13019*da0073e9SAndroid Build Coastguard Worker        def train_fn_fork_join_calls_retain(x):
13020*da0073e9SAndroid Build Coastguard Worker            y_retain = (x + 3) * (x + 4) * 0.5
13021*da0073e9SAndroid Build Coastguard Worker            fut1 = torch.jit._fork(train_fn_jit_retain, y_retain, x)
13022*da0073e9SAndroid Build Coastguard Worker            fut2 = torch.jit._fork(train_fn_jit_retain, y_retain, x)
13023*da0073e9SAndroid Build Coastguard Worker            grad = train_fn_jit_retain(y_retain, x)
13024*da0073e9SAndroid Build Coastguard Worker            grad1 = torch.jit._wait(fut1)
13025*da0073e9SAndroid Build Coastguard Worker            grad2 = torch.jit._wait(fut2)
13026*da0073e9SAndroid Build Coastguard Worker            return grad, grad1, grad2
13027*da0073e9SAndroid Build Coastguard Worker
13028*da0073e9SAndroid Build Coastguard Worker        grad, grad1, grad2 = train_fn_fork_join_calls_retain(
13029*da0073e9SAndroid Build Coastguard Worker            torch.randn(5, 5, requires_grad=True)
13030*da0073e9SAndroid Build Coastguard Worker        )
13031*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad, grad1)
13032*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(grad, grad2)
13033*da0073e9SAndroid Build Coastguard Worker
13034*da0073e9SAndroid Build Coastguard Worker    def test_preserve_backtrace(self):
13035*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.autograd.Function):
13036*da0073e9SAndroid Build Coastguard Worker            @staticmethod
13037*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, input):
13038*da0073e9SAndroid Build Coastguard Worker                return input
13039*da0073e9SAndroid Build Coastguard Worker
13040*da0073e9SAndroid Build Coastguard Worker            @staticmethod
13041*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, *grad):
13042*da0073e9SAndroid Build Coastguard Worker                raise ValueError("something")
13043*da0073e9SAndroid Build Coastguard Worker
13044*da0073e9SAndroid Build Coastguard Worker        t = torch.rand(10, requires_grad=True)
13045*da0073e9SAndroid Build Coastguard Worker        try:
13046*da0073e9SAndroid Build Coastguard Worker            Foo.apply(t).sum().backward()
13047*da0073e9SAndroid Build Coastguard Worker        except Exception:
13048*da0073e9SAndroid Build Coastguard Worker            import traceback
13049*da0073e9SAndroid Build Coastguard Worker
13050*da0073e9SAndroid Build Coastguard Worker            tb = sys.exc_info()[2]
13051*da0073e9SAndroid Build Coastguard Worker            tb_str = "\n".join(traceback.format_tb(tb))
13052*da0073e9SAndroid Build Coastguard Worker            self.assertTrue('raise ValueError("something")' in tb_str)
13053*da0073e9SAndroid Build Coastguard Worker
13054*da0073e9SAndroid Build Coastguard Worker    # TODO(@anjali411): add an OpInfo based test for torch.cat
13055*da0073e9SAndroid Build Coastguard Worker    # Issue: https://github.com/pytorch/pytorch/issues/51627
13056*da0073e9SAndroid Build Coastguard Worker    #        https://github.com/pytorch/pytorch/issues/75852
13057*da0073e9SAndroid Build Coastguard Worker    def test_cat_stack_r_to_c(self):
13058*da0073e9SAndroid Build Coastguard Worker        inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True)
13059*da0073e9SAndroid Build Coastguard Worker        inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True)
13060*da0073e9SAndroid Build Coastguard Worker
13061*da0073e9SAndroid Build Coastguard Worker        def fn(x1, x2):
13062*da0073e9SAndroid Build Coastguard Worker            return torch.cat((x1, x2), dim=-1)
13063*da0073e9SAndroid Build Coastguard Worker
13064*da0073e9SAndroid Build Coastguard Worker        def fn2(x1, x2):
13065*da0073e9SAndroid Build Coastguard Worker            return torch.stack((x1, x2), dim=-1)
13066*da0073e9SAndroid Build Coastguard Worker
13067*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True)
13068*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True)
13069*da0073e9SAndroid Build Coastguard Worker
13070*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True)
13071*da0073e9SAndroid Build Coastguard Worker        torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True)
13072*da0073e9SAndroid Build Coastguard Worker
13073*da0073e9SAndroid Build Coastguard Worker    def test_set_multithreading_enabled_as_context_manager_and_function(self):
13074*da0073e9SAndroid Build Coastguard Worker        # Test as a context manager
13075*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
13076*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.autograd.is_multithreading_enabled())
13077*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.autograd.is_multithreading_enabled())
13078*da0073e9SAndroid Build Coastguard Worker
13079*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(True):
13080*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd.is_multithreading_enabled())
13081*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.autograd.is_multithreading_enabled())
13082*da0073e9SAndroid Build Coastguard Worker
13083*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
13084*da0073e9SAndroid Build Coastguard Worker            torch.autograd.set_multithreading_enabled(True)
13085*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.autograd.is_multithreading_enabled())
13086*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.autograd.is_multithreading_enabled())
13087*da0073e9SAndroid Build Coastguard Worker
13088*da0073e9SAndroid Build Coastguard Worker        torch.autograd.set_multithreading_enabled(False)
13089*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.autograd.is_multithreading_enabled())
13090*da0073e9SAndroid Build Coastguard Worker
13091*da0073e9SAndroid Build Coastguard Worker        torch.autograd.set_multithreading_enabled(True)
13092*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.autograd.is_multithreading_enabled())
13093*da0073e9SAndroid Build Coastguard Worker
13094*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
13095*da0073e9SAndroid Build Coastguard Worker    def test_custom_function_propagates_errors_from_device_thread(self):
13096*da0073e9SAndroid Build Coastguard Worker        class MyFunc(Function):
13097*da0073e9SAndroid Build Coastguard Worker            @staticmethod
13098*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
13099*da0073e9SAndroid Build Coastguard Worker                return x
13100*da0073e9SAndroid Build Coastguard Worker
13101*da0073e9SAndroid Build Coastguard Worker            @staticmethod
13102*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
13103*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("blah")
13104*da0073e9SAndroid Build Coastguard Worker                return gO
13105*da0073e9SAndroid Build Coastguard Worker
13106*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda"))
13107*da0073e9SAndroid Build Coastguard Worker        out = MyFunc.apply(t).sum()
13108*da0073e9SAndroid Build Coastguard Worker
13109*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "blah"):
13110*da0073e9SAndroid Build Coastguard Worker            out.backward()
13111*da0073e9SAndroid Build Coastguard Worker
13112*da0073e9SAndroid Build Coastguard Worker
13113*da0073e9SAndroid Build Coastguard Workerclass TestNestedCheckpoint(TestCase):
13114*da0073e9SAndroid Build Coastguard Worker    @staticmethod
13115*da0073e9SAndroid Build Coastguard Worker    def grad(fn):
13116*da0073e9SAndroid Build Coastguard Worker        def wrapper(x):
13117*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
13118*da0073e9SAndroid Build Coastguard Worker                out = fn(x)
13119*da0073e9SAndroid Build Coastguard Worker                (grad_input,) = torch.autograd.grad(out, inputs=(x,), create_graph=True)
13120*da0073e9SAndroid Build Coastguard Worker            return grad_input
13121*da0073e9SAndroid Build Coastguard Worker
13122*da0073e9SAndroid Build Coastguard Worker        return wrapper
13123*da0073e9SAndroid Build Coastguard Worker
13124*da0073e9SAndroid Build Coastguard Worker    @staticmethod
13125*da0073e9SAndroid Build Coastguard Worker    def sum(fn):
13126*da0073e9SAndroid Build Coastguard Worker        def wrapped(x):
13127*da0073e9SAndroid Build Coastguard Worker            return fn(x).sum()
13128*da0073e9SAndroid Build Coastguard Worker
13129*da0073e9SAndroid Build Coastguard Worker        return wrapped
13130*da0073e9SAndroid Build Coastguard Worker
13131*da0073e9SAndroid Build Coastguard Worker    @staticmethod
13132*da0073e9SAndroid Build Coastguard Worker    def checkpoint(fn):
13133*da0073e9SAndroid Build Coastguard Worker        def wrapped(*args, **kwargs):
13134*da0073e9SAndroid Build Coastguard Worker            return torch.utils.checkpoint.checkpoint(
13135*da0073e9SAndroid Build Coastguard Worker                fn, *args, use_reentrant=False, **kwargs
13136*da0073e9SAndroid Build Coastguard Worker            )
13137*da0073e9SAndroid Build Coastguard Worker
13138*da0073e9SAndroid Build Coastguard Worker        return wrapped
13139*da0073e9SAndroid Build Coastguard Worker
13140*da0073e9SAndroid Build Coastguard Worker    def get_tests(self, fn):
13141*da0073e9SAndroid Build Coastguard Worker        grad, c = self.grad, self.checkpoint
13142*da0073e9SAndroid Build Coastguard Worker
13143*da0073e9SAndroid Build Coastguard Worker        tests = (
13144*da0073e9SAndroid Build Coastguard Worker            # function <> tuple of function arbitrarily wrapped in checkpoint in various ways
13145*da0073e9SAndroid Build Coastguard Worker            (fn, (c(fn), c(c(fn)))),
13146*da0073e9SAndroid Build Coastguard Worker            (grad(fn), (grad(c(fn)), grad(c(c(fn))))),
13147*da0073e9SAndroid Build Coastguard Worker            (
13148*da0073e9SAndroid Build Coastguard Worker                grad(grad(fn)),
13149*da0073e9SAndroid Build Coastguard Worker                (grad(c(grad(fn))), c(grad(grad(c(fn)))), grad(c(grad(c(fn))))),
13150*da0073e9SAndroid Build Coastguard Worker            ),
13151*da0073e9SAndroid Build Coastguard Worker            (
13152*da0073e9SAndroid Build Coastguard Worker                grad(grad(grad(fn))),
13153*da0073e9SAndroid Build Coastguard Worker                (grad(c(grad(grad(c(fn))))), grad(c(grad(c(grad(c(fn))))))),
13154*da0073e9SAndroid Build Coastguard Worker            ),
13155*da0073e9SAndroid Build Coastguard Worker        )
13156*da0073e9SAndroid Build Coastguard Worker        return tests
13157*da0073e9SAndroid Build Coastguard Worker
13158*da0073e9SAndroid Build Coastguard Worker    def check_graph_dies(self, fn):
13159*da0073e9SAndroid Build Coastguard Worker        def iter_graph(roots):
13160*da0073e9SAndroid Build Coastguard Worker            if not roots:
13161*da0073e9SAndroid Build Coastguard Worker                return
13162*da0073e9SAndroid Build Coastguard Worker            seen = set()
13163*da0073e9SAndroid Build Coastguard Worker            q = collections.deque()
13164*da0073e9SAndroid Build Coastguard Worker            for node in roots:
13165*da0073e9SAndroid Build Coastguard Worker                if node is not None:
13166*da0073e9SAndroid Build Coastguard Worker                    seen.add(node)
13167*da0073e9SAndroid Build Coastguard Worker                    q.append(node)
13168*da0073e9SAndroid Build Coastguard Worker
13169*da0073e9SAndroid Build Coastguard Worker            while q:
13170*da0073e9SAndroid Build Coastguard Worker                node = q.popleft()
13171*da0073e9SAndroid Build Coastguard Worker                for fn, _idx in node.next_functions:
13172*da0073e9SAndroid Build Coastguard Worker                    if fn in seen or fn is None:
13173*da0073e9SAndroid Build Coastguard Worker                        continue
13174*da0073e9SAndroid Build Coastguard Worker                    seen.add(fn)
13175*da0073e9SAndroid Build Coastguard Worker                    q.append(fn)
13176*da0073e9SAndroid Build Coastguard Worker
13177*da0073e9SAndroid Build Coastguard Worker                yield node
13178*da0073e9SAndroid Build Coastguard Worker
13179*da0073e9SAndroid Build Coastguard Worker        class Handle:
13180*da0073e9SAndroid Build Coastguard Worker            __slot__ = ["node_name"]
13181*da0073e9SAndroid Build Coastguard Worker
13182*da0073e9SAndroid Build Coastguard Worker            def __init__(self, node_name):
13183*da0073e9SAndroid Build Coastguard Worker                self.node_name = node_name
13184*da0073e9SAndroid Build Coastguard Worker
13185*da0073e9SAndroid Build Coastguard Worker        def scope():
13186*da0073e9SAndroid Build Coastguard Worker            a = torch.randn((), requires_grad=True)
13187*da0073e9SAndroid Build Coastguard Worker            out = fn(a)
13188*da0073e9SAndroid Build Coastguard Worker            refs = []
13189*da0073e9SAndroid Build Coastguard Worker            for node in iter_graph([out.grad_fn]):
13190*da0073e9SAndroid Build Coastguard Worker                handle = Handle(node.name())
13191*da0073e9SAndroid Build Coastguard Worker                refs.append(weakref.ref(handle))
13192*da0073e9SAndroid Build Coastguard Worker                node.metadata["blah"] = handle
13193*da0073e9SAndroid Build Coastguard Worker            return refs
13194*da0073e9SAndroid Build Coastguard Worker
13195*da0073e9SAndroid Build Coastguard Worker        refs = scope()
13196*da0073e9SAndroid Build Coastguard Worker        node_names = [ref().node_name for ref in refs if ref() is not None]
13197*da0073e9SAndroid Build Coastguard Worker        if len(node_names) > 0:
13198*da0073e9SAndroid Build Coastguard Worker            print("Nodes still alive:", node_names)
13199*da0073e9SAndroid Build Coastguard Worker
13200*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(node_names), 0)
13201*da0073e9SAndroid Build Coastguard Worker
13202*da0073e9SAndroid Build Coastguard Worker    @parametrize("early_stop", [True, False])
13203*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint(self, early_stop):
13204*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13205*da0073e9SAndroid Build Coastguard Worker            x = torch.randn((), requires_grad=True)
13206*da0073e9SAndroid Build Coastguard Worker
13207*da0073e9SAndroid Build Coastguard Worker            def f(x):
13208*da0073e9SAndroid Build Coastguard Worker                out = x.sin().exp().sin()
13209*da0073e9SAndroid Build Coastguard Worker                return out
13210*da0073e9SAndroid Build Coastguard Worker
13211*da0073e9SAndroid Build Coastguard Worker            def g(x):
13212*da0073e9SAndroid Build Coastguard Worker                a = x.sin().exp().sin()
13213*da0073e9SAndroid Build Coastguard Worker                b = x.sin().exp().sin()
13214*da0073e9SAndroid Build Coastguard Worker                (ga,) = torch.autograd.grad(a, x)
13215*da0073e9SAndroid Build Coastguard Worker                (gb,) = torch.autograd.grad(b, x)
13216*da0073e9SAndroid Build Coastguard Worker                return x.sin()
13217*da0073e9SAndroid Build Coastguard Worker
13218*da0073e9SAndroid Build Coastguard Worker            for fn in (f, g):
13219*da0073e9SAndroid Build Coastguard Worker                for expected_fn, actual_fns in self.get_tests(fn):
13220*da0073e9SAndroid Build Coastguard Worker                    expected = expected_fn(x)
13221*da0073e9SAndroid Build Coastguard Worker
13222*da0073e9SAndroid Build Coastguard Worker                    for actual_fn in actual_fns:
13223*da0073e9SAndroid Build Coastguard Worker                        actual = actual_fn(x)
13224*da0073e9SAndroid Build Coastguard Worker                        self.assertTrue(torch.allclose(expected, actual))
13225*da0073e9SAndroid Build Coastguard Worker                        self.check_graph_dies(actual_fn)
13226*da0073e9SAndroid Build Coastguard Worker
13227*da0073e9SAndroid Build Coastguard Worker    @parametrize("early_stop", [True, False])
13228*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_two_children(self, early_stop):
13229*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13230*da0073e9SAndroid Build Coastguard Worker            grad, sum, c = self.grad, self.sum, self.checkpoint
13231*da0073e9SAndroid Build Coastguard Worker
13232*da0073e9SAndroid Build Coastguard Worker            def f(x):
13233*da0073e9SAndroid Build Coastguard Worker                return x.sin().exp().sin()
13234*da0073e9SAndroid Build Coastguard Worker
13235*da0073e9SAndroid Build Coastguard Worker            def g(x):
13236*da0073e9SAndroid Build Coastguard Worker                return x.cos().sin().exp()
13237*da0073e9SAndroid Build Coastguard Worker
13238*da0073e9SAndroid Build Coastguard Worker            def hc(x):
13239*da0073e9SAndroid Build Coastguard Worker                return c(g)(c(f)(x))
13240*da0073e9SAndroid Build Coastguard Worker
13241*da0073e9SAndroid Build Coastguard Worker            def h(x):
13242*da0073e9SAndroid Build Coastguard Worker                return g(f(x))
13243*da0073e9SAndroid Build Coastguard Worker
13244*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(3, 3, requires_grad=True)
13245*da0073e9SAndroid Build Coastguard Worker            expected = grad(sum(grad(sum(h))))(a)
13246*da0073e9SAndroid Build Coastguard Worker            actual = grad(sum(grad(sum(c(hc)))))(a)
13247*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(expected, actual))
13248*da0073e9SAndroid Build Coastguard Worker
13249*da0073e9SAndroid Build Coastguard Worker            actual = grad(sum(c(grad(sum(c(hc))))))(a)
13250*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(expected, actual))
13251*da0073e9SAndroid Build Coastguard Worker
13252*da0073e9SAndroid Build Coastguard Worker            self.check_graph_dies(grad(c(hc)))
13253*da0073e9SAndroid Build Coastguard Worker            self.check_graph_dies(grad(sum(grad(sum(c(hc))))))
13254*da0073e9SAndroid Build Coastguard Worker            self.check_graph_dies(grad(sum(c(grad(sum(c(hc)))))))
13255*da0073e9SAndroid Build Coastguard Worker
13256*da0073e9SAndroid Build Coastguard Worker    @parametrize("early_stop", [True, False])
13257*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_non_tensor_inputs_and_outputs(self, early_stop):
13258*da0073e9SAndroid Build Coastguard Worker        def fn(k, a, b, f):
13259*da0073e9SAndroid Build Coastguard Worker            return f(k * a * b.exp()), 1, "abcd"
13260*da0073e9SAndroid Build Coastguard Worker
13261*da0073e9SAndroid Build Coastguard Worker        k = 3
13262*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(2.0, requires_grad=True)
13263*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(3.0, requires_grad=True)
13264*da0073e9SAndroid Build Coastguard Worker
13265*da0073e9SAndroid Build Coastguard Worker        def f(x):
13266*da0073e9SAndroid Build Coastguard Worker            return x.sin()
13267*da0073e9SAndroid Build Coastguard Worker
13268*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13269*da0073e9SAndroid Build Coastguard Worker            out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False)
13270*da0073e9SAndroid Build Coastguard Worker        actual_grads = torch.autograd.grad(out, (a, b))
13271*da0073e9SAndroid Build Coastguard Worker
13272*da0073e9SAndroid Build Coastguard Worker        out, _unused1, _unused2 = fn(k, a, b, f)
13273*da0073e9SAndroid Build Coastguard Worker        expected_grads = torch.autograd.grad(out, (a, b))
13274*da0073e9SAndroid Build Coastguard Worker        for actual, expected in zip(actual_grads, expected_grads):
13275*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(torch.allclose(actual, expected))
13276*da0073e9SAndroid Build Coastguard Worker
13277*da0073e9SAndroid Build Coastguard Worker    @parametrize("early_stop", [True, False])
13278*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_kwargs(self, early_stop):
13279*da0073e9SAndroid Build Coastguard Worker        def fn(a, blah=None):
13280*da0073e9SAndroid Build Coastguard Worker            out = a.sin().exp()
13281*da0073e9SAndroid Build Coastguard Worker            if blah is not None:
13282*da0073e9SAndroid Build Coastguard Worker                out = out * blah
13283*da0073e9SAndroid Build Coastguard Worker            return out.sin().exp()
13284*da0073e9SAndroid Build Coastguard Worker
13285*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(2.0, requires_grad=True)
13286*da0073e9SAndroid Build Coastguard Worker        b = torch.tensor(3.0, requires_grad=True)
13287*da0073e9SAndroid Build Coastguard Worker
13288*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13289*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, blah=b, use_reentrant=False)
13290*da0073e9SAndroid Build Coastguard Worker            actual_grads = torch.autograd.grad(out, (a, b))
13291*da0073e9SAndroid Build Coastguard Worker
13292*da0073e9SAndroid Build Coastguard Worker            out = fn(a, blah=b)
13293*da0073e9SAndroid Build Coastguard Worker            expected_grads = torch.autograd.grad(out, (a, b))
13294*da0073e9SAndroid Build Coastguard Worker            for actual, expected in zip(actual_grads, expected_grads):
13295*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.allclose(actual, expected))
13296*da0073e9SAndroid Build Coastguard Worker
13297*da0073e9SAndroid Build Coastguard Worker    @parametrize("early_stop", [True, False])
13298*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_same_graph(self, early_stop):
13299*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13300*da0073e9SAndroid Build Coastguard Worker
13301*da0073e9SAndroid Build Coastguard Worker        def hook(*_unused_args):
13302*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
13303*da0073e9SAndroid Build Coastguard Worker
13304*da0073e9SAndroid Build Coastguard Worker        def fn(a):
13305*da0073e9SAndroid Build Coastguard Worker            return a.sin().cos().sin()
13306*da0073e9SAndroid Build Coastguard Worker
13307*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13308*da0073e9SAndroid Build Coastguard Worker
13309*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13310*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=False)
13311*da0073e9SAndroid Build Coastguard Worker        # The hook is registered on the original graph
13312*da0073e9SAndroid Build Coastguard Worker        out.grad_fn.next_functions[0][0].register_hook(hook)
13313*da0073e9SAndroid Build Coastguard Worker        # And backward is performed on the original graph
13314*da0073e9SAndroid Build Coastguard Worker        out.backward()
13315*da0073e9SAndroid Build Coastguard Worker
13316*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
13317*da0073e9SAndroid Build Coastguard Worker
13318*da0073e9SAndroid Build Coastguard Worker    @parametrize("early_stop", [True, False])
13319*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_reentrant_backwards(self, early_stop):
13320*da0073e9SAndroid Build Coastguard Worker        def fn(a):
13321*da0073e9SAndroid Build Coastguard Worker            x = a.sin().cos()
13322*da0073e9SAndroid Build Coastguard Worker            out = x.sin()
13323*da0073e9SAndroid Build Coastguard Worker            return x, out
13324*da0073e9SAndroid Build Coastguard Worker
13325*da0073e9SAndroid Build Coastguard Worker        def hook(*_unused_args):
13326*da0073e9SAndroid Build Coastguard Worker            # do backward again, but skip over the part of the graph where
13327*da0073e9SAndroid Build Coastguard Worker            # the hook was registered
13328*da0073e9SAndroid Build Coastguard Worker            x.backward(retain_graph=True)
13329*da0073e9SAndroid Build Coastguard Worker
13330*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13331*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop):
13332*da0073e9SAndroid Build Coastguard Worker            x, out = checkpoint(fn, a, use_reentrant=False)
13333*da0073e9SAndroid Build Coastguard Worker        out.grad_fn.register_hook(hook)
13334*da0073e9SAndroid Build Coastguard Worker        out.backward(retain_graph=True)
13335*da0073e9SAndroid Build Coastguard Worker
13336*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_set_early_stop(self):
13337*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13338*da0073e9SAndroid Build Coastguard Worker
13339*da0073e9SAndroid Build Coastguard Worker        def clone(x):
13340*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
13341*da0073e9SAndroid Build Coastguard Worker            return x.clone()
13342*da0073e9SAndroid Build Coastguard Worker
13343*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13344*da0073e9SAndroid Build Coastguard Worker            # Since clone does not save anything, it is not recomputed iff
13345*da0073e9SAndroid Build Coastguard Worker            # early stop is enabled.
13346*da0073e9SAndroid Build Coastguard Worker            return clone(x.sin().cos())
13347*da0073e9SAndroid Build Coastguard Worker
13348*da0073e9SAndroid Build Coastguard Worker        # Early stopping is enabled by default
13349*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13350*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, a, use_reentrant=False)
13351*da0073e9SAndroid Build Coastguard Worker        out.backward()
13352*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
13353*da0073e9SAndroid Build Coastguard Worker
13354*da0073e9SAndroid Build Coastguard Worker        # Try using the context manager to set early stopping to False.
13355*da0073e9SAndroid Build Coastguard Worker        # Expect early stopping to be disabled for all checkpoints ran under
13356*da0073e9SAndroid Build Coastguard Worker        # the context manager, even though context manager is no longer active
13357*da0073e9SAndroid Build Coastguard Worker        # when backward/recomputation is performed.
13358*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13359*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13360*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(False):
13361*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=False)
13362*da0073e9SAndroid Build Coastguard Worker
13363*da0073e9SAndroid Build Coastguard Worker        out.backward()
13364*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 2)
13365*da0073e9SAndroid Build Coastguard Worker
13366*da0073e9SAndroid Build Coastguard Worker    def test_nested_checkpoint_set_early_stop_no_recompution_needed(self):
13367*da0073e9SAndroid Build Coastguard Worker        # Case 1: We have one tensor saved and its the input
13368*da0073e9SAndroid Build Coastguard Worker
13369*da0073e9SAndroid Build Coastguard Worker        # We have two different counters here because in this case we actually
13370*da0073e9SAndroid Build Coastguard Worker        # do call into x.sin() at the python level during recomputation whether
13371*da0073e9SAndroid Build Coastguard Worker        # or not early stop is enabled. This is because the early stopping
13372*da0073e9SAndroid Build Coastguard Worker        # only happens at the autograd level (preventing us from reaching the
13373*da0073e9SAndroid Build Coastguard Worker        # backend).
13374*da0073e9SAndroid Build Coastguard Worker        python_dispatch_counter = [0]
13375*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13376*da0073e9SAndroid Build Coastguard Worker
13377*da0073e9SAndroid Build Coastguard Worker        class SinCounterMode(TorchDispatchMode):
13378*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
13379*da0073e9SAndroid Build Coastguard Worker                self.count = 0
13380*da0073e9SAndroid Build Coastguard Worker
13381*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
13382*da0073e9SAndroid Build Coastguard Worker                kwargs = {} if kwargs is None else kwargs
13383*da0073e9SAndroid Build Coastguard Worker                if func is torch.ops.aten.sin.default:
13384*da0073e9SAndroid Build Coastguard Worker                    self.count += 1
13385*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
13386*da0073e9SAndroid Build Coastguard Worker
13387*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13388*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
13389*da0073e9SAndroid Build Coastguard Worker            return x.sin()
13390*da0073e9SAndroid Build Coastguard Worker
13391*da0073e9SAndroid Build Coastguard Worker        # With early stopping (enabled by default)
13392*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13393*da0073e9SAndroid Build Coastguard Worker        with SinCounterMode() as python_dispatch_counter:  # noqa: F811
13394*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, a, use_reentrant=False)
13395*da0073e9SAndroid Build Coastguard Worker            out.backward()
13396*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 2)
13397*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(python_dispatch_counter.count, 1)
13398*da0073e9SAndroid Build Coastguard Worker
13399*da0073e9SAndroid Build Coastguard Worker        # Without early stopping
13400*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13401*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13402*da0073e9SAndroid Build Coastguard Worker        with SinCounterMode() as python_dispatch_counter:
13403*da0073e9SAndroid Build Coastguard Worker            with torch.utils.checkpoint.set_checkpoint_early_stop(False):
13404*da0073e9SAndroid Build Coastguard Worker                out = checkpoint(fn, a, use_reentrant=False)
13405*da0073e9SAndroid Build Coastguard Worker            out.backward()
13406*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 2)
13407*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(python_dispatch_counter.count, 2)
13408*da0073e9SAndroid Build Coastguard Worker
13409*da0073e9SAndroid Build Coastguard Worker        # Case 2: Forward saves no tensors
13410*da0073e9SAndroid Build Coastguard Worker
13411*da0073e9SAndroid Build Coastguard Worker        # Since unpack isn't even called, counter is 1 whether or not early stop
13412*da0073e9SAndroid Build Coastguard Worker        # is enabled!
13413*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13414*da0073e9SAndroid Build Coastguard Worker
13415*da0073e9SAndroid Build Coastguard Worker        def fn2(x):
13416*da0073e9SAndroid Build Coastguard Worker            counter[0] += 1
13417*da0073e9SAndroid Build Coastguard Worker            return x.clone()
13418*da0073e9SAndroid Build Coastguard Worker
13419*da0073e9SAndroid Build Coastguard Worker        # With early stopping (enabled by default)
13420*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13421*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn2, a, use_reentrant=False)
13422*da0073e9SAndroid Build Coastguard Worker        out.backward()
13423*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
13424*da0073e9SAndroid Build Coastguard Worker
13425*da0073e9SAndroid Build Coastguard Worker        # Without early stopping
13426*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13427*da0073e9SAndroid Build Coastguard Worker        a = torch.tensor(1.0, requires_grad=True)
13428*da0073e9SAndroid Build Coastguard Worker        with torch.utils.checkpoint.set_checkpoint_early_stop(False):
13429*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn2, a, use_reentrant=False)
13430*da0073e9SAndroid Build Coastguard Worker        out.backward()
13431*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 1)
13432*da0073e9SAndroid Build Coastguard Worker
13433*da0073e9SAndroid Build Coastguard Worker
13434*da0073e9SAndroid Build Coastguard Workerclass TestSelectiveActivationCheckpoint(TestCase):
13435*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "requires CUDA")
13436*da0073e9SAndroid Build Coastguard Worker    def test_flops_and_mem(self):
13437*da0073e9SAndroid Build Coastguard Worker        # From https://github.com/pytorch/pytorch/pull/126320
13438*da0073e9SAndroid Build Coastguard Worker        def get_act_mem(f):
13439*da0073e9SAndroid Build Coastguard Worker            out = f()
13440*da0073e9SAndroid Build Coastguard Worker            out.backward()
13441*da0073e9SAndroid Build Coastguard Worker            # Why do one forward and backward?
13442*da0073e9SAndroid Build Coastguard Worker            start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
13443*da0073e9SAndroid Build Coastguard Worker            out = f()
13444*da0073e9SAndroid Build Coastguard Worker            cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"]
13445*da0073e9SAndroid Build Coastguard Worker            act_mem = (cur_mem - start_mem) / (1024 * 1024)
13446*da0073e9SAndroid Build Coastguard Worker            out.backward()
13447*da0073e9SAndroid Build Coastguard Worker            return act_mem
13448*da0073e9SAndroid Build Coastguard Worker
13449*da0073e9SAndroid Build Coastguard Worker        def get_bw_flops(f):
13450*da0073e9SAndroid Build Coastguard Worker            # Normalized so that a 512 square matmul returns 1
13451*da0073e9SAndroid Build Coastguard Worker            f().backward()
13452*da0073e9SAndroid Build Coastguard Worker            out = f()
13453*da0073e9SAndroid Build Coastguard Worker            # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so
13454*da0073e9SAndroid Build Coastguard Worker            # it will be able to observe whether an op is cached or not.
13455*da0073e9SAndroid Build Coastguard Worker            with FlopCounterMode(display=False) as mode:
13456*da0073e9SAndroid Build Coastguard Worker                out.backward()
13457*da0073e9SAndroid Build Coastguard Worker            return mode.get_total_flops() / (512**3 * 2)
13458*da0073e9SAndroid Build Coastguard Worker
13459*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(512, 512, requires_grad=True, device="cuda")
13460*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(512, 512, requires_grad=True, device="cuda")
13461*da0073e9SAndroid Build Coastguard Worker
13462*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
13463*da0073e9SAndroid Build Coastguard Worker            return torch.mm(x.cos(), y).sin().sum()
13464*da0073e9SAndroid Build Coastguard Worker
13465*da0073e9SAndroid Build Coastguard Worker        def fn_ac(x, y):
13466*da0073e9SAndroid Build Coastguard Worker            return checkpoint(fn, x, y, use_reentrant=False)
13467*da0073e9SAndroid Build Coastguard Worker
13468*da0073e9SAndroid Build Coastguard Worker        def fn_sac(x, y):
13469*da0073e9SAndroid Build Coastguard Worker            context_fn = functools.partial(
13470*da0073e9SAndroid Build Coastguard Worker                create_selective_checkpoint_contexts,
13471*da0073e9SAndroid Build Coastguard Worker                [torch.ops.aten.mm.default],
13472*da0073e9SAndroid Build Coastguard Worker            )
13473*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13474*da0073e9SAndroid Build Coastguard Worker            return out
13475*da0073e9SAndroid Build Coastguard Worker
13476*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
13477*da0073e9SAndroid Build Coastguard Worker            if op == torch.ops.aten.mm.default:
13478*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.MUST_SAVE
13479*da0073e9SAndroid Build Coastguard Worker            else:
13480*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
13481*da0073e9SAndroid Build Coastguard Worker
13482*da0073e9SAndroid Build Coastguard Worker        def fn_sac2(x, y):
13483*da0073e9SAndroid Build Coastguard Worker            context_fn = functools.partial(
13484*da0073e9SAndroid Build Coastguard Worker                create_selective_checkpoint_contexts,
13485*da0073e9SAndroid Build Coastguard Worker                policy_fn,
13486*da0073e9SAndroid Build Coastguard Worker            )
13487*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13488*da0073e9SAndroid Build Coastguard Worker            return out
13489*da0073e9SAndroid Build Coastguard Worker
13490*da0073e9SAndroid Build Coastguard Worker        def policy_fn_bool(ctx, op, *args, **kwargs):
13491*da0073e9SAndroid Build Coastguard Worker            return op == torch.ops.aten.mm.default
13492*da0073e9SAndroid Build Coastguard Worker
13493*da0073e9SAndroid Build Coastguard Worker        def fn_sac3(x, y):
13494*da0073e9SAndroid Build Coastguard Worker            context_fn = functools.partial(
13495*da0073e9SAndroid Build Coastguard Worker                create_selective_checkpoint_contexts,
13496*da0073e9SAndroid Build Coastguard Worker                policy_fn_bool,
13497*da0073e9SAndroid Build Coastguard Worker            )
13498*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13499*da0073e9SAndroid Build Coastguard Worker            return out
13500*da0073e9SAndroid Build Coastguard Worker
13501*da0073e9SAndroid Build Coastguard Worker        act_mem_noac = get_act_mem(lambda: fn(x, y))
13502*da0073e9SAndroid Build Coastguard Worker        bw_flops_noac = get_bw_flops(lambda: fn(x, y))
13503*da0073e9SAndroid Build Coastguard Worker
13504*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(act_mem_noac, 2.0)
13505*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bw_flops_noac, 2.0)
13506*da0073e9SAndroid Build Coastguard Worker
13507*da0073e9SAndroid Build Coastguard Worker        act_mem_ac = get_act_mem(lambda: fn_ac(x, y))
13508*da0073e9SAndroid Build Coastguard Worker        bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y))
13509*da0073e9SAndroid Build Coastguard Worker
13510*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(act_mem_ac, 0.0)
13511*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bw_flops_ac, 3.0)
13512*da0073e9SAndroid Build Coastguard Worker
13513*da0073e9SAndroid Build Coastguard Worker        act_mem_sac = get_act_mem(lambda: fn_sac(x, y))
13514*da0073e9SAndroid Build Coastguard Worker        bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y))
13515*da0073e9SAndroid Build Coastguard Worker
13516*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(act_mem_sac, 1.0)
13517*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bw_flops_sac, 2.0)
13518*da0073e9SAndroid Build Coastguard Worker
13519*da0073e9SAndroid Build Coastguard Worker        act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y))
13520*da0073e9SAndroid Build Coastguard Worker        bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y))
13521*da0073e9SAndroid Build Coastguard Worker
13522*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(act_mem_sac2, 1.0)
13523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bw_flops_sac2, 2.0)
13524*da0073e9SAndroid Build Coastguard Worker
13525*da0073e9SAndroid Build Coastguard Worker        act_mem_sac3 = get_act_mem(lambda: fn_sac3(x, y))
13526*da0073e9SAndroid Build Coastguard Worker        bw_flops_sac3 = get_bw_flops(lambda: fn_sac3(x, y))
13527*da0073e9SAndroid Build Coastguard Worker
13528*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(act_mem_sac3, 1.0)
13529*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(bw_flops_sac3, 2.0)
13530*da0073e9SAndroid Build Coastguard Worker
13531*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13532*da0073e9SAndroid Build Coastguard Worker    def test_output_already_has_autograd_meta(self):
13533*da0073e9SAndroid Build Coastguard Worker        # View of tensor of non-differentiable dtype still has AutogradMeta
13534*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
13535*da0073e9SAndroid Build Coastguard Worker            return x.view(-1), y.sin().cos()
13536*da0073e9SAndroid Build Coastguard Worker
13537*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1, 2, 3], dtype=torch.int64)
13538*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, requires_grad=True)
13539*da0073e9SAndroid Build Coastguard Worker
13540*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(
13541*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts,
13542*da0073e9SAndroid Build Coastguard Worker            [torch.ops.aten.view.default],
13543*da0073e9SAndroid Build Coastguard Worker        )
13544*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn)
13545*da0073e9SAndroid Build Coastguard Worker        out[1].sum().backward()
13546*da0073e9SAndroid Build Coastguard Worker
13547*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13548*da0073e9SAndroid Build Coastguard Worker    def test_subclass_dispatching_sizes(self):
13549*da0073e9SAndroid Build Coastguard Worker        # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default
13550*da0073e9SAndroid Build Coastguard Worker        # Caching such metadata ops can be problematic when the following are satisfied:
13551*da0073e9SAndroid Build Coastguard Worker        #
13552*da0073e9SAndroid Build Coastguard Worker        # 1. size/strides are dispatched upon
13553*da0073e9SAndroid Build Coastguard Worker        # 2. our policy saves sizes
13554*da0073e9SAndroid Build Coastguard Worker        ta = torch.randn(6, 2)
13555*da0073e9SAndroid Build Coastguard Worker
13556*da0073e9SAndroid Build Coastguard Worker        class CustomSizeDynamicShapesTensor(torch.Tensor):
13557*da0073e9SAndroid Build Coastguard Worker            @staticmethod
13558*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, inner):
13559*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_wrapper_subclass(
13560*da0073e9SAndroid Build Coastguard Worker                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
13561*da0073e9SAndroid Build Coastguard Worker                    # Calling the overload that has kwargs causes us to go down the first overload path,
13562*da0073e9SAndroid Build Coastguard Worker                    # which will **always** specialize sizes.
13563*da0073e9SAndroid Build Coastguard Worker                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
13564*da0073e9SAndroid Build Coastguard Worker                    cls,
13565*da0073e9SAndroid Build Coastguard Worker                    inner.size(),
13566*da0073e9SAndroid Build Coastguard Worker                    inner.stride(),
13567*da0073e9SAndroid Build Coastguard Worker                    None,
13568*da0073e9SAndroid Build Coastguard Worker                    None,
13569*da0073e9SAndroid Build Coastguard Worker                    inner.dtype,
13570*da0073e9SAndroid Build Coastguard Worker                    inner.layout,
13571*da0073e9SAndroid Build Coastguard Worker                    inner.device,
13572*da0073e9SAndroid Build Coastguard Worker                    False,
13573*da0073e9SAndroid Build Coastguard Worker                    inner.requires_grad,
13574*da0073e9SAndroid Build Coastguard Worker                    "sizes",
13575*da0073e9SAndroid Build Coastguard Worker                )
13576*da0073e9SAndroid Build Coastguard Worker
13577*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
13578*da0073e9SAndroid Build Coastguard Worker                self.inner = inner
13579*da0073e9SAndroid Build Coastguard Worker
13580*da0073e9SAndroid Build Coastguard Worker            @classmethod
13581*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args, kwargs):
13582*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
13583*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
13584*da0073e9SAndroid Build Coastguard Worker                args_inner = torch.utils._pytree.tree_map_only(
13585*da0073e9SAndroid Build Coastguard Worker                    cls, lambda x: x.inner, args
13586*da0073e9SAndroid Build Coastguard Worker                )
13587*da0073e9SAndroid Build Coastguard Worker                out_inner = func(*args_inner, **kwargs)
13588*da0073e9SAndroid Build Coastguard Worker                return torch.utils._pytree.tree_map_only(
13589*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor, lambda x: cls(x), out_inner
13590*da0073e9SAndroid Build Coastguard Worker                )
13591*da0073e9SAndroid Build Coastguard Worker
13592*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
13593*da0073e9SAndroid Build Coastguard Worker            if op is torch.ops.aten.sym_size.default:
13594*da0073e9SAndroid Build Coastguard Worker                # Silently ignored!
13595*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.MUST_SAVE
13596*da0073e9SAndroid Build Coastguard Worker            else:
13597*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
13598*da0073e9SAndroid Build Coastguard Worker
13599*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13600*da0073e9SAndroid Build Coastguard Worker            # We avoid the following case
13601*da0073e9SAndroid Build Coastguard Worker            #
13602*da0073e9SAndroid Build Coastguard Worker            # saved     :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12]
13603*da0073e9SAndroid Build Coastguard Worker            # forward   :sum   ,sum,mul, mul   , mul   ,view   , view
13604*da0073e9SAndroid Build Coastguard Worker            # recompute :sum   ,sum,mul, view  , view
13605*da0073e9SAndroid Build Coastguard Worker            #
13606*da0073e9SAndroid Build Coastguard Worker            # Views save the shape of their input, so we expect the second
13607*da0073e9SAndroid Build Coastguard Worker            # view to save 12, but because during AC packing during forward
13608*da0073e9SAndroid Build Coastguard Worker            # saves the shapes of the input for metadata checks later,
13609*da0073e9SAndroid Build Coastguard Worker            # we would save the wrong shape during the recompute.
13610*da0073e9SAndroid Build Coastguard Worker            view_out = (x * x.sum()).view(-1).view(4, 3)
13611*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12])
13612*da0073e9SAndroid Build Coastguard Worker            return view_out.exp()
13613*da0073e9SAndroid Build Coastguard Worker
13614*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 3, requires_grad=True)
13615*da0073e9SAndroid Build Coastguard Worker        x_wrapper = CustomSizeDynamicShapesTensor(x)
13616*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13617*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn)
13618*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
13619*da0073e9SAndroid Build Coastguard Worker
13620*da0073e9SAndroid Build Coastguard Worker    def test_bad_inputs(self):
13621*da0073e9SAndroid Build Coastguard Worker        bad_op_list1 = [2]
13622*da0073e9SAndroid Build Coastguard Worker
13623*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
13624*da0073e9SAndroid Build Coastguard Worker            ValueError, "Expected op in `op_list` to be an OpOverload"
13625*da0073e9SAndroid Build Coastguard Worker        ):
13626*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts(bad_op_list1)
13627*da0073e9SAndroid Build Coastguard Worker
13628*da0073e9SAndroid Build Coastguard Worker        bad_op_list2 = [torch.ops.aten.sin]
13629*da0073e9SAndroid Build Coastguard Worker
13630*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
13631*da0073e9SAndroid Build Coastguard Worker            ValueError, "update the OpOverloadPacket to a specific OpOverload"
13632*da0073e9SAndroid Build Coastguard Worker        ):
13633*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts(bad_op_list2)
13634*da0073e9SAndroid Build Coastguard Worker
13635*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "either a function or a list of ops."):
13636*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts(2)
13637*da0073e9SAndroid Build Coastguard Worker
13638*da0073e9SAndroid Build Coastguard Worker    # Dynamo fails for various reasons:
13639*da0073e9SAndroid Build Coastguard Worker    # - some tests using custom op that does not implement Fake
13640*da0073e9SAndroid Build Coastguard Worker    # - dynamo is trying to trace into saved variable hooks unpack hook for some reason
13641*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13642*da0073e9SAndroid Build Coastguard Worker    def test_policy_with_state(self):
13643*da0073e9SAndroid Build Coastguard Worker        # If I have a stateful callable, state is shared between the original
13644*da0073e9SAndroid Build Coastguard Worker        # forward and the recompute.
13645*da0073e9SAndroid Build Coastguard Worker        counters = []
13646*da0073e9SAndroid Build Coastguard Worker
13647*da0073e9SAndroid Build Coastguard Worker        class Policy:
13648*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
13649*da0073e9SAndroid Build Coastguard Worker                self.counter = [0]
13650*da0073e9SAndroid Build Coastguard Worker                self.recompute_counter = [0]
13651*da0073e9SAndroid Build Coastguard Worker
13652*da0073e9SAndroid Build Coastguard Worker            def __call__(self, ctx, func, *args, **kwargs):
13653*da0073e9SAndroid Build Coastguard Worker                counter = self.recompute_counter if ctx.is_recompute else self.counter
13654*da0073e9SAndroid Build Coastguard Worker                counter[0] += 1
13655*da0073e9SAndroid Build Coastguard Worker                counters.append(counter[0])
13656*da0073e9SAndroid Build Coastguard Worker                if counter == 1 and func is torch.ops.aten.mm.default:
13657*da0073e9SAndroid Build Coastguard Worker                    return CheckpointPolicy.MUST_SAVE
13658*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
13659*da0073e9SAndroid Build Coastguard Worker
13660*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13661*da0073e9SAndroid Build Coastguard Worker            return x.sin().sin().sin()
13662*da0073e9SAndroid Build Coastguard Worker
13663*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
13664*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(
13665*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts,
13666*da0073e9SAndroid Build Coastguard Worker            Policy(),
13667*da0073e9SAndroid Build Coastguard Worker            allow_cache_entry_mutation=True,
13668*da0073e9SAndroid Build Coastguard Worker        )
13669*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13670*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
13671*da0073e9SAndroid Build Coastguard Worker        # 1. counter properly reset to 0 for the recompute
13672*da0073e9SAndroid Build Coastguard Worker        # 2. due to early-stop we do not recompute the final op
13673*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counters, [1, 2, 3, 1, 2])
13674*da0073e9SAndroid Build Coastguard Worker
13675*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13676*da0073e9SAndroid Build Coastguard Worker    def test_storage_lifetime(self):
13677*da0073e9SAndroid Build Coastguard Worker        from torch.utils._python_dispatch import _get_current_dispatch_mode
13678*da0073e9SAndroid Build Coastguard Worker        from torch.utils.checkpoint import (
13679*da0073e9SAndroid Build Coastguard Worker            _CachedTorchDispatchMode,
13680*da0073e9SAndroid Build Coastguard Worker            _CachingTorchDispatchMode,
13681*da0073e9SAndroid Build Coastguard Worker        )
13682*da0073e9SAndroid Build Coastguard Worker
13683*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
13684*da0073e9SAndroid Build Coastguard Worker            return CheckpointPolicy.MUST_SAVE
13685*da0073e9SAndroid Build Coastguard Worker
13686*da0073e9SAndroid Build Coastguard Worker        ref = None
13687*da0073e9SAndroid Build Coastguard Worker
13688*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13689*da0073e9SAndroid Build Coastguard Worker            nonlocal ref
13690*da0073e9SAndroid Build Coastguard Worker
13691*da0073e9SAndroid Build Coastguard Worker            self.assertIsInstance(
13692*da0073e9SAndroid Build Coastguard Worker                _get_current_dispatch_mode(),
13693*da0073e9SAndroid Build Coastguard Worker                (_CachingTorchDispatchMode, _CachedTorchDispatchMode),
13694*da0073e9SAndroid Build Coastguard Worker            )
13695*da0073e9SAndroid Build Coastguard Worker
13696*da0073e9SAndroid Build Coastguard Worker            out = x.cos().exp()
13697*da0073e9SAndroid Build Coastguard Worker
13698*da0073e9SAndroid Build Coastguard Worker            if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode):
13699*da0073e9SAndroid Build Coastguard Worker                raw_val = (
13700*da0073e9SAndroid Build Coastguard Worker                    _get_current_dispatch_mode()
13701*da0073e9SAndroid Build Coastguard Worker                    .storage[torch.ops.aten.exp.default][0]
13702*da0073e9SAndroid Build Coastguard Worker                    .val
13703*da0073e9SAndroid Build Coastguard Worker                )
13704*da0073e9SAndroid Build Coastguard Worker                # ref should've been detached
13705*da0073e9SAndroid Build Coastguard Worker                # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph
13706*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(raw_val.requires_grad)
13707*da0073e9SAndroid Build Coastguard Worker                ref = weakref.ref(raw_val)
13708*da0073e9SAndroid Build Coastguard Worker
13709*da0073e9SAndroid Build Coastguard Worker            # Careful for early-stop
13710*da0073e9SAndroid Build Coastguard Worker            return out.sin()
13711*da0073e9SAndroid Build Coastguard Worker
13712*da0073e9SAndroid Build Coastguard Worker        with disable_gc():
13713*da0073e9SAndroid Build Coastguard Worker            # Case 1: If graph goes away without backward, make sure there's no reference cycle
13714*da0073e9SAndroid Build Coastguard Worker            #         keeping storage alive.
13715*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
13716*da0073e9SAndroid Build Coastguard Worker            context_fn = functools.partial(
13717*da0073e9SAndroid Build Coastguard Worker                create_selective_checkpoint_contexts, policy_fn
13718*da0073e9SAndroid Build Coastguard Worker            )
13719*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13720*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(ref())
13721*da0073e9SAndroid Build Coastguard Worker            del out
13722*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(ref())
13723*da0073e9SAndroid Build Coastguard Worker
13724*da0073e9SAndroid Build Coastguard Worker            # Case 2: After backward, even if retain_graph=True, the storage should go away
13725*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
13726*da0073e9SAndroid Build Coastguard Worker            context_fn = functools.partial(
13727*da0073e9SAndroid Build Coastguard Worker                create_selective_checkpoint_contexts, policy_fn
13728*da0073e9SAndroid Build Coastguard Worker            )
13729*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13730*da0073e9SAndroid Build Coastguard Worker            self.assertIsNotNone(ref())
13731*da0073e9SAndroid Build Coastguard Worker            out.sum().backward(retain_graph=True)
13732*da0073e9SAndroid Build Coastguard Worker            # The dispatch mode's storage should still be alive, but the entries should've
13733*da0073e9SAndroid Build Coastguard Worker            # been cleared.
13734*da0073e9SAndroid Build Coastguard Worker            self.assertIsNone(ref())
13735*da0073e9SAndroid Build Coastguard Worker
13736*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13737*da0073e9SAndroid Build Coastguard Worker    def test_version_counter(self):
13738*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
13739*da0073e9SAndroid Build Coastguard Worker            if op == torch.ops.aten.sin.default:
13740*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.MUST_SAVE
13741*da0073e9SAndroid Build Coastguard Worker            else:
13742*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
13743*da0073e9SAndroid Build Coastguard Worker
13744*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13745*da0073e9SAndroid Build Coastguard Worker            return x.sin().mul_(2).cos().exp()
13746*da0073e9SAndroid Build Coastguard Worker
13747*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
13748*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13749*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13750*da0073e9SAndroid Build Coastguard Worker
13751*da0073e9SAndroid Build Coastguard Worker        # 1) Error because the output of sin is saved and mutated by mul_
13752*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "has been mutated"):
13753*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
13754*da0073e9SAndroid Build Coastguard Worker
13755*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
13756*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(
13757*da0073e9SAndroid Build Coastguard Worker            create_selective_checkpoint_contexts,
13758*da0073e9SAndroid Build Coastguard Worker            policy_fn,
13759*da0073e9SAndroid Build Coastguard Worker            allow_cache_entry_mutation=True,
13760*da0073e9SAndroid Build Coastguard Worker        )
13761*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13762*da0073e9SAndroid Build Coastguard Worker
13763*da0073e9SAndroid Build Coastguard Worker        # 2) No longer should be an error because of allow_cache_entry_mutation
13764*da0073e9SAndroid Build Coastguard Worker        out.sum().backward()
13765*da0073e9SAndroid Build Coastguard Worker
13766*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13767*da0073e9SAndroid Build Coastguard Worker    def test_function_with_more_than_one_output(self):
13768*da0073e9SAndroid Build Coastguard Worker        # maybe there is a more systematic way:
13769*da0073e9SAndroid Build Coastguard Worker        counter = [0]
13770*da0073e9SAndroid Build Coastguard Worker
13771*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
13772*da0073e9SAndroid Build Coastguard Worker            if op == torch.ops.aten.var_mean.correction:
13773*da0073e9SAndroid Build Coastguard Worker                counter[0] += 1
13774*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.MUST_SAVE
13775*da0073e9SAndroid Build Coastguard Worker            else:
13776*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
13777*da0073e9SAndroid Build Coastguard Worker
13778*da0073e9SAndroid Build Coastguard Worker        # var_mean has two outputs
13779*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13780*da0073e9SAndroid Build Coastguard Worker            a, b = torch.var_mean(x)
13781*da0073e9SAndroid Build Coastguard Worker            return a * b
13782*da0073e9SAndroid Build Coastguard Worker
13783*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
13784*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13785*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13786*da0073e9SAndroid Build Coastguard Worker        x_grad = torch.autograd.grad(out.sum(), (x,))
13787*da0073e9SAndroid Build Coastguard Worker        x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,))
13788*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x_grad, x_grad_ref)
13789*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(counter[0], 2)
13790*da0073e9SAndroid Build Coastguard Worker
13791*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13792*da0073e9SAndroid Build Coastguard Worker    def test_function_with_non_tensor_output(self):
13793*da0073e9SAndroid Build Coastguard Worker        # When SAC is enabled, the op is not computed a second time
13794*da0073e9SAndroid Build Coastguard Worker        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
13795*da0073e9SAndroid Build Coastguard Worker            counter = [0]
13796*da0073e9SAndroid Build Coastguard Worker
13797*da0073e9SAndroid Build Coastguard Worker            @torch.library.custom_op("mylib::sin_with_extra", mutates_args=())
13798*da0073e9SAndroid Build Coastguard Worker            def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]:
13799*da0073e9SAndroid Build Coastguard Worker                counter[0] += 1
13800*da0073e9SAndroid Build Coastguard Worker                return x.sin(), 2
13801*da0073e9SAndroid Build Coastguard Worker
13802*da0073e9SAndroid Build Coastguard Worker            def setup_context(ctx, inputs, output) -> torch.Tensor:
13803*da0073e9SAndroid Build Coastguard Worker                (x,) = inputs
13804*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
13805*da0073e9SAndroid Build Coastguard Worker
13806*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad, _unused):
13807*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
13808*da0073e9SAndroid Build Coastguard Worker                return grad * x.cos()
13809*da0073e9SAndroid Build Coastguard Worker
13810*da0073e9SAndroid Build Coastguard Worker            torch.library.register_autograd(
13811*da0073e9SAndroid Build Coastguard Worker                "mylib::sin_with_extra", backward, setup_context=setup_context
13812*da0073e9SAndroid Build Coastguard Worker            )
13813*da0073e9SAndroid Build Coastguard Worker
13814*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
13815*da0073e9SAndroid Build Coastguard Worker
13816*da0073e9SAndroid Build Coastguard Worker            def fn(x):
13817*da0073e9SAndroid Build Coastguard Worker                return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin()
13818*da0073e9SAndroid Build Coastguard Worker
13819*da0073e9SAndroid Build Coastguard Worker            ops_list = [torch.ops.mylib.sin_with_extra.default]
13820*da0073e9SAndroid Build Coastguard Worker
13821*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
13822*da0073e9SAndroid Build Coastguard Worker            context_fn = functools.partial(
13823*da0073e9SAndroid Build Coastguard Worker                create_selective_checkpoint_contexts, ops_list
13824*da0073e9SAndroid Build Coastguard Worker            )
13825*da0073e9SAndroid Build Coastguard Worker            out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13826*da0073e9SAndroid Build Coastguard Worker            x_grad = torch.autograd.grad(out.sum(), (x,))
13827*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter[0], 1)
13828*da0073e9SAndroid Build Coastguard Worker            x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,))
13829*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_grad, x_grad_ref)
13830*da0073e9SAndroid Build Coastguard Worker
13831*da0073e9SAndroid Build Coastguard Worker    @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py")
13832*da0073e9SAndroid Build Coastguard Worker    def test_can_only_trigger_recompute_once(self):
13833*da0073e9SAndroid Build Coastguard Worker        # We don't support this to avoid adding extra complexity for now.
13834*da0073e9SAndroid Build Coastguard Worker        # If there's a need, we could probably do some kind of use_count tracking.
13835*da0073e9SAndroid Build Coastguard Worker        # TODO: have a nice error message here.
13836*da0073e9SAndroid Build Coastguard Worker        def policy_fn(ctx, op, *args, **kwargs):
13837*da0073e9SAndroid Build Coastguard Worker            if op == torch.ops.aten.sin.default:
13838*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.MUST_SAVE
13839*da0073e9SAndroid Build Coastguard Worker            else:
13840*da0073e9SAndroid Build Coastguard Worker                return CheckpointPolicy.PREFER_RECOMPUTE
13841*da0073e9SAndroid Build Coastguard Worker
13842*da0073e9SAndroid Build Coastguard Worker        def fn(x):
13843*da0073e9SAndroid Build Coastguard Worker            return x.sin().cos().exp()
13844*da0073e9SAndroid Build Coastguard Worker
13845*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, requires_grad=True)
13846*da0073e9SAndroid Build Coastguard Worker        context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
13847*da0073e9SAndroid Build Coastguard Worker        out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn)
13848*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(retain_graph=True)
13849*da0073e9SAndroid Build Coastguard Worker
13850*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"):
13851*da0073e9SAndroid Build Coastguard Worker            out.sum().backward(retain_graph=True)
13852*da0073e9SAndroid Build Coastguard Worker
13853*da0073e9SAndroid Build Coastguard Worker
13854*da0073e9SAndroid Build Coastguard Workerclass TestAutogradMultipleDispatch(TestCase):
13855*da0073e9SAndroid Build Coastguard Worker    def test_autograd_multiple_dispatch_registrations(self, device):
13856*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 3, device=device, requires_grad=True)
13857*da0073e9SAndroid Build Coastguard Worker        # using _test_autograd_multiple_dispatch.fullcoverage which has
13858*da0073e9SAndroid Build Coastguard Worker        # registrations in derivatives.yaml for Default, AutogradCUDA and NestedTensorAutograd
13859*da0073e9SAndroid Build Coastguard Worker        out = torch._test_autograd_multiple_dispatch(t)
13860*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(3, 3, device=device)
13861*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
13862*da0073e9SAndroid Build Coastguard Worker
13863*da0073e9SAndroid Build Coastguard Worker        if "cuda" not in device:
13864*da0073e9SAndroid Build Coastguard Worker            # bogus default gradient registered for Autograd is grad + 1
13865*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.grad, grad + 1)
13866*da0073e9SAndroid Build Coastguard Worker        else:
13867*da0073e9SAndroid Build Coastguard Worker            # bogus gradient registered for AutogradCUDA is grad * 2
13868*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.grad, grad * 2)
13869*da0073e9SAndroid Build Coastguard Worker
13870*da0073e9SAndroid Build Coastguard Worker        # test registered AutogradNestedTensor formula
13871*da0073e9SAndroid Build Coastguard Worker        a = (
13872*da0073e9SAndroid Build Coastguard Worker            torch.arange(6, dtype=torch.float, device=device)
13873*da0073e9SAndroid Build Coastguard Worker            .reshape(2, 3)
13874*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
13875*da0073e9SAndroid Build Coastguard Worker        )
13876*da0073e9SAndroid Build Coastguard Worker        b = (
13877*da0073e9SAndroid Build Coastguard Worker            torch.arange(8, dtype=torch.float, device=device)
13878*da0073e9SAndroid Build Coastguard Worker            .reshape(2, 4)
13879*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
13880*da0073e9SAndroid Build Coastguard Worker        )
13881*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
13882*da0073e9SAndroid Build Coastguard Worker
13883*da0073e9SAndroid Build Coastguard Worker        nt_out = torch._test_autograd_multiple_dispatch(nt)
13884*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(2, 3, device=device)
13885*da0073e9SAndroid Build Coastguard Worker        d = torch.randn(2, 4, device=device)
13886*da0073e9SAndroid Build Coastguard Worker        nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
13887*da0073e9SAndroid Build Coastguard Worker        nt_out.backward(nt_grad)
13888*da0073e9SAndroid Build Coastguard Worker
13889*da0073e9SAndroid Build Coastguard Worker        # bogus gradient for AutogradNestedTensor is grad * grad
13890*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, c * c)
13891*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, d * d)
13892*da0073e9SAndroid Build Coastguard Worker
13893*da0073e9SAndroid Build Coastguard Worker    def test_autograd_composite_implicit_and_dispatch_registration(self, device):
13894*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(3, 3, device=device, requires_grad=True)
13895*da0073e9SAndroid Build Coastguard Worker        # using _test_autograd_multiple_dispatch.ntonly
13896*da0073e9SAndroid Build Coastguard Worker        # which has registrations in derivatives.yaml for NestedTensorAutograd and otherwise is CompositeImplicit
13897*da0073e9SAndroid Build Coastguard Worker        out = torch._test_autograd_multiple_dispatch(t, True)
13898*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(3, 3, device=device)
13899*da0073e9SAndroid Build Coastguard Worker        out.backward(grad)
13900*da0073e9SAndroid Build Coastguard Worker
13901*da0073e9SAndroid Build Coastguard Worker        # t.grad is just out.grad by composite op since _test_autograd_multiple_dispatch is just a clone
13902*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.grad, grad)
13903*da0073e9SAndroid Build Coastguard Worker
13904*da0073e9SAndroid Build Coastguard Worker        # test registered AutogradNestedTensor formula
13905*da0073e9SAndroid Build Coastguard Worker        a = (
13906*da0073e9SAndroid Build Coastguard Worker            torch.arange(6, dtype=torch.float, device=device)
13907*da0073e9SAndroid Build Coastguard Worker            .reshape(2, 3)
13908*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
13909*da0073e9SAndroid Build Coastguard Worker        )
13910*da0073e9SAndroid Build Coastguard Worker        b = (
13911*da0073e9SAndroid Build Coastguard Worker            torch.arange(8, dtype=torch.float, device=device)
13912*da0073e9SAndroid Build Coastguard Worker            .reshape(2, 4)
13913*da0073e9SAndroid Build Coastguard Worker            .requires_grad_(True)
13914*da0073e9SAndroid Build Coastguard Worker        )
13915*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device)
13916*da0073e9SAndroid Build Coastguard Worker
13917*da0073e9SAndroid Build Coastguard Worker        nt_out = torch._test_autograd_multiple_dispatch(nt, True)
13918*da0073e9SAndroid Build Coastguard Worker        c = torch.randn(2, 3, device=device)
13919*da0073e9SAndroid Build Coastguard Worker        d = torch.randn(2, 4, device=device)
13920*da0073e9SAndroid Build Coastguard Worker        nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device)
13921*da0073e9SAndroid Build Coastguard Worker        nt_out.backward(nt_grad)
13922*da0073e9SAndroid Build Coastguard Worker
13923*da0073e9SAndroid Build Coastguard Worker        # bogus gradient for AutogradNestedTensor is grad * grad + grad
13924*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad, c * c + c)
13925*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad, d * d + d)
13926*da0073e9SAndroid Build Coastguard Worker
13927*da0073e9SAndroid Build Coastguard Worker    def test_foward_mode_AD(self, device):
13928*da0073e9SAndroid Build Coastguard Worker        # check that forward mode AD is only registered for the Default
13929*da0073e9SAndroid Build Coastguard Worker        # dispatch for _test_autograd_multiple_dispatch.fullcoverage and not AutogradCUDA
13930*da0073e9SAndroid Build Coastguard Worker
13931*da0073e9SAndroid Build Coastguard Worker        primal = torch.randn(3, device=device)
13932*da0073e9SAndroid Build Coastguard Worker        tangent = torch.randn(3, device=device)
13933*da0073e9SAndroid Build Coastguard Worker
13934*da0073e9SAndroid Build Coastguard Worker        with fwAD.dual_level():
13935*da0073e9SAndroid Build Coastguard Worker            dual_input = fwAD.make_dual(primal, tangent)
13936*da0073e9SAndroid Build Coastguard Worker
13937*da0073e9SAndroid Build Coastguard Worker            err_msg = r"Trying to use forward AD with .* that does not support it"
13938*da0073e9SAndroid Build Coastguard Worker            hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError"
13939*da0073e9SAndroid Build Coastguard Worker
13940*da0073e9SAndroid Build Coastguard Worker            if "cuda" in device:
13941*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
13942*da0073e9SAndroid Build Coastguard Worker                    torch._test_autograd_multiple_dispatch(dual_input)
13943*da0073e9SAndroid Build Coastguard Worker            else:
13944*da0073e9SAndroid Build Coastguard Worker                torch._test_autograd_multiple_dispatch(dual_input)
13945*da0073e9SAndroid Build Coastguard Worker
13946*da0073e9SAndroid Build Coastguard Worker    def test_view_copy(self, device):
13947*da0073e9SAndroid Build Coastguard Worker        # tests that view_copy derivative formulas are also generated per dispatch key
13948*da0073e9SAndroid Build Coastguard Worker        # from their respective view ops in derivatives.yaml
13949*da0073e9SAndroid Build Coastguard Worker        t = torch.randn(2, 2, device=device, requires_grad=True)
13950*da0073e9SAndroid Build Coastguard Worker        t_ref = t.clone().detach().requires_grad_()
13951*da0073e9SAndroid Build Coastguard Worker        # _test_autograd_multiple_dispatch_view does a .view(-1) on the input
13952*da0073e9SAndroid Build Coastguard Worker        t_view = torch._test_autograd_multiple_dispatch_view(t_ref)
13953*da0073e9SAndroid Build Coastguard Worker        t_view_copy = torch._test_autograd_multiple_dispatch_view_copy(t)
13954*da0073e9SAndroid Build Coastguard Worker
13955*da0073e9SAndroid Build Coastguard Worker        grad = torch.randn(4, device=device)
13956*da0073e9SAndroid Build Coastguard Worker        t_view_copy.backward(grad)
13957*da0073e9SAndroid Build Coastguard Worker        t_view.backward(grad.clone())
13958*da0073e9SAndroid Build Coastguard Worker
13959*da0073e9SAndroid Build Coastguard Worker        # forward and backward give the same shape + result
13960*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t_view_copy, t_view)
13961*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t.grad, t_ref.grad)
13962*da0073e9SAndroid Build Coastguard Worker        # backward results are per-dispatch-key in derivatives.yaml
13963*da0073e9SAndroid Build Coastguard Worker        if "cuda" in device:
13964*da0073e9SAndroid Build Coastguard Worker            # gradient registered to AutogradCUDA is grad.reshape_as(self) + 1
13965*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.grad, grad.reshape_as(t) + 1)
13966*da0073e9SAndroid Build Coastguard Worker        else:
13967*da0073e9SAndroid Build Coastguard Worker            # Default gradient registered is grad.reshape_as(self)
13968*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.grad, grad.reshape_as(t))
13969*da0073e9SAndroid Build Coastguard Worker
13970*da0073e9SAndroid Build Coastguard Worker    @onlyCPU
13971*da0073e9SAndroid Build Coastguard Worker    def test_per_dispatch_key_input_saving(self, device):
13972*da0073e9SAndroid Build Coastguard Worker        # Tests that sum.dim_IntList's input is not saved for regular tensors but is saved for nested tensors
13973*da0073e9SAndroid Build Coastguard Worker        def foo(x):
13974*da0073e9SAndroid Build Coastguard Worker            # Don't modify the input inplace
13975*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
13976*da0073e9SAndroid Build Coastguard Worker            res = x.sum(-1, keepdim=True)
13977*da0073e9SAndroid Build Coastguard Worker            x.add_(x)
13978*da0073e9SAndroid Build Coastguard Worker            return res
13979*da0073e9SAndroid Build Coastguard Worker
13980*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(2, device=device, requires_grad=True)
13981*da0073e9SAndroid Build Coastguard Worker        # sum's input is not saved for regular Tensors
13982*da0073e9SAndroid Build Coastguard Worker        foo(inp).backward()
13983*da0073e9SAndroid Build Coastguard Worker
13984*da0073e9SAndroid Build Coastguard Worker        # sum's input is saved for Nested Tensors
13985*da0073e9SAndroid Build Coastguard Worker        nt = torch.nested.nested_tensor(
13986*da0073e9SAndroid Build Coastguard Worker            [torch.rand(2), torch.rand(2)], device=device, requires_grad=True
13987*da0073e9SAndroid Build Coastguard Worker        )
13988*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"):
13989*da0073e9SAndroid Build Coastguard Worker            foo(nt).backward(
13990*da0073e9SAndroid Build Coastguard Worker                torch.nested.nested_tensor(
13991*da0073e9SAndroid Build Coastguard Worker                    [torch.rand(1), torch.rand(1)], device=device
13992*da0073e9SAndroid Build Coastguard Worker                )
13993*da0073e9SAndroid Build Coastguard Worker            )
13994*da0073e9SAndroid Build Coastguard Worker
13995*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
13996*da0073e9SAndroid Build Coastguard Worker    def test_backward_single_threaded(self):
13997*da0073e9SAndroid Build Coastguard Worker        threads_eq = None
13998*da0073e9SAndroid Build Coastguard Worker
13999*da0073e9SAndroid Build Coastguard Worker        class TestFn(Function):
14000*da0073e9SAndroid Build Coastguard Worker            @staticmethod
14001*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, self):
14002*da0073e9SAndroid Build Coastguard Worker                ctx.self = self
14003*da0073e9SAndroid Build Coastguard Worker                ctx.tid = threading.get_ident()
14004*da0073e9SAndroid Build Coastguard Worker                return x.clone()
14005*da0073e9SAndroid Build Coastguard Worker
14006*da0073e9SAndroid Build Coastguard Worker            @staticmethod
14007*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
14008*da0073e9SAndroid Build Coastguard Worker                nonlocal threads_eq
14009*da0073e9SAndroid Build Coastguard Worker                threads_eq = ctx.tid == threading.get_ident()
14010*da0073e9SAndroid Build Coastguard Worker                return gO, None
14011*da0073e9SAndroid Build Coastguard Worker
14012*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(10, device="cuda", requires_grad=True)
14013*da0073e9SAndroid Build Coastguard Worker
14014*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.set_multithreading_enabled(False):
14015*da0073e9SAndroid Build Coastguard Worker            TestFn.apply(inp, None).sum().backward()
14016*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(threads_eq)
14017*da0073e9SAndroid Build Coastguard Worker
14018*da0073e9SAndroid Build Coastguard Worker        TestFn.apply(inp, None).sum().backward()
14019*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(threads_eq)
14020*da0073e9SAndroid Build Coastguard Worker
14021*da0073e9SAndroid Build Coastguard Worker    @onlyCUDA
14022*da0073e9SAndroid Build Coastguard Worker    def test_backward_tls_stash(self):
14023*da0073e9SAndroid Build Coastguard Worker        local = threading.local()
14024*da0073e9SAndroid Build Coastguard Worker        local.my_obj = {}
14025*da0073e9SAndroid Build Coastguard Worker        local.my_obj[10] = 10
14026*da0073e9SAndroid Build Coastguard Worker        test_self = self
14027*da0073e9SAndroid Build Coastguard Worker        torch._C._stash_obj_in_tls("my_obj", local.my_obj)
14028*da0073e9SAndroid Build Coastguard Worker
14029*da0073e9SAndroid Build Coastguard Worker        class TestFn(Function):
14030*da0073e9SAndroid Build Coastguard Worker            @staticmethod
14031*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x, self):
14032*da0073e9SAndroid Build Coastguard Worker                return x.clone()
14033*da0073e9SAndroid Build Coastguard Worker
14034*da0073e9SAndroid Build Coastguard Worker            @staticmethod
14035*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
14036*da0073e9SAndroid Build Coastguard Worker                test_self.assertTrue(torch._C._is_key_in_tls("my_obj"))
14037*da0073e9SAndroid Build Coastguard Worker                test_self.assertTrue(torch._C._get_obj_in_tls("my_obj")[10] == 10)
14038*da0073e9SAndroid Build Coastguard Worker                torch._C._get_obj_in_tls("my_obj")[10] = 5
14039*da0073e9SAndroid Build Coastguard Worker                return gO, None
14040*da0073e9SAndroid Build Coastguard Worker
14041*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(10, device="cuda", requires_grad=True)
14042*da0073e9SAndroid Build Coastguard Worker
14043*da0073e9SAndroid Build Coastguard Worker        TestFn.apply(inp, None).sum().backward()
14044*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(local.my_obj[10], 5)
14045*da0073e9SAndroid Build Coastguard Worker
14046*da0073e9SAndroid Build Coastguard Worker    def test_is_retain_graph(self):
14047*da0073e9SAndroid Build Coastguard Worker        retain_graph_set = False
14048*da0073e9SAndroid Build Coastguard Worker
14049*da0073e9SAndroid Build Coastguard Worker        class TestFn(Function):
14050*da0073e9SAndroid Build Coastguard Worker            @staticmethod
14051*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
14052*da0073e9SAndroid Build Coastguard Worker                return x.clone()
14053*da0073e9SAndroid Build Coastguard Worker
14054*da0073e9SAndroid Build Coastguard Worker            @staticmethod
14055*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, gO):
14056*da0073e9SAndroid Build Coastguard Worker                nonlocal retain_graph_set
14057*da0073e9SAndroid Build Coastguard Worker                retain_graph_set = (
14058*da0073e9SAndroid Build Coastguard Worker                    torch._C._autograd._get_current_graph_task_keep_graph()
14059*da0073e9SAndroid Build Coastguard Worker                )
14060*da0073e9SAndroid Build Coastguard Worker                return gO, None
14061*da0073e9SAndroid Build Coastguard Worker
14062*da0073e9SAndroid Build Coastguard Worker        inp = torch.rand(10, requires_grad=True)
14063*da0073e9SAndroid Build Coastguard Worker
14064*da0073e9SAndroid Build Coastguard Worker        out = TestFn.apply(inp)
14065*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(retain_graph_set)
14066*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(retain_graph=True)
14067*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(retain_graph_set)
14068*da0073e9SAndroid Build Coastguard Worker        out.sum().backward(retain_graph=False)
14069*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(retain_graph_set)
14070*da0073e9SAndroid Build Coastguard Worker
14071*da0073e9SAndroid Build Coastguard Worker    def test_set_sequence_nr(self):
14072*da0073e9SAndroid Build Coastguard Worker        x = torch.randn((10,), dtype=torch.float32, requires_grad=True)
14073*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((10,), dtype=torch.float32, requires_grad=True)
14074*da0073e9SAndroid Build Coastguard Worker        z = torch.randn((10,), dtype=torch.float32, requires_grad=True)
14075*da0073e9SAndroid Build Coastguard Worker
14076*da0073e9SAndroid Build Coastguard Worker        a = x + y
14077*da0073e9SAndroid Build Coastguard Worker        b = y + z
14078*da0073e9SAndroid Build Coastguard Worker        c = a + b
14079*da0073e9SAndroid Build Coastguard Worker
14080*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(a.grad_fn)
14081*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(b.grad_fn)
14082*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(c.grad_fn)
14083*da0073e9SAndroid Build Coastguard Worker
14084*da0073e9SAndroid Build Coastguard Worker        a.grad_fn._set_sequence_nr(100)
14085*da0073e9SAndroid Build Coastguard Worker        b.grad_fn._set_sequence_nr(99)
14086*da0073e9SAndroid Build Coastguard Worker        c.grad_fn._set_sequence_nr(98)
14087*da0073e9SAndroid Build Coastguard Worker
14088*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(a.grad_fn._sequence_nr(), 100)
14089*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(b.grad_fn._sequence_nr(), 99)
14090*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(c.grad_fn._sequence_nr(), 98)
14091*da0073e9SAndroid Build Coastguard Worker
14092*da0073e9SAndroid Build Coastguard Worker        def log_grad_order(grad: torch.Tensor, name: str, order):
14093*da0073e9SAndroid Build Coastguard Worker            order.append(name)
14094*da0073e9SAndroid Build Coastguard Worker            return grad
14095*da0073e9SAndroid Build Coastguard Worker
14096*da0073e9SAndroid Build Coastguard Worker        order = []
14097*da0073e9SAndroid Build Coastguard Worker        a.register_hook(partial(log_grad_order, name="a", order=order))
14098*da0073e9SAndroid Build Coastguard Worker        b.register_hook(partial(log_grad_order, name="b", order=order))
14099*da0073e9SAndroid Build Coastguard Worker        c.register_hook(partial(log_grad_order, name="c", order=order))
14100*da0073e9SAndroid Build Coastguard Worker
14101*da0073e9SAndroid Build Coastguard Worker        c.sum().backward()
14102*da0073e9SAndroid Build Coastguard Worker
14103*da0073e9SAndroid Build Coastguard Worker        # Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd.
14104*da0073e9SAndroid Build Coastguard Worker        # Also check that although a comes first during the forward, after giving it priority with sequence_nr,
14105*da0073e9SAndroid Build Coastguard Worker        # its autograd node is run before that of b.
14106*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(order, ["c", "a", "b"])
14107*da0073e9SAndroid Build Coastguard Worker
14108*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, torch.ones_like(x))
14109*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, 2 * torch.ones_like(x))
14110*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(z.grad, torch.ones_like(x))
14111*da0073e9SAndroid Build Coastguard Worker
14112*da0073e9SAndroid Build Coastguard Worker
14113*da0073e9SAndroid Build Coastguard Worker# Import test cases from below autograd/ here. These are found
14114*da0073e9SAndroid Build Coastguard Worker# implicitly by the loader, so Flake8 thinks they are unused, hence
14115*da0073e9SAndroid Build Coastguard Worker# the suppressions.
14116*da0073e9SAndroid Build Coastguard Worker
14117*da0073e9SAndroid Build Coastguard Workerfrom autograd.test_complex import TestAutogradComplex  # noqa: F401
14118*da0073e9SAndroid Build Coastguard Workerfrom autograd.test_functional import TestAutogradFunctional  # noqa: F401
14119*da0073e9SAndroid Build Coastguard Workerfrom autograd.test_logging import TestAutogradLogging  # noqa: F401
14120*da0073e9SAndroid Build Coastguard Worker
14121*da0073e9SAndroid Build Coastguard Worker
14122*da0073e9SAndroid Build Coastguard Worker# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA
14123*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None)
14124*da0073e9SAndroid Build Coastguard Worker
14125*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(
14126*da0073e9SAndroid Build Coastguard Worker    TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda")
14127*da0073e9SAndroid Build Coastguard Worker)
14128*da0073e9SAndroid Build Coastguard Worker
14129*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestAutograd)
14130*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNestedCheckpoint)
14131*da0073e9SAndroid Build Coastguard Worker
14132*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
14133*da0073e9SAndroid Build Coastguard Worker    run_tests()
14134