xref: /aosp_15_r20/external/pytorch/test/autograd/test_functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport types
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Workerimport warnings
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.functional as autogradF
9*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA
10*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
11*da0073e9SAndroid Build Coastguard Worker    gradcheck,
12*da0073e9SAndroid Build Coastguard Worker    gradgradcheck,
13*da0073e9SAndroid Build Coastguard Worker    instantiate_parametrized_tests,
14*da0073e9SAndroid Build Coastguard Worker    parametrize,
15*da0073e9SAndroid Build Coastguard Worker    run_tests,
16*da0073e9SAndroid Build Coastguard Worker    subtest,
17*da0073e9SAndroid Build Coastguard Worker    TestCase,
18*da0073e9SAndroid Build Coastguard Worker)
19*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import LoggingTensor
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker# Utilities for parametrizing the tensor constructors used in autograd tests
23*da0073e9SAndroid Build Coastguard Worker#
24*da0073e9SAndroid Build Coastguard Worker# TODO: maybe move somewhere so other tests can also use
25*da0073e9SAndroid Build Coastguard Worker#
26*da0073e9SAndroid Build Coastguard Worker# NB: Not all factory functions included. A complete(?) list can be found here:
27*da0073e9SAndroid Build Coastguard Worker#     https://pytorch.org/cppdocs/notes/tensor_creation.html
28*da0073e9SAndroid Build Coastguard Workerbase_ctors_dict = {
29*da0073e9SAndroid Build Coastguard Worker    "ones": torch.ones,
30*da0073e9SAndroid Build Coastguard Worker    "zeros": torch.zeros,
31*da0073e9SAndroid Build Coastguard Worker    "randn": torch.randn,
32*da0073e9SAndroid Build Coastguard Worker    "rand": torch.rand,
33*da0073e9SAndroid Build Coastguard Worker    "tensor": torch.tensor,
34*da0073e9SAndroid Build Coastguard Worker}
35*da0073e9SAndroid Build Coastguard Workerbase_ctors = types.SimpleNamespace(**base_ctors_dict)
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Workerdef wrap_with_logging_tensor(ctor):
39*da0073e9SAndroid Build Coastguard Worker    def wrapper(*args, **kwargs):
40*da0073e9SAndroid Build Coastguard Worker        requires_grad = kwargs.pop("requires_grad", False)
41*da0073e9SAndroid Build Coastguard Worker        return LoggingTensor(ctor(*args, **kwargs), requires_grad=requires_grad)
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    return wrapper
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerlogging_tensor_ctors_dict = {
47*da0073e9SAndroid Build Coastguard Worker    k: wrap_with_logging_tensor(ctor) for (k, ctor) in base_ctors_dict.items()
48*da0073e9SAndroid Build Coastguard Worker}
49*da0073e9SAndroid Build Coastguard Workerlogging_tensor_ctors = types.SimpleNamespace(**logging_tensor_ctors_dict)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Workerbase_and_logging_tensor = parametrize(
52*da0073e9SAndroid Build Coastguard Worker    "ctors",
53*da0073e9SAndroid Build Coastguard Worker    [
54*da0073e9SAndroid Build Coastguard Worker        subtest(base_ctors, name="base_tensor"),
55*da0073e9SAndroid Build Coastguard Worker        subtest(logging_tensor_ctors, name="logging_tensor"),
56*da0073e9SAndroid Build Coastguard Worker    ],
57*da0073e9SAndroid Build Coastguard Worker)
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard WorkerFIXME_base_and_xfail_logging_tensor = parametrize(
60*da0073e9SAndroid Build Coastguard Worker    "ctors",
61*da0073e9SAndroid Build Coastguard Worker    [
62*da0073e9SAndroid Build Coastguard Worker        subtest(base_ctors, name="base_tensor"),
63*da0073e9SAndroid Build Coastguard Worker        subtest(
64*da0073e9SAndroid Build Coastguard Worker            logging_tensor_ctors,
65*da0073e9SAndroid Build Coastguard Worker            name="logging_tensor",
66*da0073e9SAndroid Build Coastguard Worker            decorators=[unittest.expectedFailure],
67*da0073e9SAndroid Build Coastguard Worker        ),
68*da0073e9SAndroid Build Coastguard Worker    ],
69*da0073e9SAndroid Build Coastguard Worker)
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker# NB: This is equivalent to having both @parametrize("vectorized", [True, False]) and
72*da0073e9SAndroid Build Coastguard Worker#     FIXME_base_and_xfail_logging_tensor, except the non-vectorized logging_tensor case is
73*da0073e9SAndroid Build Coastguard Worker#     actually expected to succeed
74*da0073e9SAndroid Build Coastguard WorkerFIXME_xfail_vectorized_logging_tensor = parametrize(
75*da0073e9SAndroid Build Coastguard Worker    "vectorize,ctors",
76*da0073e9SAndroid Build Coastguard Worker    [
77*da0073e9SAndroid Build Coastguard Worker        subtest((True, base_ctors), name="vectorized_base_tensor"),
78*da0073e9SAndroid Build Coastguard Worker        subtest((False, base_ctors), name="base_tensor"),
79*da0073e9SAndroid Build Coastguard Worker        subtest(
80*da0073e9SAndroid Build Coastguard Worker            (True, logging_tensor_ctors),
81*da0073e9SAndroid Build Coastguard Worker            name="vectorized_logging_tensor",
82*da0073e9SAndroid Build Coastguard Worker            decorators=[unittest.expectedFailure],
83*da0073e9SAndroid Build Coastguard Worker        ),
84*da0073e9SAndroid Build Coastguard Worker        subtest((False, logging_tensor_ctors), name="logging_tensor"),
85*da0073e9SAndroid Build Coastguard Worker    ],
86*da0073e9SAndroid Build Coastguard Worker)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workervectorized_logging_tensor = parametrize(
89*da0073e9SAndroid Build Coastguard Worker    "vectorize,ctors",
90*da0073e9SAndroid Build Coastguard Worker    [
91*da0073e9SAndroid Build Coastguard Worker        subtest((True, base_ctors), name="vectorized_base_tensor"),
92*da0073e9SAndroid Build Coastguard Worker        subtest((False, base_ctors), name="base_tensor"),
93*da0073e9SAndroid Build Coastguard Worker        subtest((True, logging_tensor_ctors), name="vectorized_logging_tensor"),
94*da0073e9SAndroid Build Coastguard Worker        subtest((False, logging_tensor_ctors), name="logging_tensor"),
95*da0073e9SAndroid Build Coastguard Worker    ],
96*da0073e9SAndroid Build Coastguard Worker)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Workerclass TestAutogradFunctional(TestCase):
100*da0073e9SAndroid Build Coastguard Worker    def _assert_same_struct(self, res, base):
101*da0073e9SAndroid Build Coastguard Worker        # base and res should be Tensors or tuple of Tensors with the same size
102*da0073e9SAndroid Build Coastguard Worker        if isinstance(base, torch.Tensor):
103*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(res, torch.Tensor))
104*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(base.size(), res.size())
105*da0073e9SAndroid Build Coastguard Worker        elif isinstance(base, tuple):
106*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(res, tuple))
107*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(base), len(res))
108*da0073e9SAndroid Build Coastguard Worker            for el_base, el_res in zip(base, res):
109*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_base, torch.Tensor))
110*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_res, torch.Tensor))
111*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(el_base.size(), el_res.size())
112*da0073e9SAndroid Build Coastguard Worker        else:
113*da0073e9SAndroid Build Coastguard Worker            # Wrong base
114*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
115*da0073e9SAndroid Build Coastguard Worker                "The base given to `_assert_same_struct` doesn't have"
116*da0073e9SAndroid Build Coastguard Worker                " the right structure."
117*da0073e9SAndroid Build Coastguard Worker            )
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def _assert_interleaved_struct(self, res, base1, base2):
120*da0073e9SAndroid Build Coastguard Worker        # base1 and base2 can be Tensors or tuples of Tensors.
121*da0073e9SAndroid Build Coastguard Worker        # If they are tuples, res should be a tuple as well.
122*da0073e9SAndroid Build Coastguard Worker        # The indexing works as follows for base1, base2 being
123*da0073e9SAndroid Build Coastguard Worker        # - tuple, tuple: res[i][j][k][l] = (base1[i][k], base2[j][l])
124*da0073e9SAndroid Build Coastguard Worker        # - tuple, Tensor: res[i][k][l] = (base1[i][k], base2[l])
125*da0073e9SAndroid Build Coastguard Worker        # - Tensor, tuple: res[i][j][l] = (base1[i], base2[j][l])
126*da0073e9SAndroid Build Coastguard Worker        # - Tensor, Tensor: res[k][l] = (base1[k], base2[l])
127*da0073e9SAndroid Build Coastguard Worker        if isinstance(base1, torch.Tensor) and isinstance(base2, torch.Tensor):
128*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(res, torch.Tensor))
129*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(res.size(), base1.size() + base2.size())
130*da0073e9SAndroid Build Coastguard Worker        elif isinstance(base1, tuple) and isinstance(base2, torch.Tensor):
131*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(res, tuple))
132*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(res), len(base1))
133*da0073e9SAndroid Build Coastguard Worker            for el_res, el_base1 in zip(res, base1):
134*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_res, torch.Tensor))
135*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_base1, torch.Tensor))
136*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(el_res.size(), el_base1.size() + base2.size())
137*da0073e9SAndroid Build Coastguard Worker        elif isinstance(base1, torch.Tensor) and isinstance(base2, tuple):
138*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(res, tuple))
139*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(res), len(base2))
140*da0073e9SAndroid Build Coastguard Worker            for el_res, el_base2 in zip(res, base2):
141*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_res, torch.Tensor))
142*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_base2, torch.Tensor))
143*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(el_res.size(), base1.size() + el_base2.size())
144*da0073e9SAndroid Build Coastguard Worker        elif isinstance(base1, tuple) and isinstance(base2, tuple):
145*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(res, tuple))
146*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(len(res), len(base1))
147*da0073e9SAndroid Build Coastguard Worker            for el_res, el_base1 in zip(res, base1):
148*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(isinstance(el_res, tuple))
149*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(len(res), len(base2))
150*da0073e9SAndroid Build Coastguard Worker                for el_el_res, el_base2 in zip(el_res, base2):
151*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(isinstance(el_el_res, torch.Tensor))
152*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(isinstance(el_base2, torch.Tensor))
153*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(
154*da0073e9SAndroid Build Coastguard Worker                        el_el_res.size(), el_base1.size() + el_base2.size()
155*da0073e9SAndroid Build Coastguard Worker                    )
156*da0073e9SAndroid Build Coastguard Worker        else:
157*da0073e9SAndroid Build Coastguard Worker            # Wrong bases
158*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
159*da0073e9SAndroid Build Coastguard Worker                "The bases given to `_assert_interleaved_struct` don't have"
160*da0073e9SAndroid Build Coastguard Worker                " the right structure."
161*da0073e9SAndroid Build Coastguard Worker            )
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
164*da0073e9SAndroid Build Coastguard Worker    def test_vjp_err_check(self, ctors):
165*da0073e9SAndroid Build Coastguard Worker        def foo(a):
166*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3)
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        def bar(a):
169*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), "bar"
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
172*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(3)
173*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
174*da0073e9SAndroid Build Coastguard Worker            TypeError, "The inputs given to vjp must be either a Tensor"
175*da0073e9SAndroid Build Coastguard Worker        ):
176*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(foo, (inp, 2), v)
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
179*da0073e9SAndroid Build Coastguard Worker            TypeError, "The outputs of the user-provided function given to vjp must"
180*da0073e9SAndroid Build Coastguard Worker        ):
181*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(bar, inp, v)
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
184*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
185*da0073e9SAndroid Build Coastguard Worker            "The vector v can only be None if the user-provided function returns",
186*da0073e9SAndroid Build Coastguard Worker        ):
187*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(foo, inp)
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
190*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "The given v should contain a single Tensor."
191*da0073e9SAndroid Build Coastguard Worker        ):
192*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(foo, inp, (torch.ones_like(inp), torch.ones_like(inp)))
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
195*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "v has invalid size: should be torch.Size"
196*da0073e9SAndroid Build Coastguard Worker        ):
197*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(foo, inp, v[:2])
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(foo, inp, v)[1]
200*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res, inp)
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
203*da0073e9SAndroid Build Coastguard Worker    def test_vjp_err_check_strict(self, ctors):
204*da0073e9SAndroid Build Coastguard Worker        def foo(a):
205*da0073e9SAndroid Build Coastguard Worker            return a.detach()
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker        def bar(a):
208*da0073e9SAndroid Build Coastguard Worker            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
209*da0073e9SAndroid Build Coastguard Worker            return a.long().float().requires_grad_().clone()
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
212*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
213*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
214*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
215*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function does not require gradients.",
216*da0073e9SAndroid Build Coastguard Worker        ):
217*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(foo, inp, v, strict=True)
218*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(foo, inp, v, strict=False)
219*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
223*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
224*da0073e9SAndroid Build Coastguard Worker            "The output of the user-provided function is independent of input 0",
225*da0073e9SAndroid Build Coastguard Worker        ):
226*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(bar, inp, v, strict=True)
227*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(bar, inp, v, strict=False)
228*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker        # The Jacobian does not depend on the input
232*da0073e9SAndroid Build Coastguard Worker        def foo(a):
233*da0073e9SAndroid Build Coastguard Worker            return a.clone()
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        inp.requires_grad_()
236*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
237*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
238*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function is independent of input 0.",
239*da0073e9SAndroid Build Coastguard Worker        ):
240*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(foo, inp, v, create_graph=True, strict=True)
241*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(foo, inp, v, create_graph=True, strict=False)
242*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
243*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1], v)
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
246*da0073e9SAndroid Build Coastguard Worker    def test_vjp_no_grad(self, ctors):
247*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
248*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=1)
249*da0073e9SAndroid Build Coastguard Worker
250*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
251*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4)
252*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
253*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(reducer, inputs, v)
254*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
255*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
256*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        inputs.requires_grad_()
259*da0073e9SAndroid Build Coastguard Worker        v.requires_grad_()
260*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
261*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vjp(reducer, inputs, v, create_graph=True)
262*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
263*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
264*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
267*da0073e9SAndroid Build Coastguard Worker    def test_vjp_output(self, ctors):
268*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
269*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=1)
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
272*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4)
273*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(reducer, inputs, v)
274*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
275*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
276*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
279*da0073e9SAndroid Build Coastguard Worker            return 2 * x + 3 * y
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(2), ctors.rand(2))
282*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(2)
283*da0073e9SAndroid Build Coastguard Worker        out, vjp_val = autogradF.vjp(adder, inputs, v)
284*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(vjp_val, inputs)
285*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn)
286*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(vjp_val[0].grad_fn)
287*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(vjp_val[1].grad_fn)
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
290*da0073e9SAndroid Build Coastguard Worker            return 2 * x + 3 * y, x + y
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(2), ctors.rand(2))
293*da0073e9SAndroid Build Coastguard Worker        v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0]))
294*da0073e9SAndroid Build Coastguard Worker        out, vjp_val = autogradF.vjp(adder, inputs, v)
295*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(vjp_val, inputs)
296*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out[0].grad_fn)
297*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out[1].grad_fn)
298*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(vjp_val[0].grad_fn)
299*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(vjp_val[1].grad_fn)
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
302*da0073e9SAndroid Build Coastguard Worker    def test_vjp_scalar(self, ctors):
303*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
304*da0073e9SAndroid Build Coastguard Worker            return x.sum()
305*da0073e9SAndroid Build Coastguard Worker
306*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
307*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones([])
308*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(reducer, inputs, v)
309*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[0], v)
310*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(reducer, inputs)
313*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[0], v)
314*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        def expander(x):
317*da0073e9SAndroid Build Coastguard Worker            return x.unsqueeze(0).repeat(4)
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand([])
320*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4)
321*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(expander, inputs, v)
322*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[0], v)
323*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
326*da0073e9SAndroid Build Coastguard Worker    def test_vjp_create_graph(self, ctors):
327*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
328*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=1)
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(2, 2, dtype=torch.double)
331*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(2, dtype=torch.double)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker        inputs.requires_grad_()
334*da0073e9SAndroid Build Coastguard Worker        v.requires_grad_()
335*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vjp(reducer, inputs, v, create_graph=True)
336*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
337*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
338*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker        gradcheck(
341*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True),
342*da0073e9SAndroid Build Coastguard Worker            (inputs, v),
343*da0073e9SAndroid Build Coastguard Worker        )
344*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
345*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True),
346*da0073e9SAndroid Build Coastguard Worker            (inputs, v),
347*da0073e9SAndroid Build Coastguard Worker        )
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
350*da0073e9SAndroid Build Coastguard Worker            return 2 * x + 3 * y, x * y
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        inputs = (
353*da0073e9SAndroid Build Coastguard Worker            ctors.rand(2, dtype=torch.double, requires_grad=True),
354*da0073e9SAndroid Build Coastguard Worker            ctors.rand(2, dtype=torch.double, requires_grad=True),
355*da0073e9SAndroid Build Coastguard Worker        )
356*da0073e9SAndroid Build Coastguard Worker        v = (
357*da0073e9SAndroid Build Coastguard Worker            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
358*da0073e9SAndroid Build Coastguard Worker            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
359*da0073e9SAndroid Build Coastguard Worker        )
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        gradcheck(
362*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[
363*da0073e9SAndroid Build Coastguard Worker                1
364*da0073e9SAndroid Build Coastguard Worker            ],
365*da0073e9SAndroid Build Coastguard Worker            inputs + v,
366*da0073e9SAndroid Build Coastguard Worker        )
367*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
368*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[
369*da0073e9SAndroid Build Coastguard Worker                1
370*da0073e9SAndroid Build Coastguard Worker            ],
371*da0073e9SAndroid Build Coastguard Worker            inputs + v,
372*da0073e9SAndroid Build Coastguard Worker        )
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        def foo(*args):
375*da0073e9SAndroid Build Coastguard Worker            x, y = args[:2]
376*da0073e9SAndroid Build Coastguard Worker            v = args[2:]
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker            x = x.cos()
379*da0073e9SAndroid Build Coastguard Worker            val, grad = autogradF.vjp(adder, (x, y), v, create_graph=True)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker            return (
382*da0073e9SAndroid Build Coastguard Worker                val[0].exp()
383*da0073e9SAndroid Build Coastguard Worker                + val[1].exp()
384*da0073e9SAndroid Build Coastguard Worker                + grad[0].exp()
385*da0073e9SAndroid Build Coastguard Worker                + grad[1].exp()
386*da0073e9SAndroid Build Coastguard Worker                + x.exp()
387*da0073e9SAndroid Build Coastguard Worker                + y.exp()
388*da0073e9SAndroid Build Coastguard Worker            )
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, inputs + v)
391*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(foo, inputs + v)
392*da0073e9SAndroid Build Coastguard Worker
393*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
394*da0073e9SAndroid Build Coastguard Worker    def test_jvp_err_check(self, ctors):
395*da0073e9SAndroid Build Coastguard Worker        def foo(a):
396*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3)
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker        def bar(a):
399*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), "bar"
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
402*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
403*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
404*da0073e9SAndroid Build Coastguard Worker            TypeError, "The inputs given to jvp must be either a Tensor"
405*da0073e9SAndroid Build Coastguard Worker        ):
406*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(foo, (inp, 2), v)
407*da0073e9SAndroid Build Coastguard Worker
408*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
409*da0073e9SAndroid Build Coastguard Worker            TypeError, "The outputs of the user-provided function given to jvp must"
410*da0073e9SAndroid Build Coastguard Worker        ):
411*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(bar, inp, v)
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
414*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
415*da0073e9SAndroid Build Coastguard Worker            "The vector v can only be None if the input to the user-provided function",
416*da0073e9SAndroid Build Coastguard Worker        ):
417*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(foo, inp)
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
420*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "The given v should contain a single Tensor."
421*da0073e9SAndroid Build Coastguard Worker        ):
422*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(foo, inp, (v, v))
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
425*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "v has invalid size: should be torch.Size"
426*da0073e9SAndroid Build Coastguard Worker        ):
427*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(foo, inp, v[:2])
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(foo, inp, v)[1]
430*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res, foo(inp))
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
433*da0073e9SAndroid Build Coastguard Worker    def test_jvp_err_check_strict(self, ctors):
434*da0073e9SAndroid Build Coastguard Worker        def foo(a):
435*da0073e9SAndroid Build Coastguard Worker            return a.detach()
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker        def bar(a):
438*da0073e9SAndroid Build Coastguard Worker            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
439*da0073e9SAndroid Build Coastguard Worker            return a.long().float().requires_grad_().clone()
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
442*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
443*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
444*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
445*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function does not require gradients.",
446*da0073e9SAndroid Build Coastguard Worker        ):
447*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(foo, inp, v, strict=True)
448*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(foo, inp, v, strict=False)
449*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
450*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
451*da0073e9SAndroid Build Coastguard Worker
452*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
453*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
454*da0073e9SAndroid Build Coastguard Worker            "The output of the user-provided function is independent of input 0",
455*da0073e9SAndroid Build Coastguard Worker        ):
456*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(bar, inp, v, strict=True)
457*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(bar, inp, v, strict=False)
458*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
459*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker        # The Jacobian does not depend on the input
462*da0073e9SAndroid Build Coastguard Worker        def foo(a):
463*da0073e9SAndroid Build Coastguard Worker            return a.clone()
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker        inp.requires_grad_()
466*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
467*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
468*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function is independent of input 0.",
469*da0073e9SAndroid Build Coastguard Worker        ):
470*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(foo, inp, v, create_graph=True, strict=True)
471*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(foo, inp, v, create_graph=True, strict=False)
472*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1], v)
474*da0073e9SAndroid Build Coastguard Worker
475*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
476*da0073e9SAndroid Build Coastguard Worker    def test_jvp_no_grad(self, ctors):
477*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
478*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=1)
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
481*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
482*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
483*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(reducer, inputs, v)
484*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
485*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
486*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker        inputs.requires_grad_()
489*da0073e9SAndroid Build Coastguard Worker        v.requires_grad_()
490*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
491*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jvp(reducer, inputs, v, create_graph=True)
492*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
493*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
494*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
497*da0073e9SAndroid Build Coastguard Worker    def test_jvp_output(self, ctors):
498*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
499*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=1)
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
502*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
503*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(reducer, inputs, v)
504*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
505*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
506*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
509*da0073e9SAndroid Build Coastguard Worker            return 2 * x + 3 * y
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(2), ctors.rand(2))
512*da0073e9SAndroid Build Coastguard Worker        v = (ctors.ones(2), ctors.ones(2))
513*da0073e9SAndroid Build Coastguard Worker        out, jvp_val = autogradF.jvp(adder, inputs, v)
514*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(jvp_val, out)
515*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn)
516*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(jvp_val[0].grad_fn)
517*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(jvp_val[1].grad_fn)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
520*da0073e9SAndroid Build Coastguard Worker            return 2 * x + 3 * y, x + y
521*da0073e9SAndroid Build Coastguard Worker
522*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(2), ctors.rand(2))
523*da0073e9SAndroid Build Coastguard Worker        v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0]))
524*da0073e9SAndroid Build Coastguard Worker        out, jvp_val = autogradF.jvp(adder, inputs, v)
525*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(jvp_val, out)
526*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out[0].grad_fn)
527*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out[1].grad_fn)
528*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(jvp_val[0].grad_fn)
529*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(jvp_val[1].grad_fn)
530*da0073e9SAndroid Build Coastguard Worker
531*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
532*da0073e9SAndroid Build Coastguard Worker    def test_jvp_scalar(self, ctors):
533*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
534*da0073e9SAndroid Build Coastguard Worker            return x.sum()
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
537*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
538*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(reducer, inputs, v)
539*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[0], ctors.zeros([]))
540*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
541*da0073e9SAndroid Build Coastguard Worker
542*da0073e9SAndroid Build Coastguard Worker        def expander(x):
543*da0073e9SAndroid Build Coastguard Worker            return x.unsqueeze(0).repeat(4)
544*da0073e9SAndroid Build Coastguard Worker
545*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand([])
546*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones([])
547*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(expander, inputs, v)
548*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[0], ctors.zeros(4))
549*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(expander, inputs)
552*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[0], ctors.zeros(4))
553*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
556*da0073e9SAndroid Build Coastguard Worker    def test_jvp_create_graph(self, ctors):
557*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
558*da0073e9SAndroid Build Coastguard Worker            return x.sum(dim=1)
559*da0073e9SAndroid Build Coastguard Worker
560*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(2, 2, dtype=torch.double)
561*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(2, 2, dtype=torch.double)
562*da0073e9SAndroid Build Coastguard Worker
563*da0073e9SAndroid Build Coastguard Worker        inputs.requires_grad_()
564*da0073e9SAndroid Build Coastguard Worker        v.requires_grad_()
565*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jvp(reducer, inputs, v, create_graph=True)
566*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], res[0])
567*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
568*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker        gradcheck(
571*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True),
572*da0073e9SAndroid Build Coastguard Worker            (inputs, v),
573*da0073e9SAndroid Build Coastguard Worker        )
574*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
575*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True),
576*da0073e9SAndroid Build Coastguard Worker            (inputs, v),
577*da0073e9SAndroid Build Coastguard Worker        )
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker        def adder(x, y):
580*da0073e9SAndroid Build Coastguard Worker            return 2 * x + 3 * y, x * y
581*da0073e9SAndroid Build Coastguard Worker
582*da0073e9SAndroid Build Coastguard Worker        inputs = (
583*da0073e9SAndroid Build Coastguard Worker            ctors.rand(2, dtype=torch.double, requires_grad=True),
584*da0073e9SAndroid Build Coastguard Worker            ctors.rand(2, dtype=torch.double, requires_grad=True),
585*da0073e9SAndroid Build Coastguard Worker        )
586*da0073e9SAndroid Build Coastguard Worker        v = (
587*da0073e9SAndroid Build Coastguard Worker            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
588*da0073e9SAndroid Build Coastguard Worker            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
589*da0073e9SAndroid Build Coastguard Worker        )
590*da0073e9SAndroid Build Coastguard Worker
591*da0073e9SAndroid Build Coastguard Worker        gradcheck(
592*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[
593*da0073e9SAndroid Build Coastguard Worker                1
594*da0073e9SAndroid Build Coastguard Worker            ],
595*da0073e9SAndroid Build Coastguard Worker            inputs + v,
596*da0073e9SAndroid Build Coastguard Worker        )
597*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
598*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[
599*da0073e9SAndroid Build Coastguard Worker                1
600*da0073e9SAndroid Build Coastguard Worker            ],
601*da0073e9SAndroid Build Coastguard Worker            inputs + v,
602*da0073e9SAndroid Build Coastguard Worker        )
603*da0073e9SAndroid Build Coastguard Worker
604*da0073e9SAndroid Build Coastguard Worker        def foo(*args):
605*da0073e9SAndroid Build Coastguard Worker            x, y = args[:2]
606*da0073e9SAndroid Build Coastguard Worker            v = args[2:]
607*da0073e9SAndroid Build Coastguard Worker
608*da0073e9SAndroid Build Coastguard Worker            x = x.cos()
609*da0073e9SAndroid Build Coastguard Worker            val, grad = autogradF.jvp(adder, (x, y), v, create_graph=True)
610*da0073e9SAndroid Build Coastguard Worker
611*da0073e9SAndroid Build Coastguard Worker            return (
612*da0073e9SAndroid Build Coastguard Worker                val[0].exp()
613*da0073e9SAndroid Build Coastguard Worker                + val[1].exp()
614*da0073e9SAndroid Build Coastguard Worker                + grad[0].exp()
615*da0073e9SAndroid Build Coastguard Worker                + grad[1].exp()
616*da0073e9SAndroid Build Coastguard Worker                + x.exp()
617*da0073e9SAndroid Build Coastguard Worker                + y.exp()
618*da0073e9SAndroid Build Coastguard Worker            )
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, inputs + v)
621*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(foo, inputs + v)
622*da0073e9SAndroid Build Coastguard Worker
623*da0073e9SAndroid Build Coastguard Worker    def _test_construct_standard_basis_for(self, inputs):
624*da0073e9SAndroid Build Coastguard Worker        numels = tuple(tensor.numel() for tensor in inputs)
625*da0073e9SAndroid Build Coastguard Worker        results = autogradF._construct_standard_basis_for(inputs, numels)
626*da0073e9SAndroid Build Coastguard Worker        for result, inp in zip(results, inputs):
627*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.dtype, inp.dtype)
628*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result.device, inp.device)
629*da0073e9SAndroid Build Coastguard Worker        results = torch.cat(
630*da0073e9SAndroid Build Coastguard Worker            [result.to(device="cpu", dtype=torch.float) for result in results], dim=1
631*da0073e9SAndroid Build Coastguard Worker        )
632*da0073e9SAndroid Build Coastguard Worker        expected = torch.eye(results[0].shape[0], dtype=torch.float)
633*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(results, expected)
634*da0073e9SAndroid Build Coastguard Worker
635*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
636*da0073e9SAndroid Build Coastguard Worker    def test_construct_standard_basis_for(self, ctors):
637*da0073e9SAndroid Build Coastguard Worker        test_cases = [
638*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(2, 3),),
639*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(1),),
640*da0073e9SAndroid Build Coastguard Worker            (ctors.randn([]),),
641*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(1), ctors.randn([]), ctors.randn([])),
642*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(2), ctors.randn(3), ctors.randn([])),
643*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(2), ctors.randn([]), ctors.randn(3)),
644*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(2, 3), ctors.randn(3), ctors.randn(3, 4, 2)),
645*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(2, dtype=torch.float64), ctors.randn(3, dtype=torch.float32)),
646*da0073e9SAndroid Build Coastguard Worker        ]
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker        for inputs in test_cases:
649*da0073e9SAndroid Build Coastguard Worker            self._test_construct_standard_basis_for(inputs)
650*da0073e9SAndroid Build Coastguard Worker
651*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
652*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
653*da0073e9SAndroid Build Coastguard Worker    def test_construct_standard_basis_for_cuda(self, ctors):
654*da0073e9SAndroid Build Coastguard Worker        test_cases = [
655*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(2), ctors.randn(3, device="cuda")),
656*da0073e9SAndroid Build Coastguard Worker            (ctors.randn(3, device="cuda"), ctors.randn(2)),
657*da0073e9SAndroid Build Coastguard Worker        ]
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker        for inputs in test_cases:
660*da0073e9SAndroid Build Coastguard Worker            self._test_construct_standard_basis_for(inputs)
661*da0073e9SAndroid Build Coastguard Worker
662*da0073e9SAndroid Build Coastguard Worker    def _test_vectorize_raises_no_warnings(self, api, ctors):
663*da0073e9SAndroid Build Coastguard Worker        # vmap is an experimental prototype. When someone calls torch.vmap,
664*da0073e9SAndroid Build Coastguard Worker        # it raises a python warning. This test checks that
665*da0073e9SAndroid Build Coastguard Worker        # autogradF.{jacobian, hessian} don't raise that experimental prototype
666*da0073e9SAndroid Build Coastguard Worker        # warning; it is not nice for a public-facing API to raise a warning
667*da0073e9SAndroid Build Coastguard Worker        # no matter how it is called.
668*da0073e9SAndroid Build Coastguard Worker        def foo(a):
669*da0073e9SAndroid Build Coastguard Worker            return (a**2).sum()
670*da0073e9SAndroid Build Coastguard Worker
671*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(3)
672*da0073e9SAndroid Build Coastguard Worker        with warnings.catch_warnings(record=True) as wa:
673*da0073e9SAndroid Build Coastguard Worker            result = api(foo, x, vectorize=True)
674*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(len(wa), 0)
675*da0073e9SAndroid Build Coastguard Worker
676*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
677*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_raises_no_warnings(self, ctors):
678*da0073e9SAndroid Build Coastguard Worker        return self._test_vectorize_raises_no_warnings(autogradF.jacobian, ctors)
679*da0073e9SAndroid Build Coastguard Worker
680*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
681*da0073e9SAndroid Build Coastguard Worker    def test_hessian_vectorize_raises_no_warnings(self, ctors):
682*da0073e9SAndroid Build Coastguard Worker        return self._test_vectorize_raises_no_warnings(autogradF.hessian, ctors)
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Worker    @parametrize("vectorize", [True, False])
685*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
686*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_err_check(self, vectorize, ctors):
687*da0073e9SAndroid Build Coastguard Worker        def foo(a):
688*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3)
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        def bar(a):
691*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), "bar"
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
694*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
695*da0073e9SAndroid Build Coastguard Worker            TypeError, "The inputs given to jacobian must be either a Tensor"
696*da0073e9SAndroid Build Coastguard Worker        ):
697*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(foo, (inp, 2), vectorize=vectorize)
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
700*da0073e9SAndroid Build Coastguard Worker            TypeError,
701*da0073e9SAndroid Build Coastguard Worker            "The outputs of the user-provided function given to jacobian must",
702*da0073e9SAndroid Build Coastguard Worker        ):
703*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(bar, inp, vectorize=vectorize)
704*da0073e9SAndroid Build Coastguard Worker
705*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(foo, inp, vectorize=vectorize)
706*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, foo(inp), inp)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
709*da0073e9SAndroid Build Coastguard Worker            return b, 3 * a.narrow(0, 0, 3)
710*da0073e9SAndroid Build Coastguard Worker
711*da0073e9SAndroid Build Coastguard Worker        inp = (ctors.rand(4), ctors.rand(5))
712*da0073e9SAndroid Build Coastguard Worker
713*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(foo, inp, vectorize=vectorize)
714*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, foo(*inp), inp)
715*da0073e9SAndroid Build Coastguard Worker
716*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
717*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_err_check_strict(self, ctors):
718*da0073e9SAndroid Build Coastguard Worker        def foo(a):
719*da0073e9SAndroid Build Coastguard Worker            return a.detach()
720*da0073e9SAndroid Build Coastguard Worker
721*da0073e9SAndroid Build Coastguard Worker        def bar(a):
722*da0073e9SAndroid Build Coastguard Worker            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
723*da0073e9SAndroid Build Coastguard Worker            return a.long().float().requires_grad_().clone()
724*da0073e9SAndroid Build Coastguard Worker
725*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
726*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
727*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
728*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function does not require gradients.",
729*da0073e9SAndroid Build Coastguard Worker        ):
730*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(foo, inp, strict=True)
731*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(foo, inp, strict=False)
732*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, foo(inp), inp)
733*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.abs().sum(), 0.0)
734*da0073e9SAndroid Build Coastguard Worker
735*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
736*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
737*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function is independent of input 0.",
738*da0073e9SAndroid Build Coastguard Worker        ):
739*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(bar, inp, strict=True)
740*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(bar, inp, strict=False)
741*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, foo(inp), inp)
742*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.abs().sum(), 0.0)
743*da0073e9SAndroid Build Coastguard Worker
744*da0073e9SAndroid Build Coastguard Worker        # The Jacobian does not depend on the input
745*da0073e9SAndroid Build Coastguard Worker        def foo(a):
746*da0073e9SAndroid Build Coastguard Worker            return a.clone()
747*da0073e9SAndroid Build Coastguard Worker
748*da0073e9SAndroid Build Coastguard Worker        inp.requires_grad_()
749*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
750*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
751*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function is independent of input 0.",
752*da0073e9SAndroid Build Coastguard Worker        ):
753*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(foo, inp, create_graph=True, strict=True)
754*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(foo, inp, create_graph=True, strict=False)
755*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inp, inp)
756*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, torch.eye(4))
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
759*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_err_check_strict_vectorize(self, ctors):
760*da0073e9SAndroid Build Coastguard Worker        def foo(x):
761*da0073e9SAndroid Build Coastguard Worker            return x
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
764*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "not supported together"):
765*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(foo, inp, strict=True, vectorize=True)
766*da0073e9SAndroid Build Coastguard Worker
767*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
768*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_no_grad(self, ctors):
769*da0073e9SAndroid Build Coastguard Worker        def exp_reducer(x):
770*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum(dim=1)
771*da0073e9SAndroid Build Coastguard Worker
772*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
773*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
774*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(exp_reducer, inputs)
775*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res.grad_fn)
776*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res, ctors.zeros(4, 4))
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
779*da0073e9SAndroid Build Coastguard Worker            res = autogradF.jacobian(exp_reducer, inputs, create_graph=True)
780*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res.grad_fn)
781*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res, ctors.zeros(4, 4))
782*da0073e9SAndroid Build Coastguard Worker
783*da0073e9SAndroid Build Coastguard Worker    @vectorized_logging_tensor
784*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_output(self, vectorize, ctors):
785*da0073e9SAndroid Build Coastguard Worker        def exp_reducer(x):
786*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum(dim=1)
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
789*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(exp_reducer, inputs, vectorize=vectorize)
790*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, exp_reducer(inputs), inputs)
791*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res.grad_fn)
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker        def identity(x):
794*da0073e9SAndroid Build Coastguard Worker            return x.clone()
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4)
797*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(identity, inputs, vectorize=vectorize)
798*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, identity(inputs), inputs)
799*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res.grad_fn)
800*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, torch.eye(4))
801*da0073e9SAndroid Build Coastguard Worker
802*da0073e9SAndroid Build Coastguard Worker        def add_exp_reducer(x, y):
803*da0073e9SAndroid Build Coastguard Worker            return (x + y.exp()).sum(dim=1)
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(4, 4), ctors.rand(4, 4))
806*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(add_exp_reducer, inputs, vectorize=vectorize)
807*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs)
808*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
809*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
810*da0073e9SAndroid Build Coastguard Worker
811*da0073e9SAndroid Build Coastguard Worker    @vectorized_logging_tensor
812*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_scalar(self, vectorize, ctors):
813*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
814*da0073e9SAndroid Build Coastguard Worker            return x.sum()
815*da0073e9SAndroid Build Coastguard Worker
816*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
817*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(reducer, inputs, vectorize=vectorize)
818*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res, inputs)
819*da0073e9SAndroid Build Coastguard Worker
820*da0073e9SAndroid Build Coastguard Worker        def expander(x):
821*da0073e9SAndroid Build Coastguard Worker            return x.unsqueeze(0).repeat(4)
822*da0073e9SAndroid Build Coastguard Worker
823*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand([])
824*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(expander, inputs, vectorize=vectorize)
825*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res, ctors.zeros(4))
826*da0073e9SAndroid Build Coastguard Worker
827*da0073e9SAndroid Build Coastguard Worker    @parametrize("vectorize", [True, False])
828*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
829*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_create_graph(self, vectorize, ctors):
830*da0073e9SAndroid Build Coastguard Worker        def exp_reducer(x):
831*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum(dim=1)
832*da0073e9SAndroid Build Coastguard Worker
833*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True)
834*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(
835*da0073e9SAndroid Build Coastguard Worker            exp_reducer, inputs, create_graph=True, vectorize=vectorize
836*da0073e9SAndroid Build Coastguard Worker        )
837*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, exp_reducer(inputs), inputs)
838*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res.grad_fn)
839*da0073e9SAndroid Build Coastguard Worker
840*da0073e9SAndroid Build Coastguard Worker        gradcheck(
841*da0073e9SAndroid Build Coastguard Worker            lambda inp: autogradF.jacobian(
842*da0073e9SAndroid Build Coastguard Worker                exp_reducer, inp, create_graph=True, vectorize=vectorize
843*da0073e9SAndroid Build Coastguard Worker            ),
844*da0073e9SAndroid Build Coastguard Worker            inputs,
845*da0073e9SAndroid Build Coastguard Worker        )
846*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
847*da0073e9SAndroid Build Coastguard Worker            lambda inp: autogradF.jacobian(
848*da0073e9SAndroid Build Coastguard Worker                exp_reducer, inp, create_graph=True, vectorize=vectorize
849*da0073e9SAndroid Build Coastguard Worker            ),
850*da0073e9SAndroid Build Coastguard Worker            inputs,
851*da0073e9SAndroid Build Coastguard Worker        )
852*da0073e9SAndroid Build Coastguard Worker
853*da0073e9SAndroid Build Coastguard Worker        def add_exp_reducer(x, y):
854*da0073e9SAndroid Build Coastguard Worker            return (x + y).exp().sum(dim=1)
855*da0073e9SAndroid Build Coastguard Worker
856*da0073e9SAndroid Build Coastguard Worker        inputs = (
857*da0073e9SAndroid Build Coastguard Worker            ctors.rand(4, 4, dtype=torch.double, requires_grad=True),
858*da0073e9SAndroid Build Coastguard Worker            ctors.rand(4, 4, dtype=torch.double, requires_grad=True),
859*da0073e9SAndroid Build Coastguard Worker        )
860*da0073e9SAndroid Build Coastguard Worker        res = autogradF.jacobian(
861*da0073e9SAndroid Build Coastguard Worker            add_exp_reducer, inputs, create_graph=True, vectorize=vectorize
862*da0073e9SAndroid Build Coastguard Worker        )
863*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs)
864*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
865*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
866*da0073e9SAndroid Build Coastguard Worker
867*da0073e9SAndroid Build Coastguard Worker        gradcheck(
868*da0073e9SAndroid Build Coastguard Worker            lambda *inp: autogradF.jacobian(
869*da0073e9SAndroid Build Coastguard Worker                add_exp_reducer, inp, create_graph=True, vectorize=vectorize
870*da0073e9SAndroid Build Coastguard Worker            ),
871*da0073e9SAndroid Build Coastguard Worker            inputs,
872*da0073e9SAndroid Build Coastguard Worker        )
873*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
874*da0073e9SAndroid Build Coastguard Worker            lambda *inp: autogradF.jacobian(
875*da0073e9SAndroid Build Coastguard Worker                add_exp_reducer, inp, create_graph=True, vectorize=vectorize
876*da0073e9SAndroid Build Coastguard Worker            ),
877*da0073e9SAndroid Build Coastguard Worker            inputs,
878*da0073e9SAndroid Build Coastguard Worker        )
879*da0073e9SAndroid Build Coastguard Worker
880*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
881*da0073e9SAndroid Build Coastguard Worker            x = x.cos()
882*da0073e9SAndroid Build Coastguard Worker            val, jac = autogradF.jacobian(
883*da0073e9SAndroid Build Coastguard Worker                add_exp_reducer, (x, y), create_graph=True, vectorize=vectorize
884*da0073e9SAndroid Build Coastguard Worker            )
885*da0073e9SAndroid Build Coastguard Worker
886*da0073e9SAndroid Build Coastguard Worker            res = val[0].exp().sum() + val[1].exp().sum() + jac[0].exp().sum()
887*da0073e9SAndroid Build Coastguard Worker            res = res + jac[1].exp().sum() + x.exp().sum() + y.exp().sum()
888*da0073e9SAndroid Build Coastguard Worker            return res
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, inputs)
891*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(foo, inputs)
892*da0073e9SAndroid Build Coastguard Worker
893*da0073e9SAndroid Build Coastguard Worker    def _check_jacobian_vectorize_correctness(self, f, inputs, test_forward_ad=True):
894*da0073e9SAndroid Build Coastguard Worker        expected = autogradF.jacobian(f, inputs, vectorize=False)
895*da0073e9SAndroid Build Coastguard Worker        result_backward_mode = autogradF.jacobian(f, inputs, vectorize=True)
896*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_backward_mode, expected)
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker        if test_forward_ad:
899*da0073e9SAndroid Build Coastguard Worker            result_forward_mode = autogradF.jacobian(
900*da0073e9SAndroid Build Coastguard Worker                f, inputs, strategy="forward-mode", vectorize=True
901*da0073e9SAndroid Build Coastguard Worker            )
902*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(result_forward_mode, expected)
903*da0073e9SAndroid Build Coastguard Worker
904*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
905*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_simple(self, ctors):
906*da0073e9SAndroid Build Coastguard Worker        def f(x):
907*da0073e9SAndroid Build Coastguard Worker            return 3 * x**2
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2, 3, 5)
910*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, x)
911*da0073e9SAndroid Build Coastguard Worker
912*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
913*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_multi_input(self, ctors):
914*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
915*da0073e9SAndroid Build Coastguard Worker            return (x.cos() * x) @ y.sin()
916*da0073e9SAndroid Build Coastguard Worker
917*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2, 3)
918*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3, 5)
919*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, (x, y))
920*da0073e9SAndroid Build Coastguard Worker
921*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
922*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_multi_input_multi_output(self, ctors):
923*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
924*da0073e9SAndroid Build Coastguard Worker            return (x * x) @ y, x @ (x.sum(1) * y), y.sum()
925*da0073e9SAndroid Build Coastguard Worker
926*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(5, 3)
927*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3, 5)
928*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, (x, y))
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
931*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_unrelated_outputs(self, ctors):
932*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
933*da0073e9SAndroid Build Coastguard Worker            return x, y, x, y
934*da0073e9SAndroid Build Coastguard Worker
935*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2)
936*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3)
937*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, (x, y))
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
940*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_zero_dim(self, ctors):
941*da0073e9SAndroid Build Coastguard Worker        # zero-dim output
942*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
943*da0073e9SAndroid Build Coastguard Worker            return x.sum(), y.sum(), x * y
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(3)
946*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3)
947*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, (x, y))
948*da0073e9SAndroid Build Coastguard Worker
949*da0073e9SAndroid Build Coastguard Worker        # zero-dim input
950*da0073e9SAndroid Build Coastguard Worker        def g(x):
951*da0073e9SAndroid Build Coastguard Worker            return torch.stack([x, x, x])
952*da0073e9SAndroid Build Coastguard Worker
953*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn([])
954*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(g, x)
955*da0073e9SAndroid Build Coastguard Worker
956*da0073e9SAndroid Build Coastguard Worker        # Mixed zero-dim input / zero-dim output
957*da0073e9SAndroid Build Coastguard Worker        def h(x, y):
958*da0073e9SAndroid Build Coastguard Worker            return y.sum(), x * y
959*da0073e9SAndroid Build Coastguard Worker
960*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn([])
961*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(1)
962*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(h, (x, y))
963*da0073e9SAndroid Build Coastguard Worker
964*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
965*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
966*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_different_devices(self, ctors):
967*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
968*da0073e9SAndroid Build Coastguard Worker            return x * y, (x * y).cuda()
969*da0073e9SAndroid Build Coastguard Worker
970*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(3)
971*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3)
972*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, (x, y))
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
975*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_vectorize_correctness_different_dtype(self, ctors):
976*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
977*da0073e9SAndroid Build Coastguard Worker            return (x * y).float(), (x * y).double()
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(3)
980*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3)
981*da0073e9SAndroid Build Coastguard Worker        # The Jacobian computed using forward AD has the dtype of the output
982*da0073e9SAndroid Build Coastguard Worker        # but the Jacobian computed with reverse AD has dtype of input
983*da0073e9SAndroid Build Coastguard Worker        self._check_jacobian_vectorize_correctness(f, (x, y), test_forward_ad=False)
984*da0073e9SAndroid Build Coastguard Worker
985*da0073e9SAndroid Build Coastguard Worker    def _check_hessian_vectorize_correctness(self, f, inputs):
986*da0073e9SAndroid Build Coastguard Worker        expected = autogradF.hessian(f, inputs, vectorize=False)
987*da0073e9SAndroid Build Coastguard Worker        result = autogradF.hessian(f, inputs, vectorize=True)
988*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result, expected)
989*da0073e9SAndroid Build Coastguard Worker
990*da0073e9SAndroid Build Coastguard Worker        result_forward_mode = autogradF.hessian(
991*da0073e9SAndroid Build Coastguard Worker            f, inputs, outer_jacobian_strategy="forward-mode", vectorize=True
992*da0073e9SAndroid Build Coastguard Worker        )
993*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(result_forward_mode, expected)
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
996*da0073e9SAndroid Build Coastguard Worker    def test_hessian_vectorize_correctness_simple(self, ctors):
997*da0073e9SAndroid Build Coastguard Worker        def f(x):
998*da0073e9SAndroid Build Coastguard Worker            return (3 * x**2).sum()
999*da0073e9SAndroid Build Coastguard Worker
1000*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2, 3, 5)
1001*da0073e9SAndroid Build Coastguard Worker        self._check_hessian_vectorize_correctness(f, x)
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1004*da0073e9SAndroid Build Coastguard Worker    def test_hessian_vectorize_correctness_multi_input(self, ctors):
1005*da0073e9SAndroid Build Coastguard Worker        def f(x, y, z):
1006*da0073e9SAndroid Build Coastguard Worker            return ((x.relu() * x) @ y.sin() @ z).sum()
1007*da0073e9SAndroid Build Coastguard Worker
1008*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2, 3)
1009*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3, 5)
1010*da0073e9SAndroid Build Coastguard Worker        z = ctors.randn(5, 5)
1011*da0073e9SAndroid Build Coastguard Worker        self._check_hessian_vectorize_correctness(f, (x, y, z))
1012*da0073e9SAndroid Build Coastguard Worker
1013*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1014*da0073e9SAndroid Build Coastguard Worker    def test_hessian_vectorize_correctness_unrelated_outputs(self, ctors):
1015*da0073e9SAndroid Build Coastguard Worker        # output unrelated to one input
1016*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1017*da0073e9SAndroid Build Coastguard Worker            return (x**2).sum()
1018*da0073e9SAndroid Build Coastguard Worker
1019*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2)
1020*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3)
1021*da0073e9SAndroid Build Coastguard Worker        self._check_hessian_vectorize_correctness(f, (x, y))
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker        # output unrelated to all inputs
1024*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1025*da0073e9SAndroid Build Coastguard Worker            return ctors.ones([])
1026*da0073e9SAndroid Build Coastguard Worker
1027*da0073e9SAndroid Build Coastguard Worker        x = ctors.randn(2)
1028*da0073e9SAndroid Build Coastguard Worker        y = ctors.randn(3)
1029*da0073e9SAndroid Build Coastguard Worker        self._check_hessian_vectorize_correctness(f, (x, y))
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    @parametrize("vectorize", [True, False])
1032*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1033*da0073e9SAndroid Build Coastguard Worker    def test_hessian_err_check(self, vectorize, ctors):
1034*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1035*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1036*da0073e9SAndroid Build Coastguard Worker
1037*da0073e9SAndroid Build Coastguard Worker        def bar(a):
1038*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), "bar"
1039*da0073e9SAndroid Build Coastguard Worker
1040*da0073e9SAndroid Build Coastguard Worker        def bar2(a):
1041*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3)
1042*da0073e9SAndroid Build Coastguard Worker
1043*da0073e9SAndroid Build Coastguard Worker        def bar3(a):
1044*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), 3 * a.narrow(0, 0, 3)
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1047*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1048*da0073e9SAndroid Build Coastguard Worker            TypeError, "The inputs given to hessian must be either a Tensor"
1049*da0073e9SAndroid Build Coastguard Worker        ):
1050*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(foo, (inp, 2), vectorize=vectorize)
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1053*da0073e9SAndroid Build Coastguard Worker            TypeError, "The outputs of the user-provided function given to hessian must"
1054*da0073e9SAndroid Build Coastguard Worker        ):
1055*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(bar, inp, vectorize=vectorize)
1056*da0073e9SAndroid Build Coastguard Worker
1057*da0073e9SAndroid Build Coastguard Worker        err_msg_out = "The Tensor returned by the function given to hessian should contain a single element"
1058*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg_out):
1059*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(bar2, inp, vectorize=vectorize)
1060*da0073e9SAndroid Build Coastguard Worker
1061*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1062*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "The function given to hessian should return a single Tensor"
1063*da0073e9SAndroid Build Coastguard Worker        ):
1064*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(bar3, inp, vectorize=vectorize)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(foo, inp, vectorize=vectorize)
1067*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inp, inp)
1068*da0073e9SAndroid Build Coastguard Worker
1069*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
1070*da0073e9SAndroid Build Coastguard Worker            return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum()
1071*da0073e9SAndroid Build Coastguard Worker
1072*da0073e9SAndroid Build Coastguard Worker        inp = (ctors.rand(4), ctors.rand(5))
1073*da0073e9SAndroid Build Coastguard Worker
1074*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(foo, inp, vectorize=vectorize)
1075*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inp, inp)
1076*da0073e9SAndroid Build Coastguard Worker
1077*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1078*da0073e9SAndroid Build Coastguard Worker    def test_hessian_err_check_strict(self, ctors):
1079*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1080*da0073e9SAndroid Build Coastguard Worker            return a.detach().sum()
1081*da0073e9SAndroid Build Coastguard Worker
1082*da0073e9SAndroid Build Coastguard Worker        def bar(a):
1083*da0073e9SAndroid Build Coastguard Worker            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
1084*da0073e9SAndroid Build Coastguard Worker            return a.long().float().requires_grad_().clone().sum()
1085*da0073e9SAndroid Build Coastguard Worker
1086*da0073e9SAndroid Build Coastguard Worker        def bar2(a):
1087*da0073e9SAndroid Build Coastguard Worker            # A Linear function for which the jacobian is independent of the input
1088*da0073e9SAndroid Build Coastguard Worker            return (3 * a).sum()
1089*da0073e9SAndroid Build Coastguard Worker
1090*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1091*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1092*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1093*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function does not require gradients.",
1094*da0073e9SAndroid Build Coastguard Worker        ):
1095*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(foo, inp, strict=True)
1096*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(foo, inp, strict=False)
1097*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inp, inp)
1098*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.abs().sum(), 0.0)
1099*da0073e9SAndroid Build Coastguard Worker
1100*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1101*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1102*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function with respect to input 0",
1103*da0073e9SAndroid Build Coastguard Worker        ):
1104*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(bar, inp, strict=True)
1105*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(bar, inp, strict=False)
1106*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inp, inp)
1107*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.abs().sum(), 0.0)
1108*da0073e9SAndroid Build Coastguard Worker
1109*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1110*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1111*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function with respect to input 0 is",
1112*da0073e9SAndroid Build Coastguard Worker        ):
1113*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(bar2, inp, strict=True)
1114*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(bar2, inp, strict=False)
1115*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inp, inp)
1116*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res.abs().sum(), 0.0)
1117*da0073e9SAndroid Build Coastguard Worker
1118*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1119*da0073e9SAndroid Build Coastguard Worker    def test_hessian_err_check_strict_vectorize(self, ctors):
1120*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1121*da0073e9SAndroid Build Coastguard Worker            return (x**3).sum()
1122*da0073e9SAndroid Build Coastguard Worker
1123*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1124*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "not supported together"):
1125*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(foo, inp, strict=True, vectorize=True)
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1128*da0073e9SAndroid Build Coastguard Worker    def test_hessian_no_grad(self, ctors):
1129*da0073e9SAndroid Build Coastguard Worker        def pow_reducer(x):
1130*da0073e9SAndroid Build Coastguard Worker            return x.pow(3).sum()
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(2, 2)
1133*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1134*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(pow_reducer, inputs)
1135*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0][0].grad_fn)
1136*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0][1].grad_fn)
1137*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1][0].grad_fn)
1138*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1][1].grad_fn)
1139*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res, ctors.zeros(2, 2, 2))
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1142*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hessian(pow_reducer, inputs, create_graph=True)
1143*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0][0].grad_fn)
1144*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0][1].grad_fn)
1145*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1][0].grad_fn)
1146*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1][1].grad_fn)
1147*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res, ctors.zeros(2, 2, 2))
1148*da0073e9SAndroid Build Coastguard Worker
1149*da0073e9SAndroid Build Coastguard Worker    @vectorized_logging_tensor
1150*da0073e9SAndroid Build Coastguard Worker    def test_hessian_output(self, vectorize, ctors):
1151*da0073e9SAndroid Build Coastguard Worker        def pow_reducer(x):
1152*da0073e9SAndroid Build Coastguard Worker            return x.pow(3).sum()
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(2, 2)
1155*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(pow_reducer, inputs, vectorize=vectorize)
1156*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inputs, inputs)
1157*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res.grad_fn)
1158*da0073e9SAndroid Build Coastguard Worker
1159*da0073e9SAndroid Build Coastguard Worker        def add_pow_reducer(x, y):
1160*da0073e9SAndroid Build Coastguard Worker            return (x + y).pow(3).sum()
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(2, 2), ctors.rand(2, 2))
1163*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(add_pow_reducer, inputs, vectorize=vectorize)
1164*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inputs, inputs)
1165*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0][0].grad_fn)
1166*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0][1].grad_fn)
1167*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1][0].grad_fn)
1168*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1][1].grad_fn)
1169*da0073e9SAndroid Build Coastguard Worker
1170*da0073e9SAndroid Build Coastguard Worker    @parametrize("vectorize", [True, False])
1171*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1172*da0073e9SAndroid Build Coastguard Worker    def test_hessian_scalar(self, vectorize, ctors):
1173*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
1174*da0073e9SAndroid Build Coastguard Worker            return x.sum()
1175*da0073e9SAndroid Build Coastguard Worker
1176*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1177*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(reducer, inputs, vectorize=vectorize)
1178*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inputs, inputs)
1179*da0073e9SAndroid Build Coastguard Worker
1180*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand([])
1181*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(reducer, inputs, vectorize=vectorize)
1182*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res, inputs)
1183*da0073e9SAndroid Build Coastguard Worker
1184*da0073e9SAndroid Build Coastguard Worker        def bad_reducer(x):
1185*da0073e9SAndroid Build Coastguard Worker            return x.sum().view(1, 1, 1)
1186*da0073e9SAndroid Build Coastguard Worker
1187*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1188*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(bad_reducer, inputs, vectorize=vectorize)
1189*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inputs, inputs)
1190*da0073e9SAndroid Build Coastguard Worker
1191*da0073e9SAndroid Build Coastguard Worker    @parametrize("vectorize", [True, False])
1192*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1193*da0073e9SAndroid Build Coastguard Worker    def test_hessian_create_graph(self, vectorize, ctors):
1194*da0073e9SAndroid Build Coastguard Worker        def pow_reducer(x):
1195*da0073e9SAndroid Build Coastguard Worker            return x.pow(3).sum()
1196*da0073e9SAndroid Build Coastguard Worker
1197*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(2, 2, dtype=torch.double, requires_grad=True)
1198*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(
1199*da0073e9SAndroid Build Coastguard Worker            pow_reducer, inputs, create_graph=True, vectorize=vectorize
1200*da0073e9SAndroid Build Coastguard Worker        )
1201*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inputs, inputs)
1202*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res.grad_fn)
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker        gradcheck(
1205*da0073e9SAndroid Build Coastguard Worker            lambda inp: autogradF.hessian(
1206*da0073e9SAndroid Build Coastguard Worker                pow_reducer, inp, create_graph=True, vectorize=vectorize
1207*da0073e9SAndroid Build Coastguard Worker            ),
1208*da0073e9SAndroid Build Coastguard Worker            inputs,
1209*da0073e9SAndroid Build Coastguard Worker        )
1210*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
1211*da0073e9SAndroid Build Coastguard Worker            lambda inp: autogradF.hessian(
1212*da0073e9SAndroid Build Coastguard Worker                pow_reducer, inp, create_graph=True, vectorize=vectorize
1213*da0073e9SAndroid Build Coastguard Worker            ),
1214*da0073e9SAndroid Build Coastguard Worker            inputs,
1215*da0073e9SAndroid Build Coastguard Worker        )
1216*da0073e9SAndroid Build Coastguard Worker
1217*da0073e9SAndroid Build Coastguard Worker        def add_pow_reducer(x, y):
1218*da0073e9SAndroid Build Coastguard Worker            return (x + y).pow(3).sum()
1219*da0073e9SAndroid Build Coastguard Worker
1220*da0073e9SAndroid Build Coastguard Worker        inputs = (
1221*da0073e9SAndroid Build Coastguard Worker            ctors.rand(2, 2, dtype=torch.double, requires_grad=True),
1222*da0073e9SAndroid Build Coastguard Worker            ctors.rand(2, 2, dtype=torch.double, requires_grad=True),
1223*da0073e9SAndroid Build Coastguard Worker        )
1224*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hessian(
1225*da0073e9SAndroid Build Coastguard Worker            add_pow_reducer, inputs, create_graph=True, vectorize=vectorize
1226*da0073e9SAndroid Build Coastguard Worker        )
1227*da0073e9SAndroid Build Coastguard Worker        self._assert_interleaved_struct(res, inputs, inputs)
1228*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0][0].grad_fn)
1229*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0][1].grad_fn)
1230*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1][0].grad_fn)
1231*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1][1].grad_fn)
1232*da0073e9SAndroid Build Coastguard Worker
1233*da0073e9SAndroid Build Coastguard Worker        def flatten(inp):
1234*da0073e9SAndroid Build Coastguard Worker            return tuple(el_lvl2 for el_lvl1 in inp for el_lvl2 in el_lvl1)
1235*da0073e9SAndroid Build Coastguard Worker
1236*da0073e9SAndroid Build Coastguard Worker        gradcheck(
1237*da0073e9SAndroid Build Coastguard Worker            lambda *inp: flatten(
1238*da0073e9SAndroid Build Coastguard Worker                autogradF.hessian(
1239*da0073e9SAndroid Build Coastguard Worker                    add_pow_reducer, inp, create_graph=True, vectorize=vectorize
1240*da0073e9SAndroid Build Coastguard Worker                )
1241*da0073e9SAndroid Build Coastguard Worker            ),
1242*da0073e9SAndroid Build Coastguard Worker            inputs,
1243*da0073e9SAndroid Build Coastguard Worker        )
1244*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
1245*da0073e9SAndroid Build Coastguard Worker            lambda *inp: flatten(
1246*da0073e9SAndroid Build Coastguard Worker                autogradF.hessian(
1247*da0073e9SAndroid Build Coastguard Worker                    add_pow_reducer, inp, create_graph=True, vectorize=vectorize
1248*da0073e9SAndroid Build Coastguard Worker                )
1249*da0073e9SAndroid Build Coastguard Worker            ),
1250*da0073e9SAndroid Build Coastguard Worker            inputs,
1251*da0073e9SAndroid Build Coastguard Worker        )
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker        def foo(x, y):
1254*da0073e9SAndroid Build Coastguard Worker            x = x.cos()
1255*da0073e9SAndroid Build Coastguard Worker            val, hess = autogradF.hessian(
1256*da0073e9SAndroid Build Coastguard Worker                add_pow_reducer, (x, y), create_graph=True, vectorize=vectorize
1257*da0073e9SAndroid Build Coastguard Worker            )
1258*da0073e9SAndroid Build Coastguard Worker
1259*da0073e9SAndroid Build Coastguard Worker            res = val[0].cos().sum() + val[1].cos().sum() + hess[0].cos().sum()
1260*da0073e9SAndroid Build Coastguard Worker            res = res + hess[1].cos().sum() + x.cos().sum() + y.cos().sum()
1261*da0073e9SAndroid Build Coastguard Worker            return res
1262*da0073e9SAndroid Build Coastguard Worker
1263*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, inputs)
1264*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(foo, inputs)
1265*da0073e9SAndroid Build Coastguard Worker
1266*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1267*da0073e9SAndroid Build Coastguard Worker    def test_vhp_err_check(self, ctors):
1268*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1269*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1270*da0073e9SAndroid Build Coastguard Worker
1271*da0073e9SAndroid Build Coastguard Worker        def bar(a):
1272*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), "bar"
1273*da0073e9SAndroid Build Coastguard Worker
1274*da0073e9SAndroid Build Coastguard Worker        def bar2(a):
1275*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3)
1276*da0073e9SAndroid Build Coastguard Worker
1277*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1278*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
1279*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1280*da0073e9SAndroid Build Coastguard Worker            TypeError, "The inputs given to vhp must be either a Tensor"
1281*da0073e9SAndroid Build Coastguard Worker        ):
1282*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(foo, (inp, 2), v)
1283*da0073e9SAndroid Build Coastguard Worker
1284*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1285*da0073e9SAndroid Build Coastguard Worker            TypeError, "The outputs of the user-provided function given to vhp must"
1286*da0073e9SAndroid Build Coastguard Worker        ):
1287*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(bar, inp, v)
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker        err_msg_out = "The Tensor returned by the function given to vhp should contain a single element"
1290*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg_out):
1291*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(bar2, inp, v)
1292*da0073e9SAndroid Build Coastguard Worker
1293*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "v has invalid size:"):
1294*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(foo, inp, ctors.rand(5))
1295*da0073e9SAndroid Build Coastguard Worker
1296*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1297*da0073e9SAndroid Build Coastguard Worker            TypeError,
1298*da0073e9SAndroid Build Coastguard Worker            "The v given to vhp must be either a Tensor or a tuple of Tensors",
1299*da0073e9SAndroid Build Coastguard Worker        ):
1300*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(foo, inp, (v, 2))
1301*da0073e9SAndroid Build Coastguard Worker
1302*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(foo, inp, v)
1303*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1304*da0073e9SAndroid Build Coastguard Worker
1305*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
1306*da0073e9SAndroid Build Coastguard Worker            return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum()
1307*da0073e9SAndroid Build Coastguard Worker
1308*da0073e9SAndroid Build Coastguard Worker        inp = (ctors.rand(4), ctors.rand(5))
1309*da0073e9SAndroid Build Coastguard Worker        v = (ctors.rand(4), ctors.rand(5))
1310*da0073e9SAndroid Build Coastguard Worker
1311*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(foo, inp, v)
1312*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1313*da0073e9SAndroid Build Coastguard Worker
1314*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1315*da0073e9SAndroid Build Coastguard Worker    def test_vhp_err_check_strict(self, ctors):
1316*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1317*da0073e9SAndroid Build Coastguard Worker            return a.detach().sum()
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker        def bar(a):
1320*da0073e9SAndroid Build Coastguard Worker            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
1321*da0073e9SAndroid Build Coastguard Worker            return a.long().float().requires_grad_().clone().sum()
1322*da0073e9SAndroid Build Coastguard Worker
1323*da0073e9SAndroid Build Coastguard Worker        def bar2(a):
1324*da0073e9SAndroid Build Coastguard Worker            # A Linear function for which the jacobian is independent of the input
1325*da0073e9SAndroid Build Coastguard Worker            return (3 * a).sum()
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1328*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
1329*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1330*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1331*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function does not require gradients.",
1332*da0073e9SAndroid Build Coastguard Worker        ):
1333*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(foo, inp, v, strict=True)
1334*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(foo, inp, v, strict=False)
1335*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1336*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
1337*da0073e9SAndroid Build Coastguard Worker
1338*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1339*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1340*da0073e9SAndroid Build Coastguard Worker            "The output of the user-provided function is independent of input 0",
1341*da0073e9SAndroid Build Coastguard Worker        ):
1342*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(bar, inp, v, strict=True)
1343*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(bar, inp, v, strict=False)
1344*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1345*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1348*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1349*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function with respect to input 0 is",
1350*da0073e9SAndroid Build Coastguard Worker        ):
1351*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(bar2, inp, v, strict=True)
1352*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(bar2, inp, v, strict=False)
1353*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
1355*da0073e9SAndroid Build Coastguard Worker
1356*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1357*da0073e9SAndroid Build Coastguard Worker    def test_vhp_no_grad(self, ctors):
1358*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
1359*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum()
1360*da0073e9SAndroid Build Coastguard Worker
1361*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1362*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
1363*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1364*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(reducer, inputs, v)
1365*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
1366*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
1367*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1370*da0073e9SAndroid Build Coastguard Worker            res = autogradF.vhp(reducer, inputs, v, create_graph=True)
1371*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
1372*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
1373*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1376*da0073e9SAndroid Build Coastguard Worker    def test_vhp_output(self, ctors):
1377*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1378*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1379*da0073e9SAndroid Build Coastguard Worker
1380*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1381*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
1382*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(foo, inputs, v)
1383*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1384*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
1385*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
1386*da0073e9SAndroid Build Coastguard Worker
1387*da0073e9SAndroid Build Coastguard Worker        def bar(a, b):
1388*da0073e9SAndroid Build Coastguard Worker            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1389*da0073e9SAndroid Build Coastguard Worker
1390*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(3), ctors.rand(4))
1391*da0073e9SAndroid Build Coastguard Worker        v = (ctors.ones(3), ctors.ones(4))
1392*da0073e9SAndroid Build Coastguard Worker        out, vhp_val = autogradF.vhp(bar, inputs, v)
1393*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(vhp_val, inputs)
1394*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn)
1395*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(vhp_val[0].grad_fn)
1396*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(vhp_val[1].grad_fn)
1397*da0073e9SAndroid Build Coastguard Worker
1398*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1399*da0073e9SAndroid Build Coastguard Worker    def test_vhp_scalar(self, ctors):
1400*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
1401*da0073e9SAndroid Build Coastguard Worker            return x.sum()
1402*da0073e9SAndroid Build Coastguard Worker
1403*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1404*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
1405*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(reducer, inputs, v)
1406*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1407*da0073e9SAndroid Build Coastguard Worker
1408*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand([])
1409*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand([])
1410*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(reducer, inputs, v)
1411*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(reducer, inputs)
1414*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1415*da0073e9SAndroid Build Coastguard Worker
1416*da0073e9SAndroid Build Coastguard Worker        def bad_reducer(x):
1417*da0073e9SAndroid Build Coastguard Worker            return x.sum().view(1, 1, 1)
1418*da0073e9SAndroid Build Coastguard Worker
1419*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1420*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4, 4)
1421*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(bad_reducer, inputs, v)
1422*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1423*da0073e9SAndroid Build Coastguard Worker
1424*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1425*da0073e9SAndroid Build Coastguard Worker    def test_vhp_create_graph(self, ctors):
1426*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1427*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True)
1430*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True)
1431*da0073e9SAndroid Build Coastguard Worker        res = autogradF.vhp(foo, inputs, v, create_graph=True)
1432*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1433*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
1434*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
1435*da0073e9SAndroid Build Coastguard Worker
1436*da0073e9SAndroid Build Coastguard Worker        gradcheck(
1437*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v)
1438*da0073e9SAndroid Build Coastguard Worker        )
1439*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
1440*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v)
1441*da0073e9SAndroid Build Coastguard Worker        )
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker        def bar(a, b):
1444*da0073e9SAndroid Build Coastguard Worker            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1445*da0073e9SAndroid Build Coastguard Worker
1446*da0073e9SAndroid Build Coastguard Worker        inputs = (
1447*da0073e9SAndroid Build Coastguard Worker            ctors.rand(3, dtype=torch.double, requires_grad=True),
1448*da0073e9SAndroid Build Coastguard Worker            ctors.rand(4, dtype=torch.double, requires_grad=True),
1449*da0073e9SAndroid Build Coastguard Worker        )
1450*da0073e9SAndroid Build Coastguard Worker        v = (
1451*da0073e9SAndroid Build Coastguard Worker            ctors.ones(3, dtype=torch.double, requires_grad=True),
1452*da0073e9SAndroid Build Coastguard Worker            ctors.ones(4, dtype=torch.double, requires_grad=True),
1453*da0073e9SAndroid Build Coastguard Worker        )
1454*da0073e9SAndroid Build Coastguard Worker        out, vhp_val = autogradF.vhp(bar, inputs, v, create_graph=True)
1455*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(vhp_val, inputs)
1456*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(out.grad_fn)
1457*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(vhp_val[0].grad_fn)
1458*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(vhp_val[1].grad_fn)
1459*da0073e9SAndroid Build Coastguard Worker
1460*da0073e9SAndroid Build Coastguard Worker        gradcheck(
1461*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1],
1462*da0073e9SAndroid Build Coastguard Worker            inputs + v,
1463*da0073e9SAndroid Build Coastguard Worker        )
1464*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
1465*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1],
1466*da0073e9SAndroid Build Coastguard Worker            inputs + v,
1467*da0073e9SAndroid Build Coastguard Worker        )
1468*da0073e9SAndroid Build Coastguard Worker
1469*da0073e9SAndroid Build Coastguard Worker        def foo(*args):
1470*da0073e9SAndroid Build Coastguard Worker            x, y = args[:2]
1471*da0073e9SAndroid Build Coastguard Worker            v = args[2:]
1472*da0073e9SAndroid Build Coastguard Worker
1473*da0073e9SAndroid Build Coastguard Worker            x = x.cos()
1474*da0073e9SAndroid Build Coastguard Worker            val, grad = autogradF.vhp(bar, (x, y), v, create_graph=True)
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker            return (
1477*da0073e9SAndroid Build Coastguard Worker                val.cos()
1478*da0073e9SAndroid Build Coastguard Worker                + grad[0].cos().sum()
1479*da0073e9SAndroid Build Coastguard Worker                + grad[1].cos()
1480*da0073e9SAndroid Build Coastguard Worker                + x.cos().sum()
1481*da0073e9SAndroid Build Coastguard Worker                + y.cos()
1482*da0073e9SAndroid Build Coastguard Worker            )
1483*da0073e9SAndroid Build Coastguard Worker
1484*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, inputs + v)
1485*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(foo, inputs + v)
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1488*da0073e9SAndroid Build Coastguard Worker    def test_hvp_err_check(self, ctors):
1489*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1490*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1491*da0073e9SAndroid Build Coastguard Worker
1492*da0073e9SAndroid Build Coastguard Worker        def bar(a):
1493*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3), "bar"
1494*da0073e9SAndroid Build Coastguard Worker
1495*da0073e9SAndroid Build Coastguard Worker        def bar2(a):
1496*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3)
1497*da0073e9SAndroid Build Coastguard Worker
1498*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1499*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
1500*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(foo, inp, v)
1501*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1502*da0073e9SAndroid Build Coastguard Worker            TypeError, "The inputs given to hvp must be either a Tensor"
1503*da0073e9SAndroid Build Coastguard Worker        ):
1504*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(foo, (inp, 2), v)
1505*da0073e9SAndroid Build Coastguard Worker
1506*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1507*da0073e9SAndroid Build Coastguard Worker            TypeError, "The outputs of the user-provided function given to hvp must"
1508*da0073e9SAndroid Build Coastguard Worker        ):
1509*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(bar, inp, v)
1510*da0073e9SAndroid Build Coastguard Worker
1511*da0073e9SAndroid Build Coastguard Worker        err_msg_out = "The Tensor returned by the function given to hvp should contain a single element"
1512*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg_out):
1513*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(bar2, inp, v)
1514*da0073e9SAndroid Build Coastguard Worker
1515*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "v has invalid size:"):
1516*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(foo, inp, ctors.rand(5))
1517*da0073e9SAndroid Build Coastguard Worker
1518*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1519*da0073e9SAndroid Build Coastguard Worker            TypeError,
1520*da0073e9SAndroid Build Coastguard Worker            "The v given to hvp must be either a Tensor or a tuple of Tensors",
1521*da0073e9SAndroid Build Coastguard Worker        ):
1522*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(foo, inp, (v, 2))
1523*da0073e9SAndroid Build Coastguard Worker
1524*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(foo, inp, v)
1525*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1526*da0073e9SAndroid Build Coastguard Worker
1527*da0073e9SAndroid Build Coastguard Worker        def foo(a, b):
1528*da0073e9SAndroid Build Coastguard Worker            return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum()
1529*da0073e9SAndroid Build Coastguard Worker
1530*da0073e9SAndroid Build Coastguard Worker        inp = (ctors.rand(4), ctors.rand(5))
1531*da0073e9SAndroid Build Coastguard Worker        v = (ctors.rand(4), ctors.rand(5))
1532*da0073e9SAndroid Build Coastguard Worker
1533*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(foo, inp, v)
1534*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1535*da0073e9SAndroid Build Coastguard Worker
1536*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1537*da0073e9SAndroid Build Coastguard Worker    def test_hvp_err_check_strict(self, ctors):
1538*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1539*da0073e9SAndroid Build Coastguard Worker            return a.detach().sum()
1540*da0073e9SAndroid Build Coastguard Worker
1541*da0073e9SAndroid Build Coastguard Worker        def bar(a):
1542*da0073e9SAndroid Build Coastguard Worker            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
1543*da0073e9SAndroid Build Coastguard Worker            return a.long().float().requires_grad_().clone().sum()
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker        def bar2(a):
1546*da0073e9SAndroid Build Coastguard Worker            # A Linear function for which the jacobian is independent of the input
1547*da0073e9SAndroid Build Coastguard Worker            return (3 * a).sum()
1548*da0073e9SAndroid Build Coastguard Worker
1549*da0073e9SAndroid Build Coastguard Worker        inp = ctors.rand(4)
1550*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
1551*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1552*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1553*da0073e9SAndroid Build Coastguard Worker            "Output 0 of the user-provided function does not require gradients.",
1554*da0073e9SAndroid Build Coastguard Worker        ):
1555*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(foo, inp, v, strict=True)
1556*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(foo, inp, v, strict=False)
1557*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1558*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
1559*da0073e9SAndroid Build Coastguard Worker
1560*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1561*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1562*da0073e9SAndroid Build Coastguard Worker            "The output of the user-provided function is independent of input 0",
1563*da0073e9SAndroid Build Coastguard Worker        ):
1564*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(bar, inp, v, strict=True)
1565*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(bar, inp, v, strict=False)
1566*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1567*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
1568*da0073e9SAndroid Build Coastguard Worker
1569*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1570*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1571*da0073e9SAndroid Build Coastguard Worker            "jacobian of the user-provided function with respect to input 0 is",
1572*da0073e9SAndroid Build Coastguard Worker        ):
1573*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(bar2, inp, v, strict=True)
1574*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(bar2, inp, v, strict=False)
1575*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inp)
1576*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res[1].abs().sum(), 0.0)
1577*da0073e9SAndroid Build Coastguard Worker
1578*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1579*da0073e9SAndroid Build Coastguard Worker    def test_hvp_no_grad(self, ctors):
1580*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
1581*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum()
1582*da0073e9SAndroid Build Coastguard Worker
1583*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1584*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
1585*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1586*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(reducer, inputs, v)
1587*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
1588*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
1589*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1590*da0073e9SAndroid Build Coastguard Worker
1591*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1592*da0073e9SAndroid Build Coastguard Worker            res = autogradF.hvp(reducer, inputs, v, create_graph=True)
1593*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
1594*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
1595*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1596*da0073e9SAndroid Build Coastguard Worker
1597*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1598*da0073e9SAndroid Build Coastguard Worker    def test_hvp_output(self, ctors):
1599*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1600*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1603*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
1604*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(foo, inputs, v)
1605*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1606*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[0].grad_fn)
1607*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(res[1].grad_fn)
1608*da0073e9SAndroid Build Coastguard Worker
1609*da0073e9SAndroid Build Coastguard Worker        def bar(a, b):
1610*da0073e9SAndroid Build Coastguard Worker            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1611*da0073e9SAndroid Build Coastguard Worker
1612*da0073e9SAndroid Build Coastguard Worker        inputs = (ctors.rand(3), ctors.rand(4))
1613*da0073e9SAndroid Build Coastguard Worker        v = (ctors.ones(3), ctors.ones(4))
1614*da0073e9SAndroid Build Coastguard Worker        out, hvp_val = autogradF.hvp(bar, inputs, v)
1615*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(hvp_val, inputs)
1616*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn)
1617*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(hvp_val[0].grad_fn)
1618*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(hvp_val[1].grad_fn)
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1621*da0073e9SAndroid Build Coastguard Worker    def test_hvp_scalar(self, ctors):
1622*da0073e9SAndroid Build Coastguard Worker        def reducer(x):
1623*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum()
1624*da0073e9SAndroid Build Coastguard Worker
1625*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1626*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4)
1627*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(reducer, inputs, v)
1628*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1629*da0073e9SAndroid Build Coastguard Worker
1630*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand([])
1631*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand([])
1632*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(reducer, inputs, v)
1633*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1634*da0073e9SAndroid Build Coastguard Worker
1635*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(reducer, inputs)
1636*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1637*da0073e9SAndroid Build Coastguard Worker
1638*da0073e9SAndroid Build Coastguard Worker        def bad_reducer(x):
1639*da0073e9SAndroid Build Coastguard Worker            return x.exp().sum().view(1, 1, 1)
1640*da0073e9SAndroid Build Coastguard Worker
1641*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4)
1642*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4, 4)
1643*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(bad_reducer, inputs, v)
1644*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1645*da0073e9SAndroid Build Coastguard Worker
1646*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1647*da0073e9SAndroid Build Coastguard Worker    def test_hvp_create_graph(self, ctors):
1648*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1649*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1650*da0073e9SAndroid Build Coastguard Worker
1651*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True)
1652*da0073e9SAndroid Build Coastguard Worker        v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True)
1653*da0073e9SAndroid Build Coastguard Worker        res = autogradF.hvp(foo, inputs, v, create_graph=True)
1654*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(res[1], inputs)
1655*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[0].grad_fn)
1656*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(res[1].grad_fn)
1657*da0073e9SAndroid Build Coastguard Worker
1658*da0073e9SAndroid Build Coastguard Worker        gradcheck(
1659*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v)
1660*da0073e9SAndroid Build Coastguard Worker        )
1661*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
1662*da0073e9SAndroid Build Coastguard Worker            lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v)
1663*da0073e9SAndroid Build Coastguard Worker        )
1664*da0073e9SAndroid Build Coastguard Worker
1665*da0073e9SAndroid Build Coastguard Worker        def bar(a, b):
1666*da0073e9SAndroid Build Coastguard Worker            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1667*da0073e9SAndroid Build Coastguard Worker
1668*da0073e9SAndroid Build Coastguard Worker        inputs = (
1669*da0073e9SAndroid Build Coastguard Worker            ctors.rand(3, dtype=torch.double, requires_grad=True),
1670*da0073e9SAndroid Build Coastguard Worker            ctors.rand(4, dtype=torch.double, requires_grad=True),
1671*da0073e9SAndroid Build Coastguard Worker        )
1672*da0073e9SAndroid Build Coastguard Worker        v = (
1673*da0073e9SAndroid Build Coastguard Worker            ctors.ones(3, dtype=torch.double, requires_grad=True),
1674*da0073e9SAndroid Build Coastguard Worker            ctors.ones(4, dtype=torch.double, requires_grad=True),
1675*da0073e9SAndroid Build Coastguard Worker        )
1676*da0073e9SAndroid Build Coastguard Worker        out, hvp_val = autogradF.hvp(bar, inputs, v, create_graph=True)
1677*da0073e9SAndroid Build Coastguard Worker        self._assert_same_struct(hvp_val, inputs)
1678*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(out.grad_fn)
1679*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(hvp_val[0].grad_fn)
1680*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(hvp_val[1].grad_fn)
1681*da0073e9SAndroid Build Coastguard Worker
1682*da0073e9SAndroid Build Coastguard Worker        gradcheck(
1683*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1],
1684*da0073e9SAndroid Build Coastguard Worker            inputs + v,
1685*da0073e9SAndroid Build Coastguard Worker        )
1686*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(
1687*da0073e9SAndroid Build Coastguard Worker            lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1],
1688*da0073e9SAndroid Build Coastguard Worker            inputs + v,
1689*da0073e9SAndroid Build Coastguard Worker        )
1690*da0073e9SAndroid Build Coastguard Worker
1691*da0073e9SAndroid Build Coastguard Worker        def foo(*args):
1692*da0073e9SAndroid Build Coastguard Worker            x, y = args[:2]
1693*da0073e9SAndroid Build Coastguard Worker            v = args[2:]
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker            x = x.cos()
1696*da0073e9SAndroid Build Coastguard Worker            val, grad = autogradF.hvp(bar, (x, y), v, create_graph=True)
1697*da0073e9SAndroid Build Coastguard Worker
1698*da0073e9SAndroid Build Coastguard Worker            return (
1699*da0073e9SAndroid Build Coastguard Worker                val.cos()
1700*da0073e9SAndroid Build Coastguard Worker                + grad[0].cos().sum()
1701*da0073e9SAndroid Build Coastguard Worker                + grad[1].cos()
1702*da0073e9SAndroid Build Coastguard Worker                + x.cos().sum()
1703*da0073e9SAndroid Build Coastguard Worker                + y.cos()
1704*da0073e9SAndroid Build Coastguard Worker            )
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker        gradcheck(foo, inputs + v)
1707*da0073e9SAndroid Build Coastguard Worker        gradgradcheck(foo, inputs + v)
1708*da0073e9SAndroid Build Coastguard Worker
1709*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1710*da0073e9SAndroid Build Coastguard Worker    def test_jacobian_match_vjp_jvp(self, ctors):
1711*da0073e9SAndroid Build Coastguard Worker        def foo(x):
1712*da0073e9SAndroid Build Coastguard Worker            return x**3 + x.sum()
1713*da0073e9SAndroid Build Coastguard Worker
1714*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4)
1715*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
1716*da0073e9SAndroid Build Coastguard Worker
1717*da0073e9SAndroid Build Coastguard Worker        jac = autogradF.jacobian(foo, inputs)
1718*da0073e9SAndroid Build Coastguard Worker        jvp = autogradF.jvp(foo, inputs, v)[1]
1719*da0073e9SAndroid Build Coastguard Worker        vjp = autogradF.vjp(foo, inputs, v)[1]
1720*da0073e9SAndroid Build Coastguard Worker
1721*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(jvp, torch.mm(jac, v.unsqueeze(1)).squeeze(1))
1722*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vjp, torch.mm(v.unsqueeze(0), jac).squeeze(0))
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker    @base_and_logging_tensor
1725*da0073e9SAndroid Build Coastguard Worker    def test_hessian_match_vhp_hvp(self, ctors):
1726*da0073e9SAndroid Build Coastguard Worker        def foo(a):
1727*da0073e9SAndroid Build Coastguard Worker            return 3 * a.narrow(0, 0, 3).exp().sum()
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        inputs = ctors.rand(4)
1730*da0073e9SAndroid Build Coastguard Worker        v = ctors.rand(4)
1731*da0073e9SAndroid Build Coastguard Worker
1732*da0073e9SAndroid Build Coastguard Worker        hes = autogradF.hessian(foo, inputs)
1733*da0073e9SAndroid Build Coastguard Worker        hvp = autogradF.hvp(foo, inputs, v)[1]
1734*da0073e9SAndroid Build Coastguard Worker        vhp = autogradF.vhp(foo, inputs, v)[1]
1735*da0073e9SAndroid Build Coastguard Worker
1736*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
1737*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))
1738*da0073e9SAndroid Build Coastguard Worker
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestAutogradFunctional)
1741*da0073e9SAndroid Build Coastguard Worker
1742*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
1743*da0073e9SAndroid Build Coastguard Worker    run_tests()
1744