xref: /aosp_15_r20/external/pytorch/test/autograd/test_functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: autograd"]
2
3import types
4import unittest
5import warnings
6
7import torch
8import torch.autograd.functional as autogradF
9from torch.testing._internal.common_cuda import TEST_CUDA
10from torch.testing._internal.common_utils import (
11    gradcheck,
12    gradgradcheck,
13    instantiate_parametrized_tests,
14    parametrize,
15    run_tests,
16    subtest,
17    TestCase,
18)
19from torch.testing._internal.logging_tensor import LoggingTensor
20
21
22# Utilities for parametrizing the tensor constructors used in autograd tests
23#
24# TODO: maybe move somewhere so other tests can also use
25#
26# NB: Not all factory functions included. A complete(?) list can be found here:
27#     https://pytorch.org/cppdocs/notes/tensor_creation.html
28base_ctors_dict = {
29    "ones": torch.ones,
30    "zeros": torch.zeros,
31    "randn": torch.randn,
32    "rand": torch.rand,
33    "tensor": torch.tensor,
34}
35base_ctors = types.SimpleNamespace(**base_ctors_dict)
36
37
38def wrap_with_logging_tensor(ctor):
39    def wrapper(*args, **kwargs):
40        requires_grad = kwargs.pop("requires_grad", False)
41        return LoggingTensor(ctor(*args, **kwargs), requires_grad=requires_grad)
42
43    return wrapper
44
45
46logging_tensor_ctors_dict = {
47    k: wrap_with_logging_tensor(ctor) for (k, ctor) in base_ctors_dict.items()
48}
49logging_tensor_ctors = types.SimpleNamespace(**logging_tensor_ctors_dict)
50
51base_and_logging_tensor = parametrize(
52    "ctors",
53    [
54        subtest(base_ctors, name="base_tensor"),
55        subtest(logging_tensor_ctors, name="logging_tensor"),
56    ],
57)
58
59FIXME_base_and_xfail_logging_tensor = parametrize(
60    "ctors",
61    [
62        subtest(base_ctors, name="base_tensor"),
63        subtest(
64            logging_tensor_ctors,
65            name="logging_tensor",
66            decorators=[unittest.expectedFailure],
67        ),
68    ],
69)
70
71# NB: This is equivalent to having both @parametrize("vectorized", [True, False]) and
72#     FIXME_base_and_xfail_logging_tensor, except the non-vectorized logging_tensor case is
73#     actually expected to succeed
74FIXME_xfail_vectorized_logging_tensor = parametrize(
75    "vectorize,ctors",
76    [
77        subtest((True, base_ctors), name="vectorized_base_tensor"),
78        subtest((False, base_ctors), name="base_tensor"),
79        subtest(
80            (True, logging_tensor_ctors),
81            name="vectorized_logging_tensor",
82            decorators=[unittest.expectedFailure],
83        ),
84        subtest((False, logging_tensor_ctors), name="logging_tensor"),
85    ],
86)
87
88vectorized_logging_tensor = parametrize(
89    "vectorize,ctors",
90    [
91        subtest((True, base_ctors), name="vectorized_base_tensor"),
92        subtest((False, base_ctors), name="base_tensor"),
93        subtest((True, logging_tensor_ctors), name="vectorized_logging_tensor"),
94        subtest((False, logging_tensor_ctors), name="logging_tensor"),
95    ],
96)
97
98
99class TestAutogradFunctional(TestCase):
100    def _assert_same_struct(self, res, base):
101        # base and res should be Tensors or tuple of Tensors with the same size
102        if isinstance(base, torch.Tensor):
103            self.assertTrue(isinstance(res, torch.Tensor))
104            self.assertEqual(base.size(), res.size())
105        elif isinstance(base, tuple):
106            self.assertTrue(isinstance(res, tuple))
107            self.assertEqual(len(base), len(res))
108            for el_base, el_res in zip(base, res):
109                self.assertTrue(isinstance(el_base, torch.Tensor))
110                self.assertTrue(isinstance(el_res, torch.Tensor))
111                self.assertEqual(el_base.size(), el_res.size())
112        else:
113            # Wrong base
114            raise RuntimeError(
115                "The base given to `_assert_same_struct` doesn't have"
116                " the right structure."
117            )
118
119    def _assert_interleaved_struct(self, res, base1, base2):
120        # base1 and base2 can be Tensors or tuples of Tensors.
121        # If they are tuples, res should be a tuple as well.
122        # The indexing works as follows for base1, base2 being
123        # - tuple, tuple: res[i][j][k][l] = (base1[i][k], base2[j][l])
124        # - tuple, Tensor: res[i][k][l] = (base1[i][k], base2[l])
125        # - Tensor, tuple: res[i][j][l] = (base1[i], base2[j][l])
126        # - Tensor, Tensor: res[k][l] = (base1[k], base2[l])
127        if isinstance(base1, torch.Tensor) and isinstance(base2, torch.Tensor):
128            self.assertTrue(isinstance(res, torch.Tensor))
129            self.assertEqual(res.size(), base1.size() + base2.size())
130        elif isinstance(base1, tuple) and isinstance(base2, torch.Tensor):
131            self.assertTrue(isinstance(res, tuple))
132            self.assertEqual(len(res), len(base1))
133            for el_res, el_base1 in zip(res, base1):
134                self.assertTrue(isinstance(el_res, torch.Tensor))
135                self.assertTrue(isinstance(el_base1, torch.Tensor))
136                self.assertEqual(el_res.size(), el_base1.size() + base2.size())
137        elif isinstance(base1, torch.Tensor) and isinstance(base2, tuple):
138            self.assertTrue(isinstance(res, tuple))
139            self.assertEqual(len(res), len(base2))
140            for el_res, el_base2 in zip(res, base2):
141                self.assertTrue(isinstance(el_res, torch.Tensor))
142                self.assertTrue(isinstance(el_base2, torch.Tensor))
143                self.assertEqual(el_res.size(), base1.size() + el_base2.size())
144        elif isinstance(base1, tuple) and isinstance(base2, tuple):
145            self.assertTrue(isinstance(res, tuple))
146            self.assertEqual(len(res), len(base1))
147            for el_res, el_base1 in zip(res, base1):
148                self.assertTrue(isinstance(el_res, tuple))
149                self.assertEqual(len(res), len(base2))
150                for el_el_res, el_base2 in zip(el_res, base2):
151                    self.assertTrue(isinstance(el_el_res, torch.Tensor))
152                    self.assertTrue(isinstance(el_base2, torch.Tensor))
153                    self.assertEqual(
154                        el_el_res.size(), el_base1.size() + el_base2.size()
155                    )
156        else:
157            # Wrong bases
158            raise RuntimeError(
159                "The bases given to `_assert_interleaved_struct` don't have"
160                " the right structure."
161            )
162
163    @base_and_logging_tensor
164    def test_vjp_err_check(self, ctors):
165        def foo(a):
166            return 3 * a.narrow(0, 0, 3)
167
168        def bar(a):
169            return 3 * a.narrow(0, 0, 3), "bar"
170
171        inp = ctors.rand(4)
172        v = ctors.ones(3)
173        with self.assertRaisesRegex(
174            TypeError, "The inputs given to vjp must be either a Tensor"
175        ):
176            res = autogradF.vjp(foo, (inp, 2), v)
177
178        with self.assertRaisesRegex(
179            TypeError, "The outputs of the user-provided function given to vjp must"
180        ):
181            res = autogradF.vjp(bar, inp, v)
182
183        with self.assertRaisesRegex(
184            RuntimeError,
185            "The vector v can only be None if the user-provided function returns",
186        ):
187            res = autogradF.vjp(foo, inp)
188
189        with self.assertRaisesRegex(
190            RuntimeError, "The given v should contain a single Tensor."
191        ):
192            res = autogradF.vjp(foo, inp, (torch.ones_like(inp), torch.ones_like(inp)))
193
194        with self.assertRaisesRegex(
195            RuntimeError, "v has invalid size: should be torch.Size"
196        ):
197            res = autogradF.vjp(foo, inp, v[:2])
198
199        res = autogradF.vjp(foo, inp, v)[1]
200        self._assert_same_struct(res, inp)
201
202    @base_and_logging_tensor
203    def test_vjp_err_check_strict(self, ctors):
204        def foo(a):
205            return a.detach()
206
207        def bar(a):
208            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
209            return a.long().float().requires_grad_().clone()
210
211        inp = ctors.rand(4)
212        v = ctors.rand(4)
213        with self.assertRaisesRegex(
214            RuntimeError,
215            "Output 0 of the user-provided function does not require gradients.",
216        ):
217            res = autogradF.vjp(foo, inp, v, strict=True)
218        res = autogradF.vjp(foo, inp, v, strict=False)
219        self._assert_same_struct(res[1], inp)
220        self.assertEqual(res[1].abs().sum(), 0.0)
221
222        with self.assertRaisesRegex(
223            RuntimeError,
224            "The output of the user-provided function is independent of input 0",
225        ):
226            res = autogradF.vjp(bar, inp, v, strict=True)
227        res = autogradF.vjp(bar, inp, v, strict=False)
228        self._assert_same_struct(res[1], inp)
229        self.assertEqual(res[1].abs().sum(), 0.0)
230
231        # The Jacobian does not depend on the input
232        def foo(a):
233            return a.clone()
234
235        inp.requires_grad_()
236        with self.assertRaisesRegex(
237            RuntimeError,
238            "jacobian of the user-provided function is independent of input 0.",
239        ):
240            res = autogradF.vjp(foo, inp, v, create_graph=True, strict=True)
241        res = autogradF.vjp(foo, inp, v, create_graph=True, strict=False)
242        self._assert_same_struct(res[1], inp)
243        self.assertEqual(res[1], v)
244
245    @base_and_logging_tensor
246    def test_vjp_no_grad(self, ctors):
247        def reducer(x):
248            return x.sum(dim=1)
249
250        inputs = ctors.rand(4, 4)
251        v = ctors.ones(4)
252        with torch.no_grad():
253            res = autogradF.vjp(reducer, inputs, v)
254        self.assertIsNone(res[0].grad_fn)
255        self.assertIsNone(res[1].grad_fn)
256        self.assertNotEqual(res[1], ctors.zeros(4, 4))
257
258        inputs.requires_grad_()
259        v.requires_grad_()
260        with torch.no_grad():
261            res = autogradF.vjp(reducer, inputs, v, create_graph=True)
262        self.assertIsNotNone(res[0].grad_fn)
263        self.assertIsNotNone(res[1].grad_fn)
264        self.assertNotEqual(res[1], ctors.zeros(4, 4))
265
266    @base_and_logging_tensor
267    def test_vjp_output(self, ctors):
268        def reducer(x):
269            return x.sum(dim=1)
270
271        inputs = ctors.rand(4, 4)
272        v = ctors.ones(4)
273        res = autogradF.vjp(reducer, inputs, v)
274        self._assert_same_struct(res[1], inputs)
275        self.assertIsNone(res[0].grad_fn)
276        self.assertIsNone(res[1].grad_fn)
277
278        def adder(x, y):
279            return 2 * x + 3 * y
280
281        inputs = (ctors.rand(2), ctors.rand(2))
282        v = ctors.ones(2)
283        out, vjp_val = autogradF.vjp(adder, inputs, v)
284        self._assert_same_struct(vjp_val, inputs)
285        self.assertIsNone(out.grad_fn)
286        self.assertIsNone(vjp_val[0].grad_fn)
287        self.assertIsNone(vjp_val[1].grad_fn)
288
289        def adder(x, y):
290            return 2 * x + 3 * y, x + y
291
292        inputs = (ctors.rand(2), ctors.rand(2))
293        v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0]))
294        out, vjp_val = autogradF.vjp(adder, inputs, v)
295        self._assert_same_struct(vjp_val, inputs)
296        self.assertIsNone(out[0].grad_fn)
297        self.assertIsNone(out[1].grad_fn)
298        self.assertIsNone(vjp_val[0].grad_fn)
299        self.assertIsNone(vjp_val[1].grad_fn)
300
301    @base_and_logging_tensor
302    def test_vjp_scalar(self, ctors):
303        def reducer(x):
304            return x.sum()
305
306        inputs = ctors.rand(4, 4)
307        v = ctors.ones([])
308        res = autogradF.vjp(reducer, inputs, v)
309        self._assert_same_struct(res[0], v)
310        self._assert_same_struct(res[1], inputs)
311
312        res = autogradF.vjp(reducer, inputs)
313        self._assert_same_struct(res[0], v)
314        self._assert_same_struct(res[1], inputs)
315
316        def expander(x):
317            return x.unsqueeze(0).repeat(4)
318
319        inputs = ctors.rand([])
320        v = ctors.ones(4)
321        res = autogradF.vjp(expander, inputs, v)
322        self._assert_same_struct(res[0], v)
323        self._assert_same_struct(res[1], inputs)
324
325    @base_and_logging_tensor
326    def test_vjp_create_graph(self, ctors):
327        def reducer(x):
328            return x.sum(dim=1)
329
330        inputs = ctors.rand(2, 2, dtype=torch.double)
331        v = ctors.ones(2, dtype=torch.double)
332
333        inputs.requires_grad_()
334        v.requires_grad_()
335        res = autogradF.vjp(reducer, inputs, v, create_graph=True)
336        self._assert_same_struct(res[1], inputs)
337        self.assertIsNotNone(res[0].grad_fn)
338        self.assertIsNotNone(res[1].grad_fn)
339
340        gradcheck(
341            lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True),
342            (inputs, v),
343        )
344        gradgradcheck(
345            lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True),
346            (inputs, v),
347        )
348
349        def adder(x, y):
350            return 2 * x + 3 * y, x * y
351
352        inputs = (
353            ctors.rand(2, dtype=torch.double, requires_grad=True),
354            ctors.rand(2, dtype=torch.double, requires_grad=True),
355        )
356        v = (
357            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
358            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
359        )
360
361        gradcheck(
362            lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[
363                1
364            ],
365            inputs + v,
366        )
367        gradgradcheck(
368            lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[
369                1
370            ],
371            inputs + v,
372        )
373
374        def foo(*args):
375            x, y = args[:2]
376            v = args[2:]
377
378            x = x.cos()
379            val, grad = autogradF.vjp(adder, (x, y), v, create_graph=True)
380
381            return (
382                val[0].exp()
383                + val[1].exp()
384                + grad[0].exp()
385                + grad[1].exp()
386                + x.exp()
387                + y.exp()
388            )
389
390        gradcheck(foo, inputs + v)
391        gradgradcheck(foo, inputs + v)
392
393    @base_and_logging_tensor
394    def test_jvp_err_check(self, ctors):
395        def foo(a):
396            return 3 * a.narrow(0, 0, 3)
397
398        def bar(a):
399            return 3 * a.narrow(0, 0, 3), "bar"
400
401        inp = ctors.rand(4)
402        v = ctors.rand(4)
403        with self.assertRaisesRegex(
404            TypeError, "The inputs given to jvp must be either a Tensor"
405        ):
406            res = autogradF.jvp(foo, (inp, 2), v)
407
408        with self.assertRaisesRegex(
409            TypeError, "The outputs of the user-provided function given to jvp must"
410        ):
411            res = autogradF.jvp(bar, inp, v)
412
413        with self.assertRaisesRegex(
414            RuntimeError,
415            "The vector v can only be None if the input to the user-provided function",
416        ):
417            res = autogradF.jvp(foo, inp)
418
419        with self.assertRaisesRegex(
420            RuntimeError, "The given v should contain a single Tensor."
421        ):
422            res = autogradF.jvp(foo, inp, (v, v))
423
424        with self.assertRaisesRegex(
425            RuntimeError, "v has invalid size: should be torch.Size"
426        ):
427            res = autogradF.jvp(foo, inp, v[:2])
428
429        res = autogradF.jvp(foo, inp, v)[1]
430        self._assert_same_struct(res, foo(inp))
431
432    @base_and_logging_tensor
433    def test_jvp_err_check_strict(self, ctors):
434        def foo(a):
435            return a.detach()
436
437        def bar(a):
438            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
439            return a.long().float().requires_grad_().clone()
440
441        inp = ctors.rand(4)
442        v = ctors.rand(4)
443        with self.assertRaisesRegex(
444            RuntimeError,
445            "Output 0 of the user-provided function does not require gradients.",
446        ):
447            res = autogradF.jvp(foo, inp, v, strict=True)
448        res = autogradF.jvp(foo, inp, v, strict=False)
449        self._assert_same_struct(res[1], res[0])
450        self.assertEqual(res[1].abs().sum(), 0.0)
451
452        with self.assertRaisesRegex(
453            RuntimeError,
454            "The output of the user-provided function is independent of input 0",
455        ):
456            res = autogradF.jvp(bar, inp, v, strict=True)
457        res = autogradF.jvp(bar, inp, v, strict=False)
458        self._assert_same_struct(res[1], res[0])
459        self.assertEqual(res[1].abs().sum(), 0.0)
460
461        # The Jacobian does not depend on the input
462        def foo(a):
463            return a.clone()
464
465        inp.requires_grad_()
466        with self.assertRaisesRegex(
467            RuntimeError,
468            "jacobian of the user-provided function is independent of input 0.",
469        ):
470            res = autogradF.jvp(foo, inp, v, create_graph=True, strict=True)
471        res = autogradF.jvp(foo, inp, v, create_graph=True, strict=False)
472        self._assert_same_struct(res[1], inp)
473        self.assertEqual(res[1], v)
474
475    @base_and_logging_tensor
476    def test_jvp_no_grad(self, ctors):
477        def reducer(x):
478            return x.sum(dim=1)
479
480        inputs = ctors.rand(4, 4)
481        v = ctors.ones(4, 4)
482        with torch.no_grad():
483            res = autogradF.jvp(reducer, inputs, v)
484        self.assertIsNone(res[0].grad_fn)
485        self.assertIsNone(res[1].grad_fn)
486        self.assertNotEqual(res[1], ctors.zeros(4, 4))
487
488        inputs.requires_grad_()
489        v.requires_grad_()
490        with torch.no_grad():
491            res = autogradF.jvp(reducer, inputs, v, create_graph=True)
492        self.assertIsNotNone(res[0].grad_fn)
493        self.assertIsNotNone(res[1].grad_fn)
494        self.assertNotEqual(res[1], ctors.zeros(4, 4))
495
496    @base_and_logging_tensor
497    def test_jvp_output(self, ctors):
498        def reducer(x):
499            return x.sum(dim=1)
500
501        inputs = ctors.rand(4, 4)
502        v = ctors.ones(4, 4)
503        res = autogradF.jvp(reducer, inputs, v)
504        self._assert_same_struct(res[1], res[0])
505        self.assertIsNone(res[0].grad_fn)
506        self.assertIsNone(res[1].grad_fn)
507
508        def adder(x, y):
509            return 2 * x + 3 * y
510
511        inputs = (ctors.rand(2), ctors.rand(2))
512        v = (ctors.ones(2), ctors.ones(2))
513        out, jvp_val = autogradF.jvp(adder, inputs, v)
514        self._assert_same_struct(jvp_val, out)
515        self.assertIsNone(out.grad_fn)
516        self.assertIsNone(jvp_val[0].grad_fn)
517        self.assertIsNone(jvp_val[1].grad_fn)
518
519        def adder(x, y):
520            return 2 * x + 3 * y, x + y
521
522        inputs = (ctors.rand(2), ctors.rand(2))
523        v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0]))
524        out, jvp_val = autogradF.jvp(adder, inputs, v)
525        self._assert_same_struct(jvp_val, out)
526        self.assertIsNone(out[0].grad_fn)
527        self.assertIsNone(out[1].grad_fn)
528        self.assertIsNone(jvp_val[0].grad_fn)
529        self.assertIsNone(jvp_val[1].grad_fn)
530
531    @base_and_logging_tensor
532    def test_jvp_scalar(self, ctors):
533        def reducer(x):
534            return x.sum()
535
536        inputs = ctors.rand(4, 4)
537        v = ctors.ones(4, 4)
538        res = autogradF.jvp(reducer, inputs, v)
539        self._assert_same_struct(res[0], ctors.zeros([]))
540        self._assert_same_struct(res[1], res[0])
541
542        def expander(x):
543            return x.unsqueeze(0).repeat(4)
544
545        inputs = ctors.rand([])
546        v = ctors.ones([])
547        res = autogradF.jvp(expander, inputs, v)
548        self._assert_same_struct(res[0], ctors.zeros(4))
549        self._assert_same_struct(res[1], res[0])
550
551        res = autogradF.jvp(expander, inputs)
552        self._assert_same_struct(res[0], ctors.zeros(4))
553        self._assert_same_struct(res[1], res[0])
554
555    @base_and_logging_tensor
556    def test_jvp_create_graph(self, ctors):
557        def reducer(x):
558            return x.sum(dim=1)
559
560        inputs = ctors.rand(2, 2, dtype=torch.double)
561        v = ctors.ones(2, 2, dtype=torch.double)
562
563        inputs.requires_grad_()
564        v.requires_grad_()
565        res = autogradF.jvp(reducer, inputs, v, create_graph=True)
566        self._assert_same_struct(res[1], res[0])
567        self.assertIsNotNone(res[0].grad_fn)
568        self.assertIsNotNone(res[1].grad_fn)
569
570        gradcheck(
571            lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True),
572            (inputs, v),
573        )
574        gradgradcheck(
575            lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True),
576            (inputs, v),
577        )
578
579        def adder(x, y):
580            return 2 * x + 3 * y, x * y
581
582        inputs = (
583            ctors.rand(2, dtype=torch.double, requires_grad=True),
584            ctors.rand(2, dtype=torch.double, requires_grad=True),
585        )
586        v = (
587            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
588            ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True),
589        )
590
591        gradcheck(
592            lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[
593                1
594            ],
595            inputs + v,
596        )
597        gradgradcheck(
598            lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[
599                1
600            ],
601            inputs + v,
602        )
603
604        def foo(*args):
605            x, y = args[:2]
606            v = args[2:]
607
608            x = x.cos()
609            val, grad = autogradF.jvp(adder, (x, y), v, create_graph=True)
610
611            return (
612                val[0].exp()
613                + val[1].exp()
614                + grad[0].exp()
615                + grad[1].exp()
616                + x.exp()
617                + y.exp()
618            )
619
620        gradcheck(foo, inputs + v)
621        gradgradcheck(foo, inputs + v)
622
623    def _test_construct_standard_basis_for(self, inputs):
624        numels = tuple(tensor.numel() for tensor in inputs)
625        results = autogradF._construct_standard_basis_for(inputs, numels)
626        for result, inp in zip(results, inputs):
627            self.assertEqual(result.dtype, inp.dtype)
628            self.assertEqual(result.device, inp.device)
629        results = torch.cat(
630            [result.to(device="cpu", dtype=torch.float) for result in results], dim=1
631        )
632        expected = torch.eye(results[0].shape[0], dtype=torch.float)
633        self.assertEqual(results, expected)
634
635    @base_and_logging_tensor
636    def test_construct_standard_basis_for(self, ctors):
637        test_cases = [
638            (ctors.randn(2, 3),),
639            (ctors.randn(1),),
640            (ctors.randn([]),),
641            (ctors.randn(1), ctors.randn([]), ctors.randn([])),
642            (ctors.randn(2), ctors.randn(3), ctors.randn([])),
643            (ctors.randn(2), ctors.randn([]), ctors.randn(3)),
644            (ctors.randn(2, 3), ctors.randn(3), ctors.randn(3, 4, 2)),
645            (ctors.randn(2, dtype=torch.float64), ctors.randn(3, dtype=torch.float32)),
646        ]
647
648        for inputs in test_cases:
649            self._test_construct_standard_basis_for(inputs)
650
651    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
652    @base_and_logging_tensor
653    def test_construct_standard_basis_for_cuda(self, ctors):
654        test_cases = [
655            (ctors.randn(2), ctors.randn(3, device="cuda")),
656            (ctors.randn(3, device="cuda"), ctors.randn(2)),
657        ]
658
659        for inputs in test_cases:
660            self._test_construct_standard_basis_for(inputs)
661
662    def _test_vectorize_raises_no_warnings(self, api, ctors):
663        # vmap is an experimental prototype. When someone calls torch.vmap,
664        # it raises a python warning. This test checks that
665        # autogradF.{jacobian, hessian} don't raise that experimental prototype
666        # warning; it is not nice for a public-facing API to raise a warning
667        # no matter how it is called.
668        def foo(a):
669            return (a**2).sum()
670
671        x = ctors.randn(3)
672        with warnings.catch_warnings(record=True) as wa:
673            result = api(foo, x, vectorize=True)
674        self.assertEqual(len(wa), 0)
675
676    @base_and_logging_tensor
677    def test_jacobian_vectorize_raises_no_warnings(self, ctors):
678        return self._test_vectorize_raises_no_warnings(autogradF.jacobian, ctors)
679
680    @base_and_logging_tensor
681    def test_hessian_vectorize_raises_no_warnings(self, ctors):
682        return self._test_vectorize_raises_no_warnings(autogradF.hessian, ctors)
683
684    @parametrize("vectorize", [True, False])
685    @base_and_logging_tensor
686    def test_jacobian_err_check(self, vectorize, ctors):
687        def foo(a):
688            return 3 * a.narrow(0, 0, 3)
689
690        def bar(a):
691            return 3 * a.narrow(0, 0, 3), "bar"
692
693        inp = ctors.rand(4)
694        with self.assertRaisesRegex(
695            TypeError, "The inputs given to jacobian must be either a Tensor"
696        ):
697            res = autogradF.jacobian(foo, (inp, 2), vectorize=vectorize)
698
699        with self.assertRaisesRegex(
700            TypeError,
701            "The outputs of the user-provided function given to jacobian must",
702        ):
703            res = autogradF.jacobian(bar, inp, vectorize=vectorize)
704
705        res = autogradF.jacobian(foo, inp, vectorize=vectorize)
706        self._assert_interleaved_struct(res, foo(inp), inp)
707
708        def foo(a, b):
709            return b, 3 * a.narrow(0, 0, 3)
710
711        inp = (ctors.rand(4), ctors.rand(5))
712
713        res = autogradF.jacobian(foo, inp, vectorize=vectorize)
714        self._assert_interleaved_struct(res, foo(*inp), inp)
715
716    @base_and_logging_tensor
717    def test_jacobian_err_check_strict(self, ctors):
718        def foo(a):
719            return a.detach()
720
721        def bar(a):
722            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
723            return a.long().float().requires_grad_().clone()
724
725        inp = ctors.rand(4)
726        with self.assertRaisesRegex(
727            RuntimeError,
728            "Output 0 of the user-provided function does not require gradients.",
729        ):
730            res = autogradF.jacobian(foo, inp, strict=True)
731        res = autogradF.jacobian(foo, inp, strict=False)
732        self._assert_interleaved_struct(res, foo(inp), inp)
733        self.assertEqual(res.abs().sum(), 0.0)
734
735        with self.assertRaisesRegex(
736            RuntimeError,
737            "Output 0 of the user-provided function is independent of input 0.",
738        ):
739            res = autogradF.jacobian(bar, inp, strict=True)
740        res = autogradF.jacobian(bar, inp, strict=False)
741        self._assert_interleaved_struct(res, foo(inp), inp)
742        self.assertEqual(res.abs().sum(), 0.0)
743
744        # The Jacobian does not depend on the input
745        def foo(a):
746            return a.clone()
747
748        inp.requires_grad_()
749        with self.assertRaisesRegex(
750            RuntimeError,
751            "jacobian of the user-provided function is independent of input 0.",
752        ):
753            res = autogradF.jacobian(foo, inp, create_graph=True, strict=True)
754        res = autogradF.jacobian(foo, inp, create_graph=True, strict=False)
755        self._assert_interleaved_struct(res, inp, inp)
756        self.assertEqual(res, torch.eye(4))
757
758    @base_and_logging_tensor
759    def test_jacobian_err_check_strict_vectorize(self, ctors):
760        def foo(x):
761            return x
762
763        inp = ctors.rand(4)
764        with self.assertRaisesRegex(RuntimeError, "not supported together"):
765            res = autogradF.jacobian(foo, inp, strict=True, vectorize=True)
766
767    @base_and_logging_tensor
768    def test_jacobian_no_grad(self, ctors):
769        def exp_reducer(x):
770            return x.exp().sum(dim=1)
771
772        inputs = ctors.rand(4, 4)
773        with torch.no_grad():
774            res = autogradF.jacobian(exp_reducer, inputs)
775        self.assertIsNone(res.grad_fn)
776        self.assertNotEqual(res, ctors.zeros(4, 4))
777
778        with torch.no_grad():
779            res = autogradF.jacobian(exp_reducer, inputs, create_graph=True)
780        self.assertIsNotNone(res.grad_fn)
781        self.assertNotEqual(res, ctors.zeros(4, 4))
782
783    @vectorized_logging_tensor
784    def test_jacobian_output(self, vectorize, ctors):
785        def exp_reducer(x):
786            return x.exp().sum(dim=1)
787
788        inputs = ctors.rand(4, 4)
789        res = autogradF.jacobian(exp_reducer, inputs, vectorize=vectorize)
790        self._assert_interleaved_struct(res, exp_reducer(inputs), inputs)
791        self.assertIsNone(res.grad_fn)
792
793        def identity(x):
794            return x.clone()
795
796        inputs = ctors.rand(4)
797        res = autogradF.jacobian(identity, inputs, vectorize=vectorize)
798        self._assert_interleaved_struct(res, identity(inputs), inputs)
799        self.assertIsNone(res.grad_fn)
800        self.assertEqual(res, torch.eye(4))
801
802        def add_exp_reducer(x, y):
803            return (x + y.exp()).sum(dim=1)
804
805        inputs = (ctors.rand(4, 4), ctors.rand(4, 4))
806        res = autogradF.jacobian(add_exp_reducer, inputs, vectorize=vectorize)
807        self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs)
808        self.assertIsNone(res[0].grad_fn)
809        self.assertIsNone(res[1].grad_fn)
810
811    @vectorized_logging_tensor
812    def test_jacobian_scalar(self, vectorize, ctors):
813        def reducer(x):
814            return x.sum()
815
816        inputs = ctors.rand(4, 4)
817        res = autogradF.jacobian(reducer, inputs, vectorize=vectorize)
818        self._assert_same_struct(res, inputs)
819
820        def expander(x):
821            return x.unsqueeze(0).repeat(4)
822
823        inputs = ctors.rand([])
824        res = autogradF.jacobian(expander, inputs, vectorize=vectorize)
825        self._assert_same_struct(res, ctors.zeros(4))
826
827    @parametrize("vectorize", [True, False])
828    @base_and_logging_tensor
829    def test_jacobian_create_graph(self, vectorize, ctors):
830        def exp_reducer(x):
831            return x.exp().sum(dim=1)
832
833        inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True)
834        res = autogradF.jacobian(
835            exp_reducer, inputs, create_graph=True, vectorize=vectorize
836        )
837        self._assert_interleaved_struct(res, exp_reducer(inputs), inputs)
838        self.assertIsNotNone(res.grad_fn)
839
840        gradcheck(
841            lambda inp: autogradF.jacobian(
842                exp_reducer, inp, create_graph=True, vectorize=vectorize
843            ),
844            inputs,
845        )
846        gradgradcheck(
847            lambda inp: autogradF.jacobian(
848                exp_reducer, inp, create_graph=True, vectorize=vectorize
849            ),
850            inputs,
851        )
852
853        def add_exp_reducer(x, y):
854            return (x + y).exp().sum(dim=1)
855
856        inputs = (
857            ctors.rand(4, 4, dtype=torch.double, requires_grad=True),
858            ctors.rand(4, 4, dtype=torch.double, requires_grad=True),
859        )
860        res = autogradF.jacobian(
861            add_exp_reducer, inputs, create_graph=True, vectorize=vectorize
862        )
863        self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs)
864        self.assertIsNotNone(res[0].grad_fn)
865        self.assertIsNotNone(res[1].grad_fn)
866
867        gradcheck(
868            lambda *inp: autogradF.jacobian(
869                add_exp_reducer, inp, create_graph=True, vectorize=vectorize
870            ),
871            inputs,
872        )
873        gradgradcheck(
874            lambda *inp: autogradF.jacobian(
875                add_exp_reducer, inp, create_graph=True, vectorize=vectorize
876            ),
877            inputs,
878        )
879
880        def foo(x, y):
881            x = x.cos()
882            val, jac = autogradF.jacobian(
883                add_exp_reducer, (x, y), create_graph=True, vectorize=vectorize
884            )
885
886            res = val[0].exp().sum() + val[1].exp().sum() + jac[0].exp().sum()
887            res = res + jac[1].exp().sum() + x.exp().sum() + y.exp().sum()
888            return res
889
890        gradcheck(foo, inputs)
891        gradgradcheck(foo, inputs)
892
893    def _check_jacobian_vectorize_correctness(self, f, inputs, test_forward_ad=True):
894        expected = autogradF.jacobian(f, inputs, vectorize=False)
895        result_backward_mode = autogradF.jacobian(f, inputs, vectorize=True)
896        self.assertEqual(result_backward_mode, expected)
897
898        if test_forward_ad:
899            result_forward_mode = autogradF.jacobian(
900                f, inputs, strategy="forward-mode", vectorize=True
901            )
902            self.assertEqual(result_forward_mode, expected)
903
904    @base_and_logging_tensor
905    def test_jacobian_vectorize_correctness_simple(self, ctors):
906        def f(x):
907            return 3 * x**2
908
909        x = ctors.randn(2, 3, 5)
910        self._check_jacobian_vectorize_correctness(f, x)
911
912    @base_and_logging_tensor
913    def test_jacobian_vectorize_correctness_multi_input(self, ctors):
914        def f(x, y):
915            return (x.cos() * x) @ y.sin()
916
917        x = ctors.randn(2, 3)
918        y = ctors.randn(3, 5)
919        self._check_jacobian_vectorize_correctness(f, (x, y))
920
921    @base_and_logging_tensor
922    def test_jacobian_vectorize_correctness_multi_input_multi_output(self, ctors):
923        def f(x, y):
924            return (x * x) @ y, x @ (x.sum(1) * y), y.sum()
925
926        x = ctors.randn(5, 3)
927        y = ctors.randn(3, 5)
928        self._check_jacobian_vectorize_correctness(f, (x, y))
929
930    @base_and_logging_tensor
931    def test_jacobian_vectorize_correctness_unrelated_outputs(self, ctors):
932        def f(x, y):
933            return x, y, x, y
934
935        x = ctors.randn(2)
936        y = ctors.randn(3)
937        self._check_jacobian_vectorize_correctness(f, (x, y))
938
939    @base_and_logging_tensor
940    def test_jacobian_vectorize_correctness_zero_dim(self, ctors):
941        # zero-dim output
942        def f(x, y):
943            return x.sum(), y.sum(), x * y
944
945        x = ctors.randn(3)
946        y = ctors.randn(3)
947        self._check_jacobian_vectorize_correctness(f, (x, y))
948
949        # zero-dim input
950        def g(x):
951            return torch.stack([x, x, x])
952
953        x = ctors.randn([])
954        self._check_jacobian_vectorize_correctness(g, x)
955
956        # Mixed zero-dim input / zero-dim output
957        def h(x, y):
958            return y.sum(), x * y
959
960        x = ctors.randn([])
961        y = ctors.randn(1)
962        self._check_jacobian_vectorize_correctness(h, (x, y))
963
964    @unittest.skipIf(not TEST_CUDA, "test requires CUDA")
965    @base_and_logging_tensor
966    def test_jacobian_vectorize_correctness_different_devices(self, ctors):
967        def f(x, y):
968            return x * y, (x * y).cuda()
969
970        x = ctors.randn(3)
971        y = ctors.randn(3)
972        self._check_jacobian_vectorize_correctness(f, (x, y))
973
974    @base_and_logging_tensor
975    def test_jacobian_vectorize_correctness_different_dtype(self, ctors):
976        def f(x, y):
977            return (x * y).float(), (x * y).double()
978
979        x = ctors.randn(3)
980        y = ctors.randn(3)
981        # The Jacobian computed using forward AD has the dtype of the output
982        # but the Jacobian computed with reverse AD has dtype of input
983        self._check_jacobian_vectorize_correctness(f, (x, y), test_forward_ad=False)
984
985    def _check_hessian_vectorize_correctness(self, f, inputs):
986        expected = autogradF.hessian(f, inputs, vectorize=False)
987        result = autogradF.hessian(f, inputs, vectorize=True)
988        self.assertEqual(result, expected)
989
990        result_forward_mode = autogradF.hessian(
991            f, inputs, outer_jacobian_strategy="forward-mode", vectorize=True
992        )
993        self.assertEqual(result_forward_mode, expected)
994
995    @base_and_logging_tensor
996    def test_hessian_vectorize_correctness_simple(self, ctors):
997        def f(x):
998            return (3 * x**2).sum()
999
1000        x = ctors.randn(2, 3, 5)
1001        self._check_hessian_vectorize_correctness(f, x)
1002
1003    @base_and_logging_tensor
1004    def test_hessian_vectorize_correctness_multi_input(self, ctors):
1005        def f(x, y, z):
1006            return ((x.relu() * x) @ y.sin() @ z).sum()
1007
1008        x = ctors.randn(2, 3)
1009        y = ctors.randn(3, 5)
1010        z = ctors.randn(5, 5)
1011        self._check_hessian_vectorize_correctness(f, (x, y, z))
1012
1013    @base_and_logging_tensor
1014    def test_hessian_vectorize_correctness_unrelated_outputs(self, ctors):
1015        # output unrelated to one input
1016        def f(x, y):
1017            return (x**2).sum()
1018
1019        x = ctors.randn(2)
1020        y = ctors.randn(3)
1021        self._check_hessian_vectorize_correctness(f, (x, y))
1022
1023        # output unrelated to all inputs
1024        def f(x, y):
1025            return ctors.ones([])
1026
1027        x = ctors.randn(2)
1028        y = ctors.randn(3)
1029        self._check_hessian_vectorize_correctness(f, (x, y))
1030
1031    @parametrize("vectorize", [True, False])
1032    @base_and_logging_tensor
1033    def test_hessian_err_check(self, vectorize, ctors):
1034        def foo(a):
1035            return 3 * a.narrow(0, 0, 3).exp().sum()
1036
1037        def bar(a):
1038            return 3 * a.narrow(0, 0, 3), "bar"
1039
1040        def bar2(a):
1041            return 3 * a.narrow(0, 0, 3)
1042
1043        def bar3(a):
1044            return 3 * a.narrow(0, 0, 3), 3 * a.narrow(0, 0, 3)
1045
1046        inp = ctors.rand(4)
1047        with self.assertRaisesRegex(
1048            TypeError, "The inputs given to hessian must be either a Tensor"
1049        ):
1050            res = autogradF.hessian(foo, (inp, 2), vectorize=vectorize)
1051
1052        with self.assertRaisesRegex(
1053            TypeError, "The outputs of the user-provided function given to hessian must"
1054        ):
1055            res = autogradF.hessian(bar, inp, vectorize=vectorize)
1056
1057        err_msg_out = "The Tensor returned by the function given to hessian should contain a single element"
1058        with self.assertRaisesRegex(RuntimeError, err_msg_out):
1059            res = autogradF.hessian(bar2, inp, vectorize=vectorize)
1060
1061        with self.assertRaisesRegex(
1062            RuntimeError, "The function given to hessian should return a single Tensor"
1063        ):
1064            res = autogradF.hessian(bar3, inp, vectorize=vectorize)
1065
1066        res = autogradF.hessian(foo, inp, vectorize=vectorize)
1067        self._assert_interleaved_struct(res, inp, inp)
1068
1069        def foo(a, b):
1070            return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum()
1071
1072        inp = (ctors.rand(4), ctors.rand(5))
1073
1074        res = autogradF.hessian(foo, inp, vectorize=vectorize)
1075        self._assert_interleaved_struct(res, inp, inp)
1076
1077    @base_and_logging_tensor
1078    def test_hessian_err_check_strict(self, ctors):
1079        def foo(a):
1080            return a.detach().sum()
1081
1082        def bar(a):
1083            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
1084            return a.long().float().requires_grad_().clone().sum()
1085
1086        def bar2(a):
1087            # A Linear function for which the jacobian is independent of the input
1088            return (3 * a).sum()
1089
1090        inp = ctors.rand(4)
1091        with self.assertRaisesRegex(
1092            RuntimeError,
1093            "Output 0 of the user-provided function does not require gradients.",
1094        ):
1095            res = autogradF.hessian(foo, inp, strict=True)
1096        res = autogradF.hessian(foo, inp, strict=False)
1097        self._assert_interleaved_struct(res, inp, inp)
1098        self.assertEqual(res.abs().sum(), 0.0)
1099
1100        with self.assertRaisesRegex(
1101            RuntimeError,
1102            "jacobian of the user-provided function with respect to input 0",
1103        ):
1104            res = autogradF.hessian(bar, inp, strict=True)
1105        res = autogradF.hessian(bar, inp, strict=False)
1106        self._assert_interleaved_struct(res, inp, inp)
1107        self.assertEqual(res.abs().sum(), 0.0)
1108
1109        with self.assertRaisesRegex(
1110            RuntimeError,
1111            "jacobian of the user-provided function with respect to input 0 is",
1112        ):
1113            res = autogradF.hessian(bar2, inp, strict=True)
1114        res = autogradF.hessian(bar2, inp, strict=False)
1115        self._assert_interleaved_struct(res, inp, inp)
1116        self.assertEqual(res.abs().sum(), 0.0)
1117
1118    @base_and_logging_tensor
1119    def test_hessian_err_check_strict_vectorize(self, ctors):
1120        def foo(x):
1121            return (x**3).sum()
1122
1123        inp = ctors.rand(4)
1124        with self.assertRaisesRegex(RuntimeError, "not supported together"):
1125            res = autogradF.hessian(foo, inp, strict=True, vectorize=True)
1126
1127    @base_and_logging_tensor
1128    def test_hessian_no_grad(self, ctors):
1129        def pow_reducer(x):
1130            return x.pow(3).sum()
1131
1132        inputs = ctors.rand(2, 2)
1133        with torch.no_grad():
1134            res = autogradF.hessian(pow_reducer, inputs)
1135        self.assertIsNone(res[0][0].grad_fn)
1136        self.assertIsNone(res[0][1].grad_fn)
1137        self.assertIsNone(res[1][0].grad_fn)
1138        self.assertIsNone(res[1][1].grad_fn)
1139        self.assertNotEqual(res, ctors.zeros(2, 2, 2))
1140
1141        with torch.no_grad():
1142            res = autogradF.hessian(pow_reducer, inputs, create_graph=True)
1143        self.assertIsNotNone(res[0][0].grad_fn)
1144        self.assertIsNotNone(res[0][1].grad_fn)
1145        self.assertIsNotNone(res[1][0].grad_fn)
1146        self.assertIsNotNone(res[1][1].grad_fn)
1147        self.assertNotEqual(res, ctors.zeros(2, 2, 2))
1148
1149    @vectorized_logging_tensor
1150    def test_hessian_output(self, vectorize, ctors):
1151        def pow_reducer(x):
1152            return x.pow(3).sum()
1153
1154        inputs = ctors.rand(2, 2)
1155        res = autogradF.hessian(pow_reducer, inputs, vectorize=vectorize)
1156        self._assert_interleaved_struct(res, inputs, inputs)
1157        self.assertIsNone(res.grad_fn)
1158
1159        def add_pow_reducer(x, y):
1160            return (x + y).pow(3).sum()
1161
1162        inputs = (ctors.rand(2, 2), ctors.rand(2, 2))
1163        res = autogradF.hessian(add_pow_reducer, inputs, vectorize=vectorize)
1164        self._assert_interleaved_struct(res, inputs, inputs)
1165        self.assertIsNone(res[0][0].grad_fn)
1166        self.assertIsNone(res[0][1].grad_fn)
1167        self.assertIsNone(res[1][0].grad_fn)
1168        self.assertIsNone(res[1][1].grad_fn)
1169
1170    @parametrize("vectorize", [True, False])
1171    @base_and_logging_tensor
1172    def test_hessian_scalar(self, vectorize, ctors):
1173        def reducer(x):
1174            return x.sum()
1175
1176        inputs = ctors.rand(4, 4)
1177        res = autogradF.hessian(reducer, inputs, vectorize=vectorize)
1178        self._assert_interleaved_struct(res, inputs, inputs)
1179
1180        inputs = ctors.rand([])
1181        res = autogradF.hessian(reducer, inputs, vectorize=vectorize)
1182        self._assert_same_struct(res, inputs)
1183
1184        def bad_reducer(x):
1185            return x.sum().view(1, 1, 1)
1186
1187        inputs = ctors.rand(4, 4)
1188        res = autogradF.hessian(bad_reducer, inputs, vectorize=vectorize)
1189        self._assert_interleaved_struct(res, inputs, inputs)
1190
1191    @parametrize("vectorize", [True, False])
1192    @base_and_logging_tensor
1193    def test_hessian_create_graph(self, vectorize, ctors):
1194        def pow_reducer(x):
1195            return x.pow(3).sum()
1196
1197        inputs = ctors.rand(2, 2, dtype=torch.double, requires_grad=True)
1198        res = autogradF.hessian(
1199            pow_reducer, inputs, create_graph=True, vectorize=vectorize
1200        )
1201        self._assert_interleaved_struct(res, inputs, inputs)
1202        self.assertIsNotNone(res.grad_fn)
1203
1204        gradcheck(
1205            lambda inp: autogradF.hessian(
1206                pow_reducer, inp, create_graph=True, vectorize=vectorize
1207            ),
1208            inputs,
1209        )
1210        gradgradcheck(
1211            lambda inp: autogradF.hessian(
1212                pow_reducer, inp, create_graph=True, vectorize=vectorize
1213            ),
1214            inputs,
1215        )
1216
1217        def add_pow_reducer(x, y):
1218            return (x + y).pow(3).sum()
1219
1220        inputs = (
1221            ctors.rand(2, 2, dtype=torch.double, requires_grad=True),
1222            ctors.rand(2, 2, dtype=torch.double, requires_grad=True),
1223        )
1224        res = autogradF.hessian(
1225            add_pow_reducer, inputs, create_graph=True, vectorize=vectorize
1226        )
1227        self._assert_interleaved_struct(res, inputs, inputs)
1228        self.assertIsNotNone(res[0][0].grad_fn)
1229        self.assertIsNotNone(res[0][1].grad_fn)
1230        self.assertIsNotNone(res[1][0].grad_fn)
1231        self.assertIsNotNone(res[1][1].grad_fn)
1232
1233        def flatten(inp):
1234            return tuple(el_lvl2 for el_lvl1 in inp for el_lvl2 in el_lvl1)
1235
1236        gradcheck(
1237            lambda *inp: flatten(
1238                autogradF.hessian(
1239                    add_pow_reducer, inp, create_graph=True, vectorize=vectorize
1240                )
1241            ),
1242            inputs,
1243        )
1244        gradgradcheck(
1245            lambda *inp: flatten(
1246                autogradF.hessian(
1247                    add_pow_reducer, inp, create_graph=True, vectorize=vectorize
1248                )
1249            ),
1250            inputs,
1251        )
1252
1253        def foo(x, y):
1254            x = x.cos()
1255            val, hess = autogradF.hessian(
1256                add_pow_reducer, (x, y), create_graph=True, vectorize=vectorize
1257            )
1258
1259            res = val[0].cos().sum() + val[1].cos().sum() + hess[0].cos().sum()
1260            res = res + hess[1].cos().sum() + x.cos().sum() + y.cos().sum()
1261            return res
1262
1263        gradcheck(foo, inputs)
1264        gradgradcheck(foo, inputs)
1265
1266    @base_and_logging_tensor
1267    def test_vhp_err_check(self, ctors):
1268        def foo(a):
1269            return 3 * a.narrow(0, 0, 3).exp().sum()
1270
1271        def bar(a):
1272            return 3 * a.narrow(0, 0, 3), "bar"
1273
1274        def bar2(a):
1275            return 3 * a.narrow(0, 0, 3)
1276
1277        inp = ctors.rand(4)
1278        v = ctors.rand(4)
1279        with self.assertRaisesRegex(
1280            TypeError, "The inputs given to vhp must be either a Tensor"
1281        ):
1282            res = autogradF.vhp(foo, (inp, 2), v)
1283
1284        with self.assertRaisesRegex(
1285            TypeError, "The outputs of the user-provided function given to vhp must"
1286        ):
1287            res = autogradF.vhp(bar, inp, v)
1288
1289        err_msg_out = "The Tensor returned by the function given to vhp should contain a single element"
1290        with self.assertRaisesRegex(RuntimeError, err_msg_out):
1291            res = autogradF.vhp(bar2, inp, v)
1292
1293        with self.assertRaisesRegex(RuntimeError, "v has invalid size:"):
1294            res = autogradF.vhp(foo, inp, ctors.rand(5))
1295
1296        with self.assertRaisesRegex(
1297            TypeError,
1298            "The v given to vhp must be either a Tensor or a tuple of Tensors",
1299        ):
1300            res = autogradF.vhp(foo, inp, (v, 2))
1301
1302        res = autogradF.vhp(foo, inp, v)
1303        self._assert_same_struct(res[1], inp)
1304
1305        def foo(a, b):
1306            return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum()
1307
1308        inp = (ctors.rand(4), ctors.rand(5))
1309        v = (ctors.rand(4), ctors.rand(5))
1310
1311        res = autogradF.vhp(foo, inp, v)
1312        self._assert_same_struct(res[1], inp)
1313
1314    @base_and_logging_tensor
1315    def test_vhp_err_check_strict(self, ctors):
1316        def foo(a):
1317            return a.detach().sum()
1318
1319        def bar(a):
1320            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
1321            return a.long().float().requires_grad_().clone().sum()
1322
1323        def bar2(a):
1324            # A Linear function for which the jacobian is independent of the input
1325            return (3 * a).sum()
1326
1327        inp = ctors.rand(4)
1328        v = ctors.rand(4)
1329        with self.assertRaisesRegex(
1330            RuntimeError,
1331            "Output 0 of the user-provided function does not require gradients.",
1332        ):
1333            res = autogradF.vhp(foo, inp, v, strict=True)
1334        res = autogradF.vhp(foo, inp, v, strict=False)
1335        self._assert_same_struct(res[1], inp)
1336        self.assertEqual(res[1].abs().sum(), 0.0)
1337
1338        with self.assertRaisesRegex(
1339            RuntimeError,
1340            "The output of the user-provided function is independent of input 0",
1341        ):
1342            res = autogradF.vhp(bar, inp, v, strict=True)
1343        res = autogradF.vhp(bar, inp, v, strict=False)
1344        self._assert_same_struct(res[1], inp)
1345        self.assertEqual(res[1].abs().sum(), 0.0)
1346
1347        with self.assertRaisesRegex(
1348            RuntimeError,
1349            "jacobian of the user-provided function with respect to input 0 is",
1350        ):
1351            res = autogradF.vhp(bar2, inp, v, strict=True)
1352        res = autogradF.vhp(bar2, inp, v, strict=False)
1353        self._assert_same_struct(res[1], inp)
1354        self.assertEqual(res[1].abs().sum(), 0.0)
1355
1356    @base_and_logging_tensor
1357    def test_vhp_no_grad(self, ctors):
1358        def reducer(x):
1359            return x.exp().sum()
1360
1361        inputs = ctors.rand(4, 4)
1362        v = ctors.ones(4, 4)
1363        with torch.no_grad():
1364            res = autogradF.vhp(reducer, inputs, v)
1365        self.assertIsNone(res[0].grad_fn)
1366        self.assertIsNone(res[1].grad_fn)
1367        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1368
1369        with torch.no_grad():
1370            res = autogradF.vhp(reducer, inputs, v, create_graph=True)
1371        self.assertIsNotNone(res[0].grad_fn)
1372        self.assertIsNotNone(res[1].grad_fn)
1373        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1374
1375    @base_and_logging_tensor
1376    def test_vhp_output(self, ctors):
1377        def foo(a):
1378            return 3 * a.narrow(0, 0, 3).exp().sum()
1379
1380        inputs = ctors.rand(4, 4)
1381        v = ctors.ones(4, 4)
1382        res = autogradF.vhp(foo, inputs, v)
1383        self._assert_same_struct(res[1], inputs)
1384        self.assertIsNone(res[0].grad_fn)
1385        self.assertIsNone(res[1].grad_fn)
1386
1387        def bar(a, b):
1388            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1389
1390        inputs = (ctors.rand(3), ctors.rand(4))
1391        v = (ctors.ones(3), ctors.ones(4))
1392        out, vhp_val = autogradF.vhp(bar, inputs, v)
1393        self._assert_same_struct(vhp_val, inputs)
1394        self.assertIsNone(out.grad_fn)
1395        self.assertIsNone(vhp_val[0].grad_fn)
1396        self.assertIsNone(vhp_val[1].grad_fn)
1397
1398    @base_and_logging_tensor
1399    def test_vhp_scalar(self, ctors):
1400        def reducer(x):
1401            return x.sum()
1402
1403        inputs = ctors.rand(4, 4)
1404        v = ctors.ones(4, 4)
1405        res = autogradF.vhp(reducer, inputs, v)
1406        self._assert_same_struct(res[1], inputs)
1407
1408        inputs = ctors.rand([])
1409        v = ctors.rand([])
1410        res = autogradF.vhp(reducer, inputs, v)
1411        self._assert_same_struct(res[1], inputs)
1412
1413        res = autogradF.vhp(reducer, inputs)
1414        self._assert_same_struct(res[1], inputs)
1415
1416        def bad_reducer(x):
1417            return x.sum().view(1, 1, 1)
1418
1419        inputs = ctors.rand(4, 4)
1420        v = ctors.rand(4, 4)
1421        res = autogradF.vhp(bad_reducer, inputs, v)
1422        self._assert_same_struct(res[1], inputs)
1423
1424    @base_and_logging_tensor
1425    def test_vhp_create_graph(self, ctors):
1426        def foo(a):
1427            return 3 * a.narrow(0, 0, 3).exp().sum()
1428
1429        inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True)
1430        v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True)
1431        res = autogradF.vhp(foo, inputs, v, create_graph=True)
1432        self._assert_same_struct(res[1], inputs)
1433        self.assertIsNotNone(res[0].grad_fn)
1434        self.assertIsNotNone(res[1].grad_fn)
1435
1436        gradcheck(
1437            lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v)
1438        )
1439        gradgradcheck(
1440            lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v)
1441        )
1442
1443        def bar(a, b):
1444            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1445
1446        inputs = (
1447            ctors.rand(3, dtype=torch.double, requires_grad=True),
1448            ctors.rand(4, dtype=torch.double, requires_grad=True),
1449        )
1450        v = (
1451            ctors.ones(3, dtype=torch.double, requires_grad=True),
1452            ctors.ones(4, dtype=torch.double, requires_grad=True),
1453        )
1454        out, vhp_val = autogradF.vhp(bar, inputs, v, create_graph=True)
1455        self._assert_same_struct(vhp_val, inputs)
1456        self.assertIsNotNone(out.grad_fn)
1457        self.assertIsNotNone(vhp_val[0].grad_fn)
1458        self.assertIsNotNone(vhp_val[1].grad_fn)
1459
1460        gradcheck(
1461            lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1],
1462            inputs + v,
1463        )
1464        gradgradcheck(
1465            lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1],
1466            inputs + v,
1467        )
1468
1469        def foo(*args):
1470            x, y = args[:2]
1471            v = args[2:]
1472
1473            x = x.cos()
1474            val, grad = autogradF.vhp(bar, (x, y), v, create_graph=True)
1475
1476            return (
1477                val.cos()
1478                + grad[0].cos().sum()
1479                + grad[1].cos()
1480                + x.cos().sum()
1481                + y.cos()
1482            )
1483
1484        gradcheck(foo, inputs + v)
1485        gradgradcheck(foo, inputs + v)
1486
1487    @base_and_logging_tensor
1488    def test_hvp_err_check(self, ctors):
1489        def foo(a):
1490            return 3 * a.narrow(0, 0, 3).exp().sum()
1491
1492        def bar(a):
1493            return 3 * a.narrow(0, 0, 3), "bar"
1494
1495        def bar2(a):
1496            return 3 * a.narrow(0, 0, 3)
1497
1498        inp = ctors.rand(4)
1499        v = ctors.rand(4)
1500        res = autogradF.hvp(foo, inp, v)
1501        with self.assertRaisesRegex(
1502            TypeError, "The inputs given to hvp must be either a Tensor"
1503        ):
1504            res = autogradF.hvp(foo, (inp, 2), v)
1505
1506        with self.assertRaisesRegex(
1507            TypeError, "The outputs of the user-provided function given to hvp must"
1508        ):
1509            res = autogradF.hvp(bar, inp, v)
1510
1511        err_msg_out = "The Tensor returned by the function given to hvp should contain a single element"
1512        with self.assertRaisesRegex(RuntimeError, err_msg_out):
1513            res = autogradF.hvp(bar2, inp, v)
1514
1515        with self.assertRaisesRegex(RuntimeError, "v has invalid size:"):
1516            res = autogradF.hvp(foo, inp, ctors.rand(5))
1517
1518        with self.assertRaisesRegex(
1519            TypeError,
1520            "The v given to hvp must be either a Tensor or a tuple of Tensors",
1521        ):
1522            res = autogradF.hvp(foo, inp, (v, 2))
1523
1524        res = autogradF.hvp(foo, inp, v)
1525        self._assert_same_struct(res[1], inp)
1526
1527        def foo(a, b):
1528            return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum()
1529
1530        inp = (ctors.rand(4), ctors.rand(5))
1531        v = (ctors.rand(4), ctors.rand(5))
1532
1533        res = autogradF.hvp(foo, inp, v)
1534        self._assert_same_struct(res[1], inp)
1535
1536    @base_and_logging_tensor
1537    def test_hvp_err_check_strict(self, ctors):
1538        def foo(a):
1539            return a.detach().sum()
1540
1541        def bar(a):
1542            # Make a non-leaf Tensor that requires_grad but that is not connected to the input
1543            return a.long().float().requires_grad_().clone().sum()
1544
1545        def bar2(a):
1546            # A Linear function for which the jacobian is independent of the input
1547            return (3 * a).sum()
1548
1549        inp = ctors.rand(4)
1550        v = ctors.rand(4)
1551        with self.assertRaisesRegex(
1552            RuntimeError,
1553            "Output 0 of the user-provided function does not require gradients.",
1554        ):
1555            res = autogradF.hvp(foo, inp, v, strict=True)
1556        res = autogradF.hvp(foo, inp, v, strict=False)
1557        self._assert_same_struct(res[1], inp)
1558        self.assertEqual(res[1].abs().sum(), 0.0)
1559
1560        with self.assertRaisesRegex(
1561            RuntimeError,
1562            "The output of the user-provided function is independent of input 0",
1563        ):
1564            res = autogradF.hvp(bar, inp, v, strict=True)
1565        res = autogradF.hvp(bar, inp, v, strict=False)
1566        self._assert_same_struct(res[1], inp)
1567        self.assertEqual(res[1].abs().sum(), 0.0)
1568
1569        with self.assertRaisesRegex(
1570            RuntimeError,
1571            "jacobian of the user-provided function with respect to input 0 is",
1572        ):
1573            res = autogradF.hvp(bar2, inp, v, strict=True)
1574        res = autogradF.hvp(bar2, inp, v, strict=False)
1575        self._assert_same_struct(res[1], inp)
1576        self.assertEqual(res[1].abs().sum(), 0.0)
1577
1578    @base_and_logging_tensor
1579    def test_hvp_no_grad(self, ctors):
1580        def reducer(x):
1581            return x.exp().sum()
1582
1583        inputs = ctors.rand(4, 4)
1584        v = ctors.ones(4, 4)
1585        with torch.no_grad():
1586            res = autogradF.hvp(reducer, inputs, v)
1587        self.assertIsNone(res[0].grad_fn)
1588        self.assertIsNone(res[1].grad_fn)
1589        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1590
1591        with torch.no_grad():
1592            res = autogradF.hvp(reducer, inputs, v, create_graph=True)
1593        self.assertIsNotNone(res[0].grad_fn)
1594        self.assertIsNotNone(res[1].grad_fn)
1595        self.assertNotEqual(res[1], ctors.zeros(4, 4))
1596
1597    @base_and_logging_tensor
1598    def test_hvp_output(self, ctors):
1599        def foo(a):
1600            return 3 * a.narrow(0, 0, 3).exp().sum()
1601
1602        inputs = ctors.rand(4, 4)
1603        v = ctors.ones(4, 4)
1604        res = autogradF.hvp(foo, inputs, v)
1605        self._assert_same_struct(res[1], inputs)
1606        self.assertIsNone(res[0].grad_fn)
1607        self.assertIsNone(res[1].grad_fn)
1608
1609        def bar(a, b):
1610            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1611
1612        inputs = (ctors.rand(3), ctors.rand(4))
1613        v = (ctors.ones(3), ctors.ones(4))
1614        out, hvp_val = autogradF.hvp(bar, inputs, v)
1615        self._assert_same_struct(hvp_val, inputs)
1616        self.assertIsNone(out.grad_fn)
1617        self.assertIsNone(hvp_val[0].grad_fn)
1618        self.assertIsNone(hvp_val[1].grad_fn)
1619
1620    @base_and_logging_tensor
1621    def test_hvp_scalar(self, ctors):
1622        def reducer(x):
1623            return x.exp().sum()
1624
1625        inputs = ctors.rand(4, 4)
1626        v = ctors.ones(4, 4)
1627        res = autogradF.hvp(reducer, inputs, v)
1628        self._assert_same_struct(res[1], inputs)
1629
1630        inputs = ctors.rand([])
1631        v = ctors.rand([])
1632        res = autogradF.hvp(reducer, inputs, v)
1633        self._assert_same_struct(res[1], inputs)
1634
1635        res = autogradF.hvp(reducer, inputs)
1636        self._assert_same_struct(res[1], inputs)
1637
1638        def bad_reducer(x):
1639            return x.exp().sum().view(1, 1, 1)
1640
1641        inputs = ctors.rand(4, 4)
1642        v = ctors.rand(4, 4)
1643        res = autogradF.hvp(bad_reducer, inputs, v)
1644        self._assert_same_struct(res[1], inputs)
1645
1646    @base_and_logging_tensor
1647    def test_hvp_create_graph(self, ctors):
1648        def foo(a):
1649            return 3 * a.narrow(0, 0, 3).exp().sum()
1650
1651        inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True)
1652        v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True)
1653        res = autogradF.hvp(foo, inputs, v, create_graph=True)
1654        self._assert_same_struct(res[1], inputs)
1655        self.assertIsNotNone(res[0].grad_fn)
1656        self.assertIsNotNone(res[1].grad_fn)
1657
1658        gradcheck(
1659            lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v)
1660        )
1661        gradgradcheck(
1662            lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v)
1663        )
1664
1665        def bar(a, b):
1666            return (a + 3 * b.narrow(0, 0, 3)).exp().sum()
1667
1668        inputs = (
1669            ctors.rand(3, dtype=torch.double, requires_grad=True),
1670            ctors.rand(4, dtype=torch.double, requires_grad=True),
1671        )
1672        v = (
1673            ctors.ones(3, dtype=torch.double, requires_grad=True),
1674            ctors.ones(4, dtype=torch.double, requires_grad=True),
1675        )
1676        out, hvp_val = autogradF.hvp(bar, inputs, v, create_graph=True)
1677        self._assert_same_struct(hvp_val, inputs)
1678        self.assertIsNotNone(out.grad_fn)
1679        self.assertIsNotNone(hvp_val[0].grad_fn)
1680        self.assertIsNotNone(hvp_val[1].grad_fn)
1681
1682        gradcheck(
1683            lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1],
1684            inputs + v,
1685        )
1686        gradgradcheck(
1687            lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1],
1688            inputs + v,
1689        )
1690
1691        def foo(*args):
1692            x, y = args[:2]
1693            v = args[2:]
1694
1695            x = x.cos()
1696            val, grad = autogradF.hvp(bar, (x, y), v, create_graph=True)
1697
1698            return (
1699                val.cos()
1700                + grad[0].cos().sum()
1701                + grad[1].cos()
1702                + x.cos().sum()
1703                + y.cos()
1704            )
1705
1706        gradcheck(foo, inputs + v)
1707        gradgradcheck(foo, inputs + v)
1708
1709    @base_and_logging_tensor
1710    def test_jacobian_match_vjp_jvp(self, ctors):
1711        def foo(x):
1712            return x**3 + x.sum()
1713
1714        inputs = ctors.rand(4)
1715        v = ctors.rand(4)
1716
1717        jac = autogradF.jacobian(foo, inputs)
1718        jvp = autogradF.jvp(foo, inputs, v)[1]
1719        vjp = autogradF.vjp(foo, inputs, v)[1]
1720
1721        self.assertEqual(jvp, torch.mm(jac, v.unsqueeze(1)).squeeze(1))
1722        self.assertEqual(vjp, torch.mm(v.unsqueeze(0), jac).squeeze(0))
1723
1724    @base_and_logging_tensor
1725    def test_hessian_match_vhp_hvp(self, ctors):
1726        def foo(a):
1727            return 3 * a.narrow(0, 0, 3).exp().sum()
1728
1729        inputs = ctors.rand(4)
1730        v = ctors.rand(4)
1731
1732        hes = autogradF.hessian(foo, inputs)
1733        hvp = autogradF.hvp(foo, inputs, v)[1]
1734        vhp = autogradF.vhp(foo, inputs, v)[1]
1735
1736        self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1))
1737        self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0))
1738
1739
1740instantiate_parametrized_tests(TestAutogradFunctional)
1741
1742if __name__ == "__main__":
1743    run_tests()
1744