1# Owner(s): ["module: autograd"] 2 3import types 4import unittest 5import warnings 6 7import torch 8import torch.autograd.functional as autogradF 9from torch.testing._internal.common_cuda import TEST_CUDA 10from torch.testing._internal.common_utils import ( 11 gradcheck, 12 gradgradcheck, 13 instantiate_parametrized_tests, 14 parametrize, 15 run_tests, 16 subtest, 17 TestCase, 18) 19from torch.testing._internal.logging_tensor import LoggingTensor 20 21 22# Utilities for parametrizing the tensor constructors used in autograd tests 23# 24# TODO: maybe move somewhere so other tests can also use 25# 26# NB: Not all factory functions included. A complete(?) list can be found here: 27# https://pytorch.org/cppdocs/notes/tensor_creation.html 28base_ctors_dict = { 29 "ones": torch.ones, 30 "zeros": torch.zeros, 31 "randn": torch.randn, 32 "rand": torch.rand, 33 "tensor": torch.tensor, 34} 35base_ctors = types.SimpleNamespace(**base_ctors_dict) 36 37 38def wrap_with_logging_tensor(ctor): 39 def wrapper(*args, **kwargs): 40 requires_grad = kwargs.pop("requires_grad", False) 41 return LoggingTensor(ctor(*args, **kwargs), requires_grad=requires_grad) 42 43 return wrapper 44 45 46logging_tensor_ctors_dict = { 47 k: wrap_with_logging_tensor(ctor) for (k, ctor) in base_ctors_dict.items() 48} 49logging_tensor_ctors = types.SimpleNamespace(**logging_tensor_ctors_dict) 50 51base_and_logging_tensor = parametrize( 52 "ctors", 53 [ 54 subtest(base_ctors, name="base_tensor"), 55 subtest(logging_tensor_ctors, name="logging_tensor"), 56 ], 57) 58 59FIXME_base_and_xfail_logging_tensor = parametrize( 60 "ctors", 61 [ 62 subtest(base_ctors, name="base_tensor"), 63 subtest( 64 logging_tensor_ctors, 65 name="logging_tensor", 66 decorators=[unittest.expectedFailure], 67 ), 68 ], 69) 70 71# NB: This is equivalent to having both @parametrize("vectorized", [True, False]) and 72# FIXME_base_and_xfail_logging_tensor, except the non-vectorized logging_tensor case is 73# actually expected to succeed 74FIXME_xfail_vectorized_logging_tensor = parametrize( 75 "vectorize,ctors", 76 [ 77 subtest((True, base_ctors), name="vectorized_base_tensor"), 78 subtest((False, base_ctors), name="base_tensor"), 79 subtest( 80 (True, logging_tensor_ctors), 81 name="vectorized_logging_tensor", 82 decorators=[unittest.expectedFailure], 83 ), 84 subtest((False, logging_tensor_ctors), name="logging_tensor"), 85 ], 86) 87 88vectorized_logging_tensor = parametrize( 89 "vectorize,ctors", 90 [ 91 subtest((True, base_ctors), name="vectorized_base_tensor"), 92 subtest((False, base_ctors), name="base_tensor"), 93 subtest((True, logging_tensor_ctors), name="vectorized_logging_tensor"), 94 subtest((False, logging_tensor_ctors), name="logging_tensor"), 95 ], 96) 97 98 99class TestAutogradFunctional(TestCase): 100 def _assert_same_struct(self, res, base): 101 # base and res should be Tensors or tuple of Tensors with the same size 102 if isinstance(base, torch.Tensor): 103 self.assertTrue(isinstance(res, torch.Tensor)) 104 self.assertEqual(base.size(), res.size()) 105 elif isinstance(base, tuple): 106 self.assertTrue(isinstance(res, tuple)) 107 self.assertEqual(len(base), len(res)) 108 for el_base, el_res in zip(base, res): 109 self.assertTrue(isinstance(el_base, torch.Tensor)) 110 self.assertTrue(isinstance(el_res, torch.Tensor)) 111 self.assertEqual(el_base.size(), el_res.size()) 112 else: 113 # Wrong base 114 raise RuntimeError( 115 "The base given to `_assert_same_struct` doesn't have" 116 " the right structure." 117 ) 118 119 def _assert_interleaved_struct(self, res, base1, base2): 120 # base1 and base2 can be Tensors or tuples of Tensors. 121 # If they are tuples, res should be a tuple as well. 122 # The indexing works as follows for base1, base2 being 123 # - tuple, tuple: res[i][j][k][l] = (base1[i][k], base2[j][l]) 124 # - tuple, Tensor: res[i][k][l] = (base1[i][k], base2[l]) 125 # - Tensor, tuple: res[i][j][l] = (base1[i], base2[j][l]) 126 # - Tensor, Tensor: res[k][l] = (base1[k], base2[l]) 127 if isinstance(base1, torch.Tensor) and isinstance(base2, torch.Tensor): 128 self.assertTrue(isinstance(res, torch.Tensor)) 129 self.assertEqual(res.size(), base1.size() + base2.size()) 130 elif isinstance(base1, tuple) and isinstance(base2, torch.Tensor): 131 self.assertTrue(isinstance(res, tuple)) 132 self.assertEqual(len(res), len(base1)) 133 for el_res, el_base1 in zip(res, base1): 134 self.assertTrue(isinstance(el_res, torch.Tensor)) 135 self.assertTrue(isinstance(el_base1, torch.Tensor)) 136 self.assertEqual(el_res.size(), el_base1.size() + base2.size()) 137 elif isinstance(base1, torch.Tensor) and isinstance(base2, tuple): 138 self.assertTrue(isinstance(res, tuple)) 139 self.assertEqual(len(res), len(base2)) 140 for el_res, el_base2 in zip(res, base2): 141 self.assertTrue(isinstance(el_res, torch.Tensor)) 142 self.assertTrue(isinstance(el_base2, torch.Tensor)) 143 self.assertEqual(el_res.size(), base1.size() + el_base2.size()) 144 elif isinstance(base1, tuple) and isinstance(base2, tuple): 145 self.assertTrue(isinstance(res, tuple)) 146 self.assertEqual(len(res), len(base1)) 147 for el_res, el_base1 in zip(res, base1): 148 self.assertTrue(isinstance(el_res, tuple)) 149 self.assertEqual(len(res), len(base2)) 150 for el_el_res, el_base2 in zip(el_res, base2): 151 self.assertTrue(isinstance(el_el_res, torch.Tensor)) 152 self.assertTrue(isinstance(el_base2, torch.Tensor)) 153 self.assertEqual( 154 el_el_res.size(), el_base1.size() + el_base2.size() 155 ) 156 else: 157 # Wrong bases 158 raise RuntimeError( 159 "The bases given to `_assert_interleaved_struct` don't have" 160 " the right structure." 161 ) 162 163 @base_and_logging_tensor 164 def test_vjp_err_check(self, ctors): 165 def foo(a): 166 return 3 * a.narrow(0, 0, 3) 167 168 def bar(a): 169 return 3 * a.narrow(0, 0, 3), "bar" 170 171 inp = ctors.rand(4) 172 v = ctors.ones(3) 173 with self.assertRaisesRegex( 174 TypeError, "The inputs given to vjp must be either a Tensor" 175 ): 176 res = autogradF.vjp(foo, (inp, 2), v) 177 178 with self.assertRaisesRegex( 179 TypeError, "The outputs of the user-provided function given to vjp must" 180 ): 181 res = autogradF.vjp(bar, inp, v) 182 183 with self.assertRaisesRegex( 184 RuntimeError, 185 "The vector v can only be None if the user-provided function returns", 186 ): 187 res = autogradF.vjp(foo, inp) 188 189 with self.assertRaisesRegex( 190 RuntimeError, "The given v should contain a single Tensor." 191 ): 192 res = autogradF.vjp(foo, inp, (torch.ones_like(inp), torch.ones_like(inp))) 193 194 with self.assertRaisesRegex( 195 RuntimeError, "v has invalid size: should be torch.Size" 196 ): 197 res = autogradF.vjp(foo, inp, v[:2]) 198 199 res = autogradF.vjp(foo, inp, v)[1] 200 self._assert_same_struct(res, inp) 201 202 @base_and_logging_tensor 203 def test_vjp_err_check_strict(self, ctors): 204 def foo(a): 205 return a.detach() 206 207 def bar(a): 208 # Make a non-leaf Tensor that requires_grad but that is not connected to the input 209 return a.long().float().requires_grad_().clone() 210 211 inp = ctors.rand(4) 212 v = ctors.rand(4) 213 with self.assertRaisesRegex( 214 RuntimeError, 215 "Output 0 of the user-provided function does not require gradients.", 216 ): 217 res = autogradF.vjp(foo, inp, v, strict=True) 218 res = autogradF.vjp(foo, inp, v, strict=False) 219 self._assert_same_struct(res[1], inp) 220 self.assertEqual(res[1].abs().sum(), 0.0) 221 222 with self.assertRaisesRegex( 223 RuntimeError, 224 "The output of the user-provided function is independent of input 0", 225 ): 226 res = autogradF.vjp(bar, inp, v, strict=True) 227 res = autogradF.vjp(bar, inp, v, strict=False) 228 self._assert_same_struct(res[1], inp) 229 self.assertEqual(res[1].abs().sum(), 0.0) 230 231 # The Jacobian does not depend on the input 232 def foo(a): 233 return a.clone() 234 235 inp.requires_grad_() 236 with self.assertRaisesRegex( 237 RuntimeError, 238 "jacobian of the user-provided function is independent of input 0.", 239 ): 240 res = autogradF.vjp(foo, inp, v, create_graph=True, strict=True) 241 res = autogradF.vjp(foo, inp, v, create_graph=True, strict=False) 242 self._assert_same_struct(res[1], inp) 243 self.assertEqual(res[1], v) 244 245 @base_and_logging_tensor 246 def test_vjp_no_grad(self, ctors): 247 def reducer(x): 248 return x.sum(dim=1) 249 250 inputs = ctors.rand(4, 4) 251 v = ctors.ones(4) 252 with torch.no_grad(): 253 res = autogradF.vjp(reducer, inputs, v) 254 self.assertIsNone(res[0].grad_fn) 255 self.assertIsNone(res[1].grad_fn) 256 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 257 258 inputs.requires_grad_() 259 v.requires_grad_() 260 with torch.no_grad(): 261 res = autogradF.vjp(reducer, inputs, v, create_graph=True) 262 self.assertIsNotNone(res[0].grad_fn) 263 self.assertIsNotNone(res[1].grad_fn) 264 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 265 266 @base_and_logging_tensor 267 def test_vjp_output(self, ctors): 268 def reducer(x): 269 return x.sum(dim=1) 270 271 inputs = ctors.rand(4, 4) 272 v = ctors.ones(4) 273 res = autogradF.vjp(reducer, inputs, v) 274 self._assert_same_struct(res[1], inputs) 275 self.assertIsNone(res[0].grad_fn) 276 self.assertIsNone(res[1].grad_fn) 277 278 def adder(x, y): 279 return 2 * x + 3 * y 280 281 inputs = (ctors.rand(2), ctors.rand(2)) 282 v = ctors.ones(2) 283 out, vjp_val = autogradF.vjp(adder, inputs, v) 284 self._assert_same_struct(vjp_val, inputs) 285 self.assertIsNone(out.grad_fn) 286 self.assertIsNone(vjp_val[0].grad_fn) 287 self.assertIsNone(vjp_val[1].grad_fn) 288 289 def adder(x, y): 290 return 2 * x + 3 * y, x + y 291 292 inputs = (ctors.rand(2), ctors.rand(2)) 293 v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0])) 294 out, vjp_val = autogradF.vjp(adder, inputs, v) 295 self._assert_same_struct(vjp_val, inputs) 296 self.assertIsNone(out[0].grad_fn) 297 self.assertIsNone(out[1].grad_fn) 298 self.assertIsNone(vjp_val[0].grad_fn) 299 self.assertIsNone(vjp_val[1].grad_fn) 300 301 @base_and_logging_tensor 302 def test_vjp_scalar(self, ctors): 303 def reducer(x): 304 return x.sum() 305 306 inputs = ctors.rand(4, 4) 307 v = ctors.ones([]) 308 res = autogradF.vjp(reducer, inputs, v) 309 self._assert_same_struct(res[0], v) 310 self._assert_same_struct(res[1], inputs) 311 312 res = autogradF.vjp(reducer, inputs) 313 self._assert_same_struct(res[0], v) 314 self._assert_same_struct(res[1], inputs) 315 316 def expander(x): 317 return x.unsqueeze(0).repeat(4) 318 319 inputs = ctors.rand([]) 320 v = ctors.ones(4) 321 res = autogradF.vjp(expander, inputs, v) 322 self._assert_same_struct(res[0], v) 323 self._assert_same_struct(res[1], inputs) 324 325 @base_and_logging_tensor 326 def test_vjp_create_graph(self, ctors): 327 def reducer(x): 328 return x.sum(dim=1) 329 330 inputs = ctors.rand(2, 2, dtype=torch.double) 331 v = ctors.ones(2, dtype=torch.double) 332 333 inputs.requires_grad_() 334 v.requires_grad_() 335 res = autogradF.vjp(reducer, inputs, v, create_graph=True) 336 self._assert_same_struct(res[1], inputs) 337 self.assertIsNotNone(res[0].grad_fn) 338 self.assertIsNotNone(res[1].grad_fn) 339 340 gradcheck( 341 lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True), 342 (inputs, v), 343 ) 344 gradgradcheck( 345 lambda inp, v: autogradF.vjp(reducer, inputs, v, create_graph=True), 346 (inputs, v), 347 ) 348 349 def adder(x, y): 350 return 2 * x + 3 * y, x * y 351 352 inputs = ( 353 ctors.rand(2, dtype=torch.double, requires_grad=True), 354 ctors.rand(2, dtype=torch.double, requires_grad=True), 355 ) 356 v = ( 357 ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 358 ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 359 ) 360 361 gradcheck( 362 lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[ 363 1 364 ], 365 inputs + v, 366 ) 367 gradgradcheck( 368 lambda *args: autogradF.vjp(adder, args[:2], args[2:], create_graph=True)[ 369 1 370 ], 371 inputs + v, 372 ) 373 374 def foo(*args): 375 x, y = args[:2] 376 v = args[2:] 377 378 x = x.cos() 379 val, grad = autogradF.vjp(adder, (x, y), v, create_graph=True) 380 381 return ( 382 val[0].exp() 383 + val[1].exp() 384 + grad[0].exp() 385 + grad[1].exp() 386 + x.exp() 387 + y.exp() 388 ) 389 390 gradcheck(foo, inputs + v) 391 gradgradcheck(foo, inputs + v) 392 393 @base_and_logging_tensor 394 def test_jvp_err_check(self, ctors): 395 def foo(a): 396 return 3 * a.narrow(0, 0, 3) 397 398 def bar(a): 399 return 3 * a.narrow(0, 0, 3), "bar" 400 401 inp = ctors.rand(4) 402 v = ctors.rand(4) 403 with self.assertRaisesRegex( 404 TypeError, "The inputs given to jvp must be either a Tensor" 405 ): 406 res = autogradF.jvp(foo, (inp, 2), v) 407 408 with self.assertRaisesRegex( 409 TypeError, "The outputs of the user-provided function given to jvp must" 410 ): 411 res = autogradF.jvp(bar, inp, v) 412 413 with self.assertRaisesRegex( 414 RuntimeError, 415 "The vector v can only be None if the input to the user-provided function", 416 ): 417 res = autogradF.jvp(foo, inp) 418 419 with self.assertRaisesRegex( 420 RuntimeError, "The given v should contain a single Tensor." 421 ): 422 res = autogradF.jvp(foo, inp, (v, v)) 423 424 with self.assertRaisesRegex( 425 RuntimeError, "v has invalid size: should be torch.Size" 426 ): 427 res = autogradF.jvp(foo, inp, v[:2]) 428 429 res = autogradF.jvp(foo, inp, v)[1] 430 self._assert_same_struct(res, foo(inp)) 431 432 @base_and_logging_tensor 433 def test_jvp_err_check_strict(self, ctors): 434 def foo(a): 435 return a.detach() 436 437 def bar(a): 438 # Make a non-leaf Tensor that requires_grad but that is not connected to the input 439 return a.long().float().requires_grad_().clone() 440 441 inp = ctors.rand(4) 442 v = ctors.rand(4) 443 with self.assertRaisesRegex( 444 RuntimeError, 445 "Output 0 of the user-provided function does not require gradients.", 446 ): 447 res = autogradF.jvp(foo, inp, v, strict=True) 448 res = autogradF.jvp(foo, inp, v, strict=False) 449 self._assert_same_struct(res[1], res[0]) 450 self.assertEqual(res[1].abs().sum(), 0.0) 451 452 with self.assertRaisesRegex( 453 RuntimeError, 454 "The output of the user-provided function is independent of input 0", 455 ): 456 res = autogradF.jvp(bar, inp, v, strict=True) 457 res = autogradF.jvp(bar, inp, v, strict=False) 458 self._assert_same_struct(res[1], res[0]) 459 self.assertEqual(res[1].abs().sum(), 0.0) 460 461 # The Jacobian does not depend on the input 462 def foo(a): 463 return a.clone() 464 465 inp.requires_grad_() 466 with self.assertRaisesRegex( 467 RuntimeError, 468 "jacobian of the user-provided function is independent of input 0.", 469 ): 470 res = autogradF.jvp(foo, inp, v, create_graph=True, strict=True) 471 res = autogradF.jvp(foo, inp, v, create_graph=True, strict=False) 472 self._assert_same_struct(res[1], inp) 473 self.assertEqual(res[1], v) 474 475 @base_and_logging_tensor 476 def test_jvp_no_grad(self, ctors): 477 def reducer(x): 478 return x.sum(dim=1) 479 480 inputs = ctors.rand(4, 4) 481 v = ctors.ones(4, 4) 482 with torch.no_grad(): 483 res = autogradF.jvp(reducer, inputs, v) 484 self.assertIsNone(res[0].grad_fn) 485 self.assertIsNone(res[1].grad_fn) 486 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 487 488 inputs.requires_grad_() 489 v.requires_grad_() 490 with torch.no_grad(): 491 res = autogradF.jvp(reducer, inputs, v, create_graph=True) 492 self.assertIsNotNone(res[0].grad_fn) 493 self.assertIsNotNone(res[1].grad_fn) 494 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 495 496 @base_and_logging_tensor 497 def test_jvp_output(self, ctors): 498 def reducer(x): 499 return x.sum(dim=1) 500 501 inputs = ctors.rand(4, 4) 502 v = ctors.ones(4, 4) 503 res = autogradF.jvp(reducer, inputs, v) 504 self._assert_same_struct(res[1], res[0]) 505 self.assertIsNone(res[0].grad_fn) 506 self.assertIsNone(res[1].grad_fn) 507 508 def adder(x, y): 509 return 2 * x + 3 * y 510 511 inputs = (ctors.rand(2), ctors.rand(2)) 512 v = (ctors.ones(2), ctors.ones(2)) 513 out, jvp_val = autogradF.jvp(adder, inputs, v) 514 self._assert_same_struct(jvp_val, out) 515 self.assertIsNone(out.grad_fn) 516 self.assertIsNone(jvp_val[0].grad_fn) 517 self.assertIsNone(jvp_val[1].grad_fn) 518 519 def adder(x, y): 520 return 2 * x + 3 * y, x + y 521 522 inputs = (ctors.rand(2), ctors.rand(2)) 523 v = (ctors.tensor([1.0, 0.0]), ctors.tensor([1.0, 0.0])) 524 out, jvp_val = autogradF.jvp(adder, inputs, v) 525 self._assert_same_struct(jvp_val, out) 526 self.assertIsNone(out[0].grad_fn) 527 self.assertIsNone(out[1].grad_fn) 528 self.assertIsNone(jvp_val[0].grad_fn) 529 self.assertIsNone(jvp_val[1].grad_fn) 530 531 @base_and_logging_tensor 532 def test_jvp_scalar(self, ctors): 533 def reducer(x): 534 return x.sum() 535 536 inputs = ctors.rand(4, 4) 537 v = ctors.ones(4, 4) 538 res = autogradF.jvp(reducer, inputs, v) 539 self._assert_same_struct(res[0], ctors.zeros([])) 540 self._assert_same_struct(res[1], res[0]) 541 542 def expander(x): 543 return x.unsqueeze(0).repeat(4) 544 545 inputs = ctors.rand([]) 546 v = ctors.ones([]) 547 res = autogradF.jvp(expander, inputs, v) 548 self._assert_same_struct(res[0], ctors.zeros(4)) 549 self._assert_same_struct(res[1], res[0]) 550 551 res = autogradF.jvp(expander, inputs) 552 self._assert_same_struct(res[0], ctors.zeros(4)) 553 self._assert_same_struct(res[1], res[0]) 554 555 @base_and_logging_tensor 556 def test_jvp_create_graph(self, ctors): 557 def reducer(x): 558 return x.sum(dim=1) 559 560 inputs = ctors.rand(2, 2, dtype=torch.double) 561 v = ctors.ones(2, 2, dtype=torch.double) 562 563 inputs.requires_grad_() 564 v.requires_grad_() 565 res = autogradF.jvp(reducer, inputs, v, create_graph=True) 566 self._assert_same_struct(res[1], res[0]) 567 self.assertIsNotNone(res[0].grad_fn) 568 self.assertIsNotNone(res[1].grad_fn) 569 570 gradcheck( 571 lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), 572 (inputs, v), 573 ) 574 gradgradcheck( 575 lambda inp, v: autogradF.jvp(reducer, inp, v, create_graph=True), 576 (inputs, v), 577 ) 578 579 def adder(x, y): 580 return 2 * x + 3 * y, x * y 581 582 inputs = ( 583 ctors.rand(2, dtype=torch.double, requires_grad=True), 584 ctors.rand(2, dtype=torch.double, requires_grad=True), 585 ) 586 v = ( 587 ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 588 ctors.tensor([1.0, 0.0], dtype=torch.double, requires_grad=True), 589 ) 590 591 gradcheck( 592 lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[ 593 1 594 ], 595 inputs + v, 596 ) 597 gradgradcheck( 598 lambda *args: autogradF.jvp(adder, args[:2], args[2:], create_graph=True)[ 599 1 600 ], 601 inputs + v, 602 ) 603 604 def foo(*args): 605 x, y = args[:2] 606 v = args[2:] 607 608 x = x.cos() 609 val, grad = autogradF.jvp(adder, (x, y), v, create_graph=True) 610 611 return ( 612 val[0].exp() 613 + val[1].exp() 614 + grad[0].exp() 615 + grad[1].exp() 616 + x.exp() 617 + y.exp() 618 ) 619 620 gradcheck(foo, inputs + v) 621 gradgradcheck(foo, inputs + v) 622 623 def _test_construct_standard_basis_for(self, inputs): 624 numels = tuple(tensor.numel() for tensor in inputs) 625 results = autogradF._construct_standard_basis_for(inputs, numels) 626 for result, inp in zip(results, inputs): 627 self.assertEqual(result.dtype, inp.dtype) 628 self.assertEqual(result.device, inp.device) 629 results = torch.cat( 630 [result.to(device="cpu", dtype=torch.float) for result in results], dim=1 631 ) 632 expected = torch.eye(results[0].shape[0], dtype=torch.float) 633 self.assertEqual(results, expected) 634 635 @base_and_logging_tensor 636 def test_construct_standard_basis_for(self, ctors): 637 test_cases = [ 638 (ctors.randn(2, 3),), 639 (ctors.randn(1),), 640 (ctors.randn([]),), 641 (ctors.randn(1), ctors.randn([]), ctors.randn([])), 642 (ctors.randn(2), ctors.randn(3), ctors.randn([])), 643 (ctors.randn(2), ctors.randn([]), ctors.randn(3)), 644 (ctors.randn(2, 3), ctors.randn(3), ctors.randn(3, 4, 2)), 645 (ctors.randn(2, dtype=torch.float64), ctors.randn(3, dtype=torch.float32)), 646 ] 647 648 for inputs in test_cases: 649 self._test_construct_standard_basis_for(inputs) 650 651 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 652 @base_and_logging_tensor 653 def test_construct_standard_basis_for_cuda(self, ctors): 654 test_cases = [ 655 (ctors.randn(2), ctors.randn(3, device="cuda")), 656 (ctors.randn(3, device="cuda"), ctors.randn(2)), 657 ] 658 659 for inputs in test_cases: 660 self._test_construct_standard_basis_for(inputs) 661 662 def _test_vectorize_raises_no_warnings(self, api, ctors): 663 # vmap is an experimental prototype. When someone calls torch.vmap, 664 # it raises a python warning. This test checks that 665 # autogradF.{jacobian, hessian} don't raise that experimental prototype 666 # warning; it is not nice for a public-facing API to raise a warning 667 # no matter how it is called. 668 def foo(a): 669 return (a**2).sum() 670 671 x = ctors.randn(3) 672 with warnings.catch_warnings(record=True) as wa: 673 result = api(foo, x, vectorize=True) 674 self.assertEqual(len(wa), 0) 675 676 @base_and_logging_tensor 677 def test_jacobian_vectorize_raises_no_warnings(self, ctors): 678 return self._test_vectorize_raises_no_warnings(autogradF.jacobian, ctors) 679 680 @base_and_logging_tensor 681 def test_hessian_vectorize_raises_no_warnings(self, ctors): 682 return self._test_vectorize_raises_no_warnings(autogradF.hessian, ctors) 683 684 @parametrize("vectorize", [True, False]) 685 @base_and_logging_tensor 686 def test_jacobian_err_check(self, vectorize, ctors): 687 def foo(a): 688 return 3 * a.narrow(0, 0, 3) 689 690 def bar(a): 691 return 3 * a.narrow(0, 0, 3), "bar" 692 693 inp = ctors.rand(4) 694 with self.assertRaisesRegex( 695 TypeError, "The inputs given to jacobian must be either a Tensor" 696 ): 697 res = autogradF.jacobian(foo, (inp, 2), vectorize=vectorize) 698 699 with self.assertRaisesRegex( 700 TypeError, 701 "The outputs of the user-provided function given to jacobian must", 702 ): 703 res = autogradF.jacobian(bar, inp, vectorize=vectorize) 704 705 res = autogradF.jacobian(foo, inp, vectorize=vectorize) 706 self._assert_interleaved_struct(res, foo(inp), inp) 707 708 def foo(a, b): 709 return b, 3 * a.narrow(0, 0, 3) 710 711 inp = (ctors.rand(4), ctors.rand(5)) 712 713 res = autogradF.jacobian(foo, inp, vectorize=vectorize) 714 self._assert_interleaved_struct(res, foo(*inp), inp) 715 716 @base_and_logging_tensor 717 def test_jacobian_err_check_strict(self, ctors): 718 def foo(a): 719 return a.detach() 720 721 def bar(a): 722 # Make a non-leaf Tensor that requires_grad but that is not connected to the input 723 return a.long().float().requires_grad_().clone() 724 725 inp = ctors.rand(4) 726 with self.assertRaisesRegex( 727 RuntimeError, 728 "Output 0 of the user-provided function does not require gradients.", 729 ): 730 res = autogradF.jacobian(foo, inp, strict=True) 731 res = autogradF.jacobian(foo, inp, strict=False) 732 self._assert_interleaved_struct(res, foo(inp), inp) 733 self.assertEqual(res.abs().sum(), 0.0) 734 735 with self.assertRaisesRegex( 736 RuntimeError, 737 "Output 0 of the user-provided function is independent of input 0.", 738 ): 739 res = autogradF.jacobian(bar, inp, strict=True) 740 res = autogradF.jacobian(bar, inp, strict=False) 741 self._assert_interleaved_struct(res, foo(inp), inp) 742 self.assertEqual(res.abs().sum(), 0.0) 743 744 # The Jacobian does not depend on the input 745 def foo(a): 746 return a.clone() 747 748 inp.requires_grad_() 749 with self.assertRaisesRegex( 750 RuntimeError, 751 "jacobian of the user-provided function is independent of input 0.", 752 ): 753 res = autogradF.jacobian(foo, inp, create_graph=True, strict=True) 754 res = autogradF.jacobian(foo, inp, create_graph=True, strict=False) 755 self._assert_interleaved_struct(res, inp, inp) 756 self.assertEqual(res, torch.eye(4)) 757 758 @base_and_logging_tensor 759 def test_jacobian_err_check_strict_vectorize(self, ctors): 760 def foo(x): 761 return x 762 763 inp = ctors.rand(4) 764 with self.assertRaisesRegex(RuntimeError, "not supported together"): 765 res = autogradF.jacobian(foo, inp, strict=True, vectorize=True) 766 767 @base_and_logging_tensor 768 def test_jacobian_no_grad(self, ctors): 769 def exp_reducer(x): 770 return x.exp().sum(dim=1) 771 772 inputs = ctors.rand(4, 4) 773 with torch.no_grad(): 774 res = autogradF.jacobian(exp_reducer, inputs) 775 self.assertIsNone(res.grad_fn) 776 self.assertNotEqual(res, ctors.zeros(4, 4)) 777 778 with torch.no_grad(): 779 res = autogradF.jacobian(exp_reducer, inputs, create_graph=True) 780 self.assertIsNotNone(res.grad_fn) 781 self.assertNotEqual(res, ctors.zeros(4, 4)) 782 783 @vectorized_logging_tensor 784 def test_jacobian_output(self, vectorize, ctors): 785 def exp_reducer(x): 786 return x.exp().sum(dim=1) 787 788 inputs = ctors.rand(4, 4) 789 res = autogradF.jacobian(exp_reducer, inputs, vectorize=vectorize) 790 self._assert_interleaved_struct(res, exp_reducer(inputs), inputs) 791 self.assertIsNone(res.grad_fn) 792 793 def identity(x): 794 return x.clone() 795 796 inputs = ctors.rand(4) 797 res = autogradF.jacobian(identity, inputs, vectorize=vectorize) 798 self._assert_interleaved_struct(res, identity(inputs), inputs) 799 self.assertIsNone(res.grad_fn) 800 self.assertEqual(res, torch.eye(4)) 801 802 def add_exp_reducer(x, y): 803 return (x + y.exp()).sum(dim=1) 804 805 inputs = (ctors.rand(4, 4), ctors.rand(4, 4)) 806 res = autogradF.jacobian(add_exp_reducer, inputs, vectorize=vectorize) 807 self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs) 808 self.assertIsNone(res[0].grad_fn) 809 self.assertIsNone(res[1].grad_fn) 810 811 @vectorized_logging_tensor 812 def test_jacobian_scalar(self, vectorize, ctors): 813 def reducer(x): 814 return x.sum() 815 816 inputs = ctors.rand(4, 4) 817 res = autogradF.jacobian(reducer, inputs, vectorize=vectorize) 818 self._assert_same_struct(res, inputs) 819 820 def expander(x): 821 return x.unsqueeze(0).repeat(4) 822 823 inputs = ctors.rand([]) 824 res = autogradF.jacobian(expander, inputs, vectorize=vectorize) 825 self._assert_same_struct(res, ctors.zeros(4)) 826 827 @parametrize("vectorize", [True, False]) 828 @base_and_logging_tensor 829 def test_jacobian_create_graph(self, vectorize, ctors): 830 def exp_reducer(x): 831 return x.exp().sum(dim=1) 832 833 inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True) 834 res = autogradF.jacobian( 835 exp_reducer, inputs, create_graph=True, vectorize=vectorize 836 ) 837 self._assert_interleaved_struct(res, exp_reducer(inputs), inputs) 838 self.assertIsNotNone(res.grad_fn) 839 840 gradcheck( 841 lambda inp: autogradF.jacobian( 842 exp_reducer, inp, create_graph=True, vectorize=vectorize 843 ), 844 inputs, 845 ) 846 gradgradcheck( 847 lambda inp: autogradF.jacobian( 848 exp_reducer, inp, create_graph=True, vectorize=vectorize 849 ), 850 inputs, 851 ) 852 853 def add_exp_reducer(x, y): 854 return (x + y).exp().sum(dim=1) 855 856 inputs = ( 857 ctors.rand(4, 4, dtype=torch.double, requires_grad=True), 858 ctors.rand(4, 4, dtype=torch.double, requires_grad=True), 859 ) 860 res = autogradF.jacobian( 861 add_exp_reducer, inputs, create_graph=True, vectorize=vectorize 862 ) 863 self._assert_interleaved_struct(res, add_exp_reducer(*inputs), inputs) 864 self.assertIsNotNone(res[0].grad_fn) 865 self.assertIsNotNone(res[1].grad_fn) 866 867 gradcheck( 868 lambda *inp: autogradF.jacobian( 869 add_exp_reducer, inp, create_graph=True, vectorize=vectorize 870 ), 871 inputs, 872 ) 873 gradgradcheck( 874 lambda *inp: autogradF.jacobian( 875 add_exp_reducer, inp, create_graph=True, vectorize=vectorize 876 ), 877 inputs, 878 ) 879 880 def foo(x, y): 881 x = x.cos() 882 val, jac = autogradF.jacobian( 883 add_exp_reducer, (x, y), create_graph=True, vectorize=vectorize 884 ) 885 886 res = val[0].exp().sum() + val[1].exp().sum() + jac[0].exp().sum() 887 res = res + jac[1].exp().sum() + x.exp().sum() + y.exp().sum() 888 return res 889 890 gradcheck(foo, inputs) 891 gradgradcheck(foo, inputs) 892 893 def _check_jacobian_vectorize_correctness(self, f, inputs, test_forward_ad=True): 894 expected = autogradF.jacobian(f, inputs, vectorize=False) 895 result_backward_mode = autogradF.jacobian(f, inputs, vectorize=True) 896 self.assertEqual(result_backward_mode, expected) 897 898 if test_forward_ad: 899 result_forward_mode = autogradF.jacobian( 900 f, inputs, strategy="forward-mode", vectorize=True 901 ) 902 self.assertEqual(result_forward_mode, expected) 903 904 @base_and_logging_tensor 905 def test_jacobian_vectorize_correctness_simple(self, ctors): 906 def f(x): 907 return 3 * x**2 908 909 x = ctors.randn(2, 3, 5) 910 self._check_jacobian_vectorize_correctness(f, x) 911 912 @base_and_logging_tensor 913 def test_jacobian_vectorize_correctness_multi_input(self, ctors): 914 def f(x, y): 915 return (x.cos() * x) @ y.sin() 916 917 x = ctors.randn(2, 3) 918 y = ctors.randn(3, 5) 919 self._check_jacobian_vectorize_correctness(f, (x, y)) 920 921 @base_and_logging_tensor 922 def test_jacobian_vectorize_correctness_multi_input_multi_output(self, ctors): 923 def f(x, y): 924 return (x * x) @ y, x @ (x.sum(1) * y), y.sum() 925 926 x = ctors.randn(5, 3) 927 y = ctors.randn(3, 5) 928 self._check_jacobian_vectorize_correctness(f, (x, y)) 929 930 @base_and_logging_tensor 931 def test_jacobian_vectorize_correctness_unrelated_outputs(self, ctors): 932 def f(x, y): 933 return x, y, x, y 934 935 x = ctors.randn(2) 936 y = ctors.randn(3) 937 self._check_jacobian_vectorize_correctness(f, (x, y)) 938 939 @base_and_logging_tensor 940 def test_jacobian_vectorize_correctness_zero_dim(self, ctors): 941 # zero-dim output 942 def f(x, y): 943 return x.sum(), y.sum(), x * y 944 945 x = ctors.randn(3) 946 y = ctors.randn(3) 947 self._check_jacobian_vectorize_correctness(f, (x, y)) 948 949 # zero-dim input 950 def g(x): 951 return torch.stack([x, x, x]) 952 953 x = ctors.randn([]) 954 self._check_jacobian_vectorize_correctness(g, x) 955 956 # Mixed zero-dim input / zero-dim output 957 def h(x, y): 958 return y.sum(), x * y 959 960 x = ctors.randn([]) 961 y = ctors.randn(1) 962 self._check_jacobian_vectorize_correctness(h, (x, y)) 963 964 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 965 @base_and_logging_tensor 966 def test_jacobian_vectorize_correctness_different_devices(self, ctors): 967 def f(x, y): 968 return x * y, (x * y).cuda() 969 970 x = ctors.randn(3) 971 y = ctors.randn(3) 972 self._check_jacobian_vectorize_correctness(f, (x, y)) 973 974 @base_and_logging_tensor 975 def test_jacobian_vectorize_correctness_different_dtype(self, ctors): 976 def f(x, y): 977 return (x * y).float(), (x * y).double() 978 979 x = ctors.randn(3) 980 y = ctors.randn(3) 981 # The Jacobian computed using forward AD has the dtype of the output 982 # but the Jacobian computed with reverse AD has dtype of input 983 self._check_jacobian_vectorize_correctness(f, (x, y), test_forward_ad=False) 984 985 def _check_hessian_vectorize_correctness(self, f, inputs): 986 expected = autogradF.hessian(f, inputs, vectorize=False) 987 result = autogradF.hessian(f, inputs, vectorize=True) 988 self.assertEqual(result, expected) 989 990 result_forward_mode = autogradF.hessian( 991 f, inputs, outer_jacobian_strategy="forward-mode", vectorize=True 992 ) 993 self.assertEqual(result_forward_mode, expected) 994 995 @base_and_logging_tensor 996 def test_hessian_vectorize_correctness_simple(self, ctors): 997 def f(x): 998 return (3 * x**2).sum() 999 1000 x = ctors.randn(2, 3, 5) 1001 self._check_hessian_vectorize_correctness(f, x) 1002 1003 @base_and_logging_tensor 1004 def test_hessian_vectorize_correctness_multi_input(self, ctors): 1005 def f(x, y, z): 1006 return ((x.relu() * x) @ y.sin() @ z).sum() 1007 1008 x = ctors.randn(2, 3) 1009 y = ctors.randn(3, 5) 1010 z = ctors.randn(5, 5) 1011 self._check_hessian_vectorize_correctness(f, (x, y, z)) 1012 1013 @base_and_logging_tensor 1014 def test_hessian_vectorize_correctness_unrelated_outputs(self, ctors): 1015 # output unrelated to one input 1016 def f(x, y): 1017 return (x**2).sum() 1018 1019 x = ctors.randn(2) 1020 y = ctors.randn(3) 1021 self._check_hessian_vectorize_correctness(f, (x, y)) 1022 1023 # output unrelated to all inputs 1024 def f(x, y): 1025 return ctors.ones([]) 1026 1027 x = ctors.randn(2) 1028 y = ctors.randn(3) 1029 self._check_hessian_vectorize_correctness(f, (x, y)) 1030 1031 @parametrize("vectorize", [True, False]) 1032 @base_and_logging_tensor 1033 def test_hessian_err_check(self, vectorize, ctors): 1034 def foo(a): 1035 return 3 * a.narrow(0, 0, 3).exp().sum() 1036 1037 def bar(a): 1038 return 3 * a.narrow(0, 0, 3), "bar" 1039 1040 def bar2(a): 1041 return 3 * a.narrow(0, 0, 3) 1042 1043 def bar3(a): 1044 return 3 * a.narrow(0, 0, 3), 3 * a.narrow(0, 0, 3) 1045 1046 inp = ctors.rand(4) 1047 with self.assertRaisesRegex( 1048 TypeError, "The inputs given to hessian must be either a Tensor" 1049 ): 1050 res = autogradF.hessian(foo, (inp, 2), vectorize=vectorize) 1051 1052 with self.assertRaisesRegex( 1053 TypeError, "The outputs of the user-provided function given to hessian must" 1054 ): 1055 res = autogradF.hessian(bar, inp, vectorize=vectorize) 1056 1057 err_msg_out = "The Tensor returned by the function given to hessian should contain a single element" 1058 with self.assertRaisesRegex(RuntimeError, err_msg_out): 1059 res = autogradF.hessian(bar2, inp, vectorize=vectorize) 1060 1061 with self.assertRaisesRegex( 1062 RuntimeError, "The function given to hessian should return a single Tensor" 1063 ): 1064 res = autogradF.hessian(bar3, inp, vectorize=vectorize) 1065 1066 res = autogradF.hessian(foo, inp, vectorize=vectorize) 1067 self._assert_interleaved_struct(res, inp, inp) 1068 1069 def foo(a, b): 1070 return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum() 1071 1072 inp = (ctors.rand(4), ctors.rand(5)) 1073 1074 res = autogradF.hessian(foo, inp, vectorize=vectorize) 1075 self._assert_interleaved_struct(res, inp, inp) 1076 1077 @base_and_logging_tensor 1078 def test_hessian_err_check_strict(self, ctors): 1079 def foo(a): 1080 return a.detach().sum() 1081 1082 def bar(a): 1083 # Make a non-leaf Tensor that requires_grad but that is not connected to the input 1084 return a.long().float().requires_grad_().clone().sum() 1085 1086 def bar2(a): 1087 # A Linear function for which the jacobian is independent of the input 1088 return (3 * a).sum() 1089 1090 inp = ctors.rand(4) 1091 with self.assertRaisesRegex( 1092 RuntimeError, 1093 "Output 0 of the user-provided function does not require gradients.", 1094 ): 1095 res = autogradF.hessian(foo, inp, strict=True) 1096 res = autogradF.hessian(foo, inp, strict=False) 1097 self._assert_interleaved_struct(res, inp, inp) 1098 self.assertEqual(res.abs().sum(), 0.0) 1099 1100 with self.assertRaisesRegex( 1101 RuntimeError, 1102 "jacobian of the user-provided function with respect to input 0", 1103 ): 1104 res = autogradF.hessian(bar, inp, strict=True) 1105 res = autogradF.hessian(bar, inp, strict=False) 1106 self._assert_interleaved_struct(res, inp, inp) 1107 self.assertEqual(res.abs().sum(), 0.0) 1108 1109 with self.assertRaisesRegex( 1110 RuntimeError, 1111 "jacobian of the user-provided function with respect to input 0 is", 1112 ): 1113 res = autogradF.hessian(bar2, inp, strict=True) 1114 res = autogradF.hessian(bar2, inp, strict=False) 1115 self._assert_interleaved_struct(res, inp, inp) 1116 self.assertEqual(res.abs().sum(), 0.0) 1117 1118 @base_and_logging_tensor 1119 def test_hessian_err_check_strict_vectorize(self, ctors): 1120 def foo(x): 1121 return (x**3).sum() 1122 1123 inp = ctors.rand(4) 1124 with self.assertRaisesRegex(RuntimeError, "not supported together"): 1125 res = autogradF.hessian(foo, inp, strict=True, vectorize=True) 1126 1127 @base_and_logging_tensor 1128 def test_hessian_no_grad(self, ctors): 1129 def pow_reducer(x): 1130 return x.pow(3).sum() 1131 1132 inputs = ctors.rand(2, 2) 1133 with torch.no_grad(): 1134 res = autogradF.hessian(pow_reducer, inputs) 1135 self.assertIsNone(res[0][0].grad_fn) 1136 self.assertIsNone(res[0][1].grad_fn) 1137 self.assertIsNone(res[1][0].grad_fn) 1138 self.assertIsNone(res[1][1].grad_fn) 1139 self.assertNotEqual(res, ctors.zeros(2, 2, 2)) 1140 1141 with torch.no_grad(): 1142 res = autogradF.hessian(pow_reducer, inputs, create_graph=True) 1143 self.assertIsNotNone(res[0][0].grad_fn) 1144 self.assertIsNotNone(res[0][1].grad_fn) 1145 self.assertIsNotNone(res[1][0].grad_fn) 1146 self.assertIsNotNone(res[1][1].grad_fn) 1147 self.assertNotEqual(res, ctors.zeros(2, 2, 2)) 1148 1149 @vectorized_logging_tensor 1150 def test_hessian_output(self, vectorize, ctors): 1151 def pow_reducer(x): 1152 return x.pow(3).sum() 1153 1154 inputs = ctors.rand(2, 2) 1155 res = autogradF.hessian(pow_reducer, inputs, vectorize=vectorize) 1156 self._assert_interleaved_struct(res, inputs, inputs) 1157 self.assertIsNone(res.grad_fn) 1158 1159 def add_pow_reducer(x, y): 1160 return (x + y).pow(3).sum() 1161 1162 inputs = (ctors.rand(2, 2), ctors.rand(2, 2)) 1163 res = autogradF.hessian(add_pow_reducer, inputs, vectorize=vectorize) 1164 self._assert_interleaved_struct(res, inputs, inputs) 1165 self.assertIsNone(res[0][0].grad_fn) 1166 self.assertIsNone(res[0][1].grad_fn) 1167 self.assertIsNone(res[1][0].grad_fn) 1168 self.assertIsNone(res[1][1].grad_fn) 1169 1170 @parametrize("vectorize", [True, False]) 1171 @base_and_logging_tensor 1172 def test_hessian_scalar(self, vectorize, ctors): 1173 def reducer(x): 1174 return x.sum() 1175 1176 inputs = ctors.rand(4, 4) 1177 res = autogradF.hessian(reducer, inputs, vectorize=vectorize) 1178 self._assert_interleaved_struct(res, inputs, inputs) 1179 1180 inputs = ctors.rand([]) 1181 res = autogradF.hessian(reducer, inputs, vectorize=vectorize) 1182 self._assert_same_struct(res, inputs) 1183 1184 def bad_reducer(x): 1185 return x.sum().view(1, 1, 1) 1186 1187 inputs = ctors.rand(4, 4) 1188 res = autogradF.hessian(bad_reducer, inputs, vectorize=vectorize) 1189 self._assert_interleaved_struct(res, inputs, inputs) 1190 1191 @parametrize("vectorize", [True, False]) 1192 @base_and_logging_tensor 1193 def test_hessian_create_graph(self, vectorize, ctors): 1194 def pow_reducer(x): 1195 return x.pow(3).sum() 1196 1197 inputs = ctors.rand(2, 2, dtype=torch.double, requires_grad=True) 1198 res = autogradF.hessian( 1199 pow_reducer, inputs, create_graph=True, vectorize=vectorize 1200 ) 1201 self._assert_interleaved_struct(res, inputs, inputs) 1202 self.assertIsNotNone(res.grad_fn) 1203 1204 gradcheck( 1205 lambda inp: autogradF.hessian( 1206 pow_reducer, inp, create_graph=True, vectorize=vectorize 1207 ), 1208 inputs, 1209 ) 1210 gradgradcheck( 1211 lambda inp: autogradF.hessian( 1212 pow_reducer, inp, create_graph=True, vectorize=vectorize 1213 ), 1214 inputs, 1215 ) 1216 1217 def add_pow_reducer(x, y): 1218 return (x + y).pow(3).sum() 1219 1220 inputs = ( 1221 ctors.rand(2, 2, dtype=torch.double, requires_grad=True), 1222 ctors.rand(2, 2, dtype=torch.double, requires_grad=True), 1223 ) 1224 res = autogradF.hessian( 1225 add_pow_reducer, inputs, create_graph=True, vectorize=vectorize 1226 ) 1227 self._assert_interleaved_struct(res, inputs, inputs) 1228 self.assertIsNotNone(res[0][0].grad_fn) 1229 self.assertIsNotNone(res[0][1].grad_fn) 1230 self.assertIsNotNone(res[1][0].grad_fn) 1231 self.assertIsNotNone(res[1][1].grad_fn) 1232 1233 def flatten(inp): 1234 return tuple(el_lvl2 for el_lvl1 in inp for el_lvl2 in el_lvl1) 1235 1236 gradcheck( 1237 lambda *inp: flatten( 1238 autogradF.hessian( 1239 add_pow_reducer, inp, create_graph=True, vectorize=vectorize 1240 ) 1241 ), 1242 inputs, 1243 ) 1244 gradgradcheck( 1245 lambda *inp: flatten( 1246 autogradF.hessian( 1247 add_pow_reducer, inp, create_graph=True, vectorize=vectorize 1248 ) 1249 ), 1250 inputs, 1251 ) 1252 1253 def foo(x, y): 1254 x = x.cos() 1255 val, hess = autogradF.hessian( 1256 add_pow_reducer, (x, y), create_graph=True, vectorize=vectorize 1257 ) 1258 1259 res = val[0].cos().sum() + val[1].cos().sum() + hess[0].cos().sum() 1260 res = res + hess[1].cos().sum() + x.cos().sum() + y.cos().sum() 1261 return res 1262 1263 gradcheck(foo, inputs) 1264 gradgradcheck(foo, inputs) 1265 1266 @base_and_logging_tensor 1267 def test_vhp_err_check(self, ctors): 1268 def foo(a): 1269 return 3 * a.narrow(0, 0, 3).exp().sum() 1270 1271 def bar(a): 1272 return 3 * a.narrow(0, 0, 3), "bar" 1273 1274 def bar2(a): 1275 return 3 * a.narrow(0, 0, 3) 1276 1277 inp = ctors.rand(4) 1278 v = ctors.rand(4) 1279 with self.assertRaisesRegex( 1280 TypeError, "The inputs given to vhp must be either a Tensor" 1281 ): 1282 res = autogradF.vhp(foo, (inp, 2), v) 1283 1284 with self.assertRaisesRegex( 1285 TypeError, "The outputs of the user-provided function given to vhp must" 1286 ): 1287 res = autogradF.vhp(bar, inp, v) 1288 1289 err_msg_out = "The Tensor returned by the function given to vhp should contain a single element" 1290 with self.assertRaisesRegex(RuntimeError, err_msg_out): 1291 res = autogradF.vhp(bar2, inp, v) 1292 1293 with self.assertRaisesRegex(RuntimeError, "v has invalid size:"): 1294 res = autogradF.vhp(foo, inp, ctors.rand(5)) 1295 1296 with self.assertRaisesRegex( 1297 TypeError, 1298 "The v given to vhp must be either a Tensor or a tuple of Tensors", 1299 ): 1300 res = autogradF.vhp(foo, inp, (v, 2)) 1301 1302 res = autogradF.vhp(foo, inp, v) 1303 self._assert_same_struct(res[1], inp) 1304 1305 def foo(a, b): 1306 return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum() 1307 1308 inp = (ctors.rand(4), ctors.rand(5)) 1309 v = (ctors.rand(4), ctors.rand(5)) 1310 1311 res = autogradF.vhp(foo, inp, v) 1312 self._assert_same_struct(res[1], inp) 1313 1314 @base_and_logging_tensor 1315 def test_vhp_err_check_strict(self, ctors): 1316 def foo(a): 1317 return a.detach().sum() 1318 1319 def bar(a): 1320 # Make a non-leaf Tensor that requires_grad but that is not connected to the input 1321 return a.long().float().requires_grad_().clone().sum() 1322 1323 def bar2(a): 1324 # A Linear function for which the jacobian is independent of the input 1325 return (3 * a).sum() 1326 1327 inp = ctors.rand(4) 1328 v = ctors.rand(4) 1329 with self.assertRaisesRegex( 1330 RuntimeError, 1331 "Output 0 of the user-provided function does not require gradients.", 1332 ): 1333 res = autogradF.vhp(foo, inp, v, strict=True) 1334 res = autogradF.vhp(foo, inp, v, strict=False) 1335 self._assert_same_struct(res[1], inp) 1336 self.assertEqual(res[1].abs().sum(), 0.0) 1337 1338 with self.assertRaisesRegex( 1339 RuntimeError, 1340 "The output of the user-provided function is independent of input 0", 1341 ): 1342 res = autogradF.vhp(bar, inp, v, strict=True) 1343 res = autogradF.vhp(bar, inp, v, strict=False) 1344 self._assert_same_struct(res[1], inp) 1345 self.assertEqual(res[1].abs().sum(), 0.0) 1346 1347 with self.assertRaisesRegex( 1348 RuntimeError, 1349 "jacobian of the user-provided function with respect to input 0 is", 1350 ): 1351 res = autogradF.vhp(bar2, inp, v, strict=True) 1352 res = autogradF.vhp(bar2, inp, v, strict=False) 1353 self._assert_same_struct(res[1], inp) 1354 self.assertEqual(res[1].abs().sum(), 0.0) 1355 1356 @base_and_logging_tensor 1357 def test_vhp_no_grad(self, ctors): 1358 def reducer(x): 1359 return x.exp().sum() 1360 1361 inputs = ctors.rand(4, 4) 1362 v = ctors.ones(4, 4) 1363 with torch.no_grad(): 1364 res = autogradF.vhp(reducer, inputs, v) 1365 self.assertIsNone(res[0].grad_fn) 1366 self.assertIsNone(res[1].grad_fn) 1367 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1368 1369 with torch.no_grad(): 1370 res = autogradF.vhp(reducer, inputs, v, create_graph=True) 1371 self.assertIsNotNone(res[0].grad_fn) 1372 self.assertIsNotNone(res[1].grad_fn) 1373 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1374 1375 @base_and_logging_tensor 1376 def test_vhp_output(self, ctors): 1377 def foo(a): 1378 return 3 * a.narrow(0, 0, 3).exp().sum() 1379 1380 inputs = ctors.rand(4, 4) 1381 v = ctors.ones(4, 4) 1382 res = autogradF.vhp(foo, inputs, v) 1383 self._assert_same_struct(res[1], inputs) 1384 self.assertIsNone(res[0].grad_fn) 1385 self.assertIsNone(res[1].grad_fn) 1386 1387 def bar(a, b): 1388 return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1389 1390 inputs = (ctors.rand(3), ctors.rand(4)) 1391 v = (ctors.ones(3), ctors.ones(4)) 1392 out, vhp_val = autogradF.vhp(bar, inputs, v) 1393 self._assert_same_struct(vhp_val, inputs) 1394 self.assertIsNone(out.grad_fn) 1395 self.assertIsNone(vhp_val[0].grad_fn) 1396 self.assertIsNone(vhp_val[1].grad_fn) 1397 1398 @base_and_logging_tensor 1399 def test_vhp_scalar(self, ctors): 1400 def reducer(x): 1401 return x.sum() 1402 1403 inputs = ctors.rand(4, 4) 1404 v = ctors.ones(4, 4) 1405 res = autogradF.vhp(reducer, inputs, v) 1406 self._assert_same_struct(res[1], inputs) 1407 1408 inputs = ctors.rand([]) 1409 v = ctors.rand([]) 1410 res = autogradF.vhp(reducer, inputs, v) 1411 self._assert_same_struct(res[1], inputs) 1412 1413 res = autogradF.vhp(reducer, inputs) 1414 self._assert_same_struct(res[1], inputs) 1415 1416 def bad_reducer(x): 1417 return x.sum().view(1, 1, 1) 1418 1419 inputs = ctors.rand(4, 4) 1420 v = ctors.rand(4, 4) 1421 res = autogradF.vhp(bad_reducer, inputs, v) 1422 self._assert_same_struct(res[1], inputs) 1423 1424 @base_and_logging_tensor 1425 def test_vhp_create_graph(self, ctors): 1426 def foo(a): 1427 return 3 * a.narrow(0, 0, 3).exp().sum() 1428 1429 inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True) 1430 v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True) 1431 res = autogradF.vhp(foo, inputs, v, create_graph=True) 1432 self._assert_same_struct(res[1], inputs) 1433 self.assertIsNotNone(res[0].grad_fn) 1434 self.assertIsNotNone(res[1].grad_fn) 1435 1436 gradcheck( 1437 lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v) 1438 ) 1439 gradgradcheck( 1440 lambda inp, v: autogradF.vhp(foo, inp, v, create_graph=True), (inputs, v) 1441 ) 1442 1443 def bar(a, b): 1444 return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1445 1446 inputs = ( 1447 ctors.rand(3, dtype=torch.double, requires_grad=True), 1448 ctors.rand(4, dtype=torch.double, requires_grad=True), 1449 ) 1450 v = ( 1451 ctors.ones(3, dtype=torch.double, requires_grad=True), 1452 ctors.ones(4, dtype=torch.double, requires_grad=True), 1453 ) 1454 out, vhp_val = autogradF.vhp(bar, inputs, v, create_graph=True) 1455 self._assert_same_struct(vhp_val, inputs) 1456 self.assertIsNotNone(out.grad_fn) 1457 self.assertIsNotNone(vhp_val[0].grad_fn) 1458 self.assertIsNotNone(vhp_val[1].grad_fn) 1459 1460 gradcheck( 1461 lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1], 1462 inputs + v, 1463 ) 1464 gradgradcheck( 1465 lambda *args: autogradF.vhp(bar, args[:2], args[2:], create_graph=True)[1], 1466 inputs + v, 1467 ) 1468 1469 def foo(*args): 1470 x, y = args[:2] 1471 v = args[2:] 1472 1473 x = x.cos() 1474 val, grad = autogradF.vhp(bar, (x, y), v, create_graph=True) 1475 1476 return ( 1477 val.cos() 1478 + grad[0].cos().sum() 1479 + grad[1].cos() 1480 + x.cos().sum() 1481 + y.cos() 1482 ) 1483 1484 gradcheck(foo, inputs + v) 1485 gradgradcheck(foo, inputs + v) 1486 1487 @base_and_logging_tensor 1488 def test_hvp_err_check(self, ctors): 1489 def foo(a): 1490 return 3 * a.narrow(0, 0, 3).exp().sum() 1491 1492 def bar(a): 1493 return 3 * a.narrow(0, 0, 3), "bar" 1494 1495 def bar2(a): 1496 return 3 * a.narrow(0, 0, 3) 1497 1498 inp = ctors.rand(4) 1499 v = ctors.rand(4) 1500 res = autogradF.hvp(foo, inp, v) 1501 with self.assertRaisesRegex( 1502 TypeError, "The inputs given to hvp must be either a Tensor" 1503 ): 1504 res = autogradF.hvp(foo, (inp, 2), v) 1505 1506 with self.assertRaisesRegex( 1507 TypeError, "The outputs of the user-provided function given to hvp must" 1508 ): 1509 res = autogradF.hvp(bar, inp, v) 1510 1511 err_msg_out = "The Tensor returned by the function given to hvp should contain a single element" 1512 with self.assertRaisesRegex(RuntimeError, err_msg_out): 1513 res = autogradF.hvp(bar2, inp, v) 1514 1515 with self.assertRaisesRegex(RuntimeError, "v has invalid size:"): 1516 res = autogradF.hvp(foo, inp, ctors.rand(5)) 1517 1518 with self.assertRaisesRegex( 1519 TypeError, 1520 "The v given to hvp must be either a Tensor or a tuple of Tensors", 1521 ): 1522 res = autogradF.hvp(foo, inp, (v, 2)) 1523 1524 res = autogradF.hvp(foo, inp, v) 1525 self._assert_same_struct(res[1], inp) 1526 1527 def foo(a, b): 1528 return (3 * b.narrow(0, 0, 3) * a.narrow(0, 0, 3)).sum() 1529 1530 inp = (ctors.rand(4), ctors.rand(5)) 1531 v = (ctors.rand(4), ctors.rand(5)) 1532 1533 res = autogradF.hvp(foo, inp, v) 1534 self._assert_same_struct(res[1], inp) 1535 1536 @base_and_logging_tensor 1537 def test_hvp_err_check_strict(self, ctors): 1538 def foo(a): 1539 return a.detach().sum() 1540 1541 def bar(a): 1542 # Make a non-leaf Tensor that requires_grad but that is not connected to the input 1543 return a.long().float().requires_grad_().clone().sum() 1544 1545 def bar2(a): 1546 # A Linear function for which the jacobian is independent of the input 1547 return (3 * a).sum() 1548 1549 inp = ctors.rand(4) 1550 v = ctors.rand(4) 1551 with self.assertRaisesRegex( 1552 RuntimeError, 1553 "Output 0 of the user-provided function does not require gradients.", 1554 ): 1555 res = autogradF.hvp(foo, inp, v, strict=True) 1556 res = autogradF.hvp(foo, inp, v, strict=False) 1557 self._assert_same_struct(res[1], inp) 1558 self.assertEqual(res[1].abs().sum(), 0.0) 1559 1560 with self.assertRaisesRegex( 1561 RuntimeError, 1562 "The output of the user-provided function is independent of input 0", 1563 ): 1564 res = autogradF.hvp(bar, inp, v, strict=True) 1565 res = autogradF.hvp(bar, inp, v, strict=False) 1566 self._assert_same_struct(res[1], inp) 1567 self.assertEqual(res[1].abs().sum(), 0.0) 1568 1569 with self.assertRaisesRegex( 1570 RuntimeError, 1571 "jacobian of the user-provided function with respect to input 0 is", 1572 ): 1573 res = autogradF.hvp(bar2, inp, v, strict=True) 1574 res = autogradF.hvp(bar2, inp, v, strict=False) 1575 self._assert_same_struct(res[1], inp) 1576 self.assertEqual(res[1].abs().sum(), 0.0) 1577 1578 @base_and_logging_tensor 1579 def test_hvp_no_grad(self, ctors): 1580 def reducer(x): 1581 return x.exp().sum() 1582 1583 inputs = ctors.rand(4, 4) 1584 v = ctors.ones(4, 4) 1585 with torch.no_grad(): 1586 res = autogradF.hvp(reducer, inputs, v) 1587 self.assertIsNone(res[0].grad_fn) 1588 self.assertIsNone(res[1].grad_fn) 1589 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1590 1591 with torch.no_grad(): 1592 res = autogradF.hvp(reducer, inputs, v, create_graph=True) 1593 self.assertIsNotNone(res[0].grad_fn) 1594 self.assertIsNotNone(res[1].grad_fn) 1595 self.assertNotEqual(res[1], ctors.zeros(4, 4)) 1596 1597 @base_and_logging_tensor 1598 def test_hvp_output(self, ctors): 1599 def foo(a): 1600 return 3 * a.narrow(0, 0, 3).exp().sum() 1601 1602 inputs = ctors.rand(4, 4) 1603 v = ctors.ones(4, 4) 1604 res = autogradF.hvp(foo, inputs, v) 1605 self._assert_same_struct(res[1], inputs) 1606 self.assertIsNone(res[0].grad_fn) 1607 self.assertIsNone(res[1].grad_fn) 1608 1609 def bar(a, b): 1610 return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1611 1612 inputs = (ctors.rand(3), ctors.rand(4)) 1613 v = (ctors.ones(3), ctors.ones(4)) 1614 out, hvp_val = autogradF.hvp(bar, inputs, v) 1615 self._assert_same_struct(hvp_val, inputs) 1616 self.assertIsNone(out.grad_fn) 1617 self.assertIsNone(hvp_val[0].grad_fn) 1618 self.assertIsNone(hvp_val[1].grad_fn) 1619 1620 @base_and_logging_tensor 1621 def test_hvp_scalar(self, ctors): 1622 def reducer(x): 1623 return x.exp().sum() 1624 1625 inputs = ctors.rand(4, 4) 1626 v = ctors.ones(4, 4) 1627 res = autogradF.hvp(reducer, inputs, v) 1628 self._assert_same_struct(res[1], inputs) 1629 1630 inputs = ctors.rand([]) 1631 v = ctors.rand([]) 1632 res = autogradF.hvp(reducer, inputs, v) 1633 self._assert_same_struct(res[1], inputs) 1634 1635 res = autogradF.hvp(reducer, inputs) 1636 self._assert_same_struct(res[1], inputs) 1637 1638 def bad_reducer(x): 1639 return x.exp().sum().view(1, 1, 1) 1640 1641 inputs = ctors.rand(4, 4) 1642 v = ctors.rand(4, 4) 1643 res = autogradF.hvp(bad_reducer, inputs, v) 1644 self._assert_same_struct(res[1], inputs) 1645 1646 @base_and_logging_tensor 1647 def test_hvp_create_graph(self, ctors): 1648 def foo(a): 1649 return 3 * a.narrow(0, 0, 3).exp().sum() 1650 1651 inputs = ctors.rand(4, 4, dtype=torch.double, requires_grad=True) 1652 v = ctors.ones(4, 4, dtype=torch.double, requires_grad=True) 1653 res = autogradF.hvp(foo, inputs, v, create_graph=True) 1654 self._assert_same_struct(res[1], inputs) 1655 self.assertIsNotNone(res[0].grad_fn) 1656 self.assertIsNotNone(res[1].grad_fn) 1657 1658 gradcheck( 1659 lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v) 1660 ) 1661 gradgradcheck( 1662 lambda inp, v: autogradF.hvp(foo, inp, v, create_graph=True), (inputs, v) 1663 ) 1664 1665 def bar(a, b): 1666 return (a + 3 * b.narrow(0, 0, 3)).exp().sum() 1667 1668 inputs = ( 1669 ctors.rand(3, dtype=torch.double, requires_grad=True), 1670 ctors.rand(4, dtype=torch.double, requires_grad=True), 1671 ) 1672 v = ( 1673 ctors.ones(3, dtype=torch.double, requires_grad=True), 1674 ctors.ones(4, dtype=torch.double, requires_grad=True), 1675 ) 1676 out, hvp_val = autogradF.hvp(bar, inputs, v, create_graph=True) 1677 self._assert_same_struct(hvp_val, inputs) 1678 self.assertIsNotNone(out.grad_fn) 1679 self.assertIsNotNone(hvp_val[0].grad_fn) 1680 self.assertIsNotNone(hvp_val[1].grad_fn) 1681 1682 gradcheck( 1683 lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1], 1684 inputs + v, 1685 ) 1686 gradgradcheck( 1687 lambda *args: autogradF.hvp(bar, args[:2], args[2:], create_graph=True)[1], 1688 inputs + v, 1689 ) 1690 1691 def foo(*args): 1692 x, y = args[:2] 1693 v = args[2:] 1694 1695 x = x.cos() 1696 val, grad = autogradF.hvp(bar, (x, y), v, create_graph=True) 1697 1698 return ( 1699 val.cos() 1700 + grad[0].cos().sum() 1701 + grad[1].cos() 1702 + x.cos().sum() 1703 + y.cos() 1704 ) 1705 1706 gradcheck(foo, inputs + v) 1707 gradgradcheck(foo, inputs + v) 1708 1709 @base_and_logging_tensor 1710 def test_jacobian_match_vjp_jvp(self, ctors): 1711 def foo(x): 1712 return x**3 + x.sum() 1713 1714 inputs = ctors.rand(4) 1715 v = ctors.rand(4) 1716 1717 jac = autogradF.jacobian(foo, inputs) 1718 jvp = autogradF.jvp(foo, inputs, v)[1] 1719 vjp = autogradF.vjp(foo, inputs, v)[1] 1720 1721 self.assertEqual(jvp, torch.mm(jac, v.unsqueeze(1)).squeeze(1)) 1722 self.assertEqual(vjp, torch.mm(v.unsqueeze(0), jac).squeeze(0)) 1723 1724 @base_and_logging_tensor 1725 def test_hessian_match_vhp_hvp(self, ctors): 1726 def foo(a): 1727 return 3 * a.narrow(0, 0, 3).exp().sum() 1728 1729 inputs = ctors.rand(4) 1730 v = ctors.rand(4) 1731 1732 hes = autogradF.hessian(foo, inputs) 1733 hvp = autogradF.hvp(foo, inputs, v)[1] 1734 vhp = autogradF.vhp(foo, inputs, v)[1] 1735 1736 self.assertEqual(hvp, torch.mm(hes, v.unsqueeze(1)).squeeze(1)) 1737 self.assertEqual(vhp, torch.mm(v.unsqueeze(0), hes).squeeze(0)) 1738 1739 1740instantiate_parametrized_tests(TestAutogradFunctional) 1741 1742if __name__ == "__main__": 1743 run_tests() 1744