1# Owner(s): ["module: autograd"] 2 3import collections 4import contextlib 5import functools 6import gc 7import io 8import math 9import operator 10import os 11import pickle 12import random 13import subprocess 14import sys 15import tempfile 16import threading 17import time 18import unittest 19import uuid 20import warnings 21import weakref 22from collections import OrderedDict 23from copy import deepcopy 24from functools import partial, reduce 25from itertools import product 26from operator import mul 27from typing import List, Tuple, TYPE_CHECKING 28 29import torch 30import torch.autograd._functions 31import torch.autograd.forward_ad as fwAD 32from torch import inf, nan, nn 33from torch.autograd import ( 34 _calculate_shape, 35 detect_anomaly, 36 Function, 37 kineto_available, 38 Variable, 39) 40from torch.autograd.function import InplaceFunction, once_differentiable 41from torch.autograd.graph import GradientEdge 42from torch.autograd.profiler import emit_itt, emit_nvtx, profile, record_function 43from torch.autograd.profiler_util import ( 44 _format_time, 45 EventList, 46 FunctionEvent, 47 FunctionEventAvg, 48) 49from torch.testing import make_tensor 50from torch.testing._internal.common_cuda import TEST_CUDA 51from torch.testing._internal.common_device_type import ( 52 deviceCountAtLeast, 53 dtypes, 54 dtypesIfCUDA, 55 dtypesIfMPS, 56 instantiate_device_type_tests, 57 onlyCPU, 58 onlyCUDA, 59 skipMeta, 60) 61from torch.testing._internal.common_dtype import floating_types_and 62from torch.testing._internal.common_methods_invocations import mask_not_all_zeros 63from torch.testing._internal.common_utils import ( 64 disable_gc, 65 gradcheck, 66 gradgradcheck, 67 instantiate_parametrized_tests, 68 IS_MACOS, 69 IS_WINDOWS, 70 parametrize, 71 run_tests, 72 set_warn_always_context, 73 skipIfMps, 74 skipIfNoLapack, 75 skipIfTorchDynamo, 76 slowTest, 77 TestCase, 78 xfailIfTorchDynamo, 79) 80from torch.utils._mode_utils import no_dispatch 81from torch.utils._python_dispatch import TorchDispatchMode 82from torch.utils.checkpoint import ( 83 checkpoint, 84 checkpoint_sequential, 85 CheckpointPolicy, 86 create_selective_checkpoint_contexts, 87) 88from torch.utils.cpp_extension import load_inline 89from torch.utils.flop_counter import FlopCounterMode 90 91 92if TYPE_CHECKING: 93 from torch.utils.hooks import RemovableHandle 94 95 96def graph_desc(fn): 97 if fn is None: 98 return "None" 99 result = type(fn).__name__ + "(" 100 next_functions = fn.next_functions 101 for next_fn, _ in next_functions: 102 result += graph_desc(next_fn) 103 result += ", " 104 if next_functions: 105 result = result[:-2] 106 return result + ")" 107 108 109class TestAutograd(TestCase): 110 def test_copy_slices_graph_task_updates(self): 111 def f1(x, y): 112 out = x.clone().view(-1) 113 out += y 114 return out 115 116 def f2(x, y): 117 out = x.clone().view(-1) 118 b = out * 2 119 out += y 120 return out + b 121 122 x = torch.rand(2, requires_grad=True) 123 y = torch.rand(2, requires_grad=True) 124 125 y_safe = torch._C._functions.DelayedError("Boom!", 1)(y) 126 127 for f in [f1, f2]: 128 # Ensure that the error Node works 129 out = f(x, y_safe) 130 with self.assertRaisesRegex(RuntimeError, "Boom!"): 131 out.sum().backward() 132 133 out = f(x, y_safe) 134 with self.assertRaisesRegex(RuntimeError, "Boom!"): 135 torch.autograd.grad(out.sum(), y) 136 137 # Ensure that if we don't ask for y, it doesn't crash 138 out = f(x, y_safe) 139 torch.autograd.grad(out.sum(), x) 140 141 out = f(x, y_safe) 142 torch.autograd.grad(out.sum(), y_safe) 143 144 out = f(x, y_safe) 145 torch.autograd.grad(out.sum(), (x, y_safe)) 146 147 # Ensure that we don't run extra view Node 148 def f3(x, y): 149 out = x.clone().view(-1) 150 151 def hook(*args): 152 # This should never be called! 153 self.assertTrue(False) 154 155 out.register_hook(hook) 156 157 b = out + y 158 out += y 159 return out + b, b 160 161 out, b = f3(x, y_safe) 162 torch.autograd.grad(out.sum(), (b, y_safe)) 163 164 def test_grad_mode_class_decoration(self): 165 # Decorating class is deprecated and should not be used 166 with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"): 167 168 @torch.no_grad() 169 class Foo: 170 def __init__(self) -> None: 171 assert not torch.is_grad_enabled() 172 173 def foo(self): 174 # Not applied to methods 175 assert torch.is_grad_enabled() 176 177 # Show that we can actually construct the class 178 foo = Foo() 179 foo.foo() 180 181 # Decorating functions or methods is fine though 182 with warnings.catch_warnings(record=True) as w: 183 184 @torch.no_grad() 185 def foo(): 186 assert not torch.is_grad_enabled() 187 188 foo() 189 190 class Foo2: 191 @torch.no_grad() 192 def __init__(self) -> None: 193 assert not torch.is_grad_enabled() 194 195 @torch.no_grad() 196 def foo(self): 197 assert not torch.is_grad_enabled() 198 199 foo2 = Foo2() 200 foo2.foo() 201 202 self.assertEqual(len(w), 0) 203 204 def test_tensor_grad_warnings(self): 205 dummy = torch.empty(1) 206 207 with warnings.catch_warnings(record=True) as w: 208 # Accessing .grad on leaf 209 dummy.requires_grad_() 210 foo = dummy.grad 211 self.assertEqual(len(w), 0) 212 213 # Accessing .grad on non-leaf 214 dummy = dummy.clone() 215 foo = dummy.grad 216 self.assertEqual(len(w), 1) 217 218 # Accessing .grad on non-leaf that retains gradients 219 dummy.retain_grad() 220 foo = dummy.grad 221 self.assertEqual(len(w), 1) 222 223 def _function_test(self, cls): 224 x = torch.randn(5, 5, requires_grad=True) 225 y = torch.randn(5, 5, requires_grad=True) 226 result = cls.apply(x, 2, y) 227 go = torch.ones((), requires_grad=True) 228 result.sum().backward(go, create_graph=True) 229 230 self.assertEqual(x.grad, y + torch.ones(5, 5)) 231 self.assertEqual(y.grad, x + torch.ones(5, 5) * 2) 232 self.assertIsNotNone(x.grad.grad_fn) 233 self.assertIsNotNone(y.grad.grad_fn) 234 235 return x, y 236 237 def test_function(self): 238 class MyFunction(Function): 239 @staticmethod 240 def forward(ctx, tensor1, pyscalar, tensor2): 241 ctx.pyscalar = pyscalar 242 ctx.save_for_backward(tensor1, tensor2) 243 return tensor1 + pyscalar * tensor2 + tensor1 * tensor2 244 245 @staticmethod 246 def backward(ctx, grad_output): 247 var1, var2 = ctx.saved_tensors 248 # NOTE: self is the test case here 249 self.assertIsInstance(var1, torch.Tensor) 250 self.assertIsInstance(var2, torch.Tensor) 251 self.assertIsInstance(grad_output, torch.Tensor) 252 return ( 253 grad_output + grad_output * var2, 254 None, 255 grad_output * ctx.pyscalar + grad_output * var1, 256 ) 257 258 x, y = self._function_test(MyFunction) 259 260 x_grad_desc = graph_desc(x.grad.grad_fn) 261 y_grad_desc = graph_desc(y.grad.grad_fn) 262 self.assertExpected(x_grad_desc, "x_grad_desc") 263 self.assertExpected(y_grad_desc, "y_grad_desc") 264 265 def test_once_differentiable(self): 266 class MyFunction(Function): 267 @staticmethod 268 def forward(ctx, tensor1, pyscalar, tensor2): 269 ctx.pyscalar = pyscalar 270 ctx.save_for_backward(tensor1, tensor2) 271 return tensor1 + pyscalar * tensor2 + tensor1 * tensor2 272 273 @staticmethod 274 @once_differentiable 275 def backward(ctx, grad_output): 276 self.assertFalse(torch.is_grad_enabled()) 277 t1, t2 = ctx.saved_tensors 278 return ( 279 grad_output + grad_output * t2, 280 None, 281 grad_output * ctx.pyscalar + grad_output * t1, 282 ) 283 284 x, y = self._function_test(MyFunction) 285 self.assertEqual( 286 graph_desc(x.grad.grad_fn), 287 "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))", 288 ) 289 self.assertEqual( 290 graph_desc(y.grad.grad_fn), 291 "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))", 292 ) 293 294 def test_function_returns_input(self): 295 class MyFunction(Function): 296 @staticmethod 297 def forward(ctx, x): 298 return x 299 300 @staticmethod 301 def backward(ctx, grad): 302 return grad * 2 303 304 for shape in [(1,), ()]: 305 v = torch.ones(shape, requires_grad=True) 306 MyFunction.apply(v).backward() 307 self.assertEqual(v.grad, torch.full(shape, 2.0)) 308 309 with torch.no_grad(): 310 v.grad.zero_() 311 MyFunction.apply(v.clone()).backward() 312 self.assertEqual(v.grad, torch.full(shape, 2.0)) 313 314 def test_function_returns_undefined_tensor(self): 315 class MyFunction(Function): 316 @staticmethod 317 def forward(ctx, x): 318 return x * 2 319 320 @staticmethod 321 def backward(ctx, grad): 322 return None 323 324 # Test that undefined tensors returned from custom backward function 325 # are propagated as undefined and not tensor full of zeroes 326 x = torch.ones(1, requires_grad=True) 327 328 MyFunction.apply(x).backward() 329 self.assertIsNone(x.grad) 330 331 MyFunction.apply(x**2).backward() 332 self.assertIsNone(x.grad) 333 334 MyFunction.apply(x).sum().backward() 335 self.assertIsNone(x.grad) 336 337 self.assertIsNone( 338 torch.autograd.grad(MyFunction.apply(x), x, allow_unused=True)[0] 339 ) 340 341 def test_materialize_grads(self): 342 class MyFunction(Function): 343 @staticmethod 344 def forward(ctx, x): 345 return x 346 347 @staticmethod 348 def backward(ctx, grad): 349 self.assertEqual(grad, torch.zeros(1)) 350 return grad 351 352 x = torch.ones(1, requires_grad=True) 353 torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward() 354 355 def test_dont_materialize_grads(self): 356 class MyFunction(Function): 357 @staticmethod 358 def forward(ctx, x): 359 ctx.set_materialize_grads(False) 360 return x 361 362 @staticmethod 363 def backward(ctx, grad): 364 self.assertIsNone(grad) 365 return grad 366 367 x = torch.ones(1, requires_grad=True) 368 torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward() 369 370 @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 371 def test_set_materialize_non_diff_grads(self): 372 class Func(torch.autograd.Function): 373 @staticmethod 374 def forward(ctx, x): 375 out0 = x.clone() 376 out1 = x.clone() 377 ctx.mark_non_differentiable(out1) 378 ctx._materialize_non_diff_grads = False 379 return out0, out1 380 381 @staticmethod 382 def backward(ctx, g0, g1): 383 self.assertIsNone(g1) 384 return g0 385 386 a = torch.tensor(1.0, requires_grad=True) 387 out = Func.apply(a)[0] 388 out.backward() 389 390 def test_legacy_function_deprecation_exception(self): 391 # Trigger exception 392 class MyFunction(Function): 393 def forward(self, x): 394 return x 395 396 def backward(self, grad_output): 397 return grad_output 398 399 # Check exception occurs 400 with self.assertRaisesRegex( 401 RuntimeError, 402 "Legacy autograd function with non-static forward method is deprecated", 403 ): 404 MyFunction()(torch.randn(3, 4)) 405 406 class SimulateBackwardError(Function): 407 @staticmethod 408 def forward(ctx, input): 409 return input.clone() 410 411 @staticmethod 412 @once_differentiable 413 def backward(ctx, input): 414 raise Exception("Simulate error on backward pass") # noqa: TRY002 415 416 def test_custom_function_exception(self): 417 t1 = torch.rand((3, 3), requires_grad=True) 418 t2 = torch.rand((3, 3), requires_grad=True) 419 420 tmp = (t1 + t2) * (t1 + t2) 421 t3 = TestAutograd.SimulateBackwardError.apply(tmp) 422 with self.assertRaisesRegex(Exception, "Simulate error on backward pass"): 423 t3.sum().backward() 424 425 def test_custom_function_non_tensor_inputs_outputs(self): 426 class MyFunction(Function): 427 @staticmethod 428 def forward(ctx, t1, t2, scale, t3): 429 t4 = t1 + t2 * t3 430 t5 = t1 * t2 + t3 431 t4 *= scale 432 t5 *= scale 433 434 # Save scale 435 ctx.scale = scale 436 ctx.save_for_backward(t1, t2, t3) 437 return scale, t4, None, True, t5, "bar", t1 438 439 @staticmethod 440 @once_differentiable 441 def backward(ctx, *grads): 442 # Verify grads 443 self.assertEqual(7, len(grads)) 444 self.assertIsNone(grads[0]) 445 self.assertIsNone(grads[2]) 446 self.assertIsNone(grads[3]) 447 self.assertIsNone(grads[5]) 448 449 scale = ctx.scale 450 var1, var2, var3 = ctx.saved_tensors 451 return ( 452 grads[1] * scale + grads[4] * var2 * scale + grads[6], 453 grads[1] * var3 * scale + grads[4] * var1 * scale, 454 None, 455 grads[1] * var2 * scale + grads[4] * scale, 456 ) 457 458 t1 = torch.rand(10, dtype=torch.double, requires_grad=True) 459 t2 = torch.rand(10, dtype=torch.double, requires_grad=True) 460 t3 = torch.rand(10, dtype=torch.double) 461 scale = random.randint(0, 10) 462 res = MyFunction.apply(t1, t2, scale, t3) 463 self.assertEqual(scale, res[0]) 464 self.assertEqual((t1 + t2 * t3) * scale, res[1]) 465 self.assertEqual(None, res[2]) 466 self.assertEqual(True, res[3]) 467 self.assertEqual((t1 * t2 + t3) * scale, res[4]) 468 self.assertEqual("bar", res[5]) 469 self.assertEqual(t1, res[6]) 470 471 # Validate running backward. 472 torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) 473 self.assertIsNotNone(t1.grad) 474 self.assertIsNotNone(t2.grad) 475 self.assertIsNone(t3.grad) 476 477 # Test gradcheck 478 def foo(t1, t2, t3): 479 res = MyFunction.apply(t1, t2, scale, t3) 480 return res[1], res[4], res[6] 481 482 gradcheck(foo, (t1, t2, t3)) 483 484 def test_custom_function_no_tensors(self): 485 class MyFunction(Function): 486 @staticmethod 487 def forward(ctx, t1, t2, scale, t3): 488 t4 = t1 + t2 * t3 489 t5 = t1 * t2 + t3 490 t4 *= scale 491 t5 *= scale 492 return scale, t4, None, True, t5, "bar", t1 493 494 @staticmethod 495 @once_differentiable 496 def backward(ctx, *args): 497 return (args[0], args[1], None, args[2]) 498 499 t1 = random.random() 500 t2 = random.random() 501 t3 = random.random() 502 scale = random.randint(0, 10) 503 res = MyFunction.apply(t1, t2, scale, t3) 504 self.assertEqual(scale, res[0]) 505 self.assertEqual((t1 + t2 * t3) * scale, res[1]) 506 self.assertEqual(None, res[2]) 507 self.assertEqual(True, res[3]) 508 self.assertEqual((t1 * t2 + t3) * scale, res[4]) 509 self.assertEqual("bar", res[5]) 510 self.assertEqual(t1, res[6]) 511 512 def test_invalid_gradients(self): 513 class MyFunction(Function): 514 @staticmethod 515 def forward(ctx, x): 516 return x * 2 517 518 @staticmethod 519 def backward(ctx, grad_output): 520 return torch.randn(10, dtype=torch.float) 521 522 with self.assertRaisesRegex(RuntimeError, "expected shape"): 523 input = torch.randn(5, 5, dtype=torch.float, requires_grad=True) 524 MyFunction.apply(input).sum().backward() 525 526 def test_unrelated_inputs(self): 527 # test to ensure grad(grad)check runs successfully even if there is an 528 # unrelated (but differentiable) inputs 529 530 def my_function(x, y): 531 return x * x 532 533 x = torch.rand(10, dtype=torch.double, requires_grad=True) 534 y = torch.rand(10, dtype=torch.double, requires_grad=True) 535 536 gradcheck(my_function, (x, y)) 537 gradgradcheck(my_function, (x, y)) 538 539 def test_not_implemented_grad(self): 540 a = torch.rand(2, requires_grad=True) 541 # if grad for nextafter ends up being implemented, this should be changed 542 y = torch.nextafter(a, a).sum() 543 with self.assertRaisesRegex( 544 NotImplementedError, "the derivative for .* is not implemented" 545 ): 546 y.backward() 547 548 def test_not_implemented_fwad(self): 549 x = torch.randn(3) 550 v = torch.rand(3) 551 552 with fwAD.dual_level(): 553 dual_x = fwAD.make_dual(x, v) 554 555 err_msg = r"Trying to use forward AD with .* that does not support it" 556 hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError" 557 558 with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): 559 # if forward AD ends up being implemented for torch.igamma, choose a different op 560 torch.igamma(dual_x, dual_x) 561 562 def test_saved_tensor_hooks_extra_exit_during_bw_no_crash(self): 563 # This usage of saved tensor is not supported, but should not crash 564 def unpack(x): 565 ctx_1.__exit__() 566 return x 567 568 ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack) 569 ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x) 570 571 for i in range(10): 572 with ctx_2: 573 ctx_1.__enter__() 574 x = torch.randn(3, 3, requires_grad=True) 575 x.sin().sum().backward() 576 577 # Clean up 578 for i in range(10): 579 ctx_1.__exit__() 580 581 # Validate there are no more hooks on the stack 582 a = torch.tensor(1.0, requires_grad=True) 583 y = a.exp() 584 y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x) 585 586 def test_saved_tensor_hooks_extra_enter_during_bw_no_leak(self): 587 # This usage of saved tensor is not supported, but should not leak 588 def scope(): 589 def unpack(x): 590 weak_ctx_1().__enter__() 591 return x 592 593 ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack) 594 weak_ctx_1 = weakref.ref(ctx_1) 595 596 x = torch.randn(3, 3, requires_grad=True) 597 with ctx_1: 598 x.sin().sum().backward() 599 return weakref.ref(unpack) 600 601 with disable_gc(): 602 unpack_hook_ref = scope() 603 self.assertIsNone(unpack_hook_ref()) 604 605 def test_will_engine_execute_node(self): 606 counter = [0] 607 608 class MyFunction(Function): 609 @staticmethod 610 def forward(ctx, x): 611 return x * 2 612 613 @staticmethod 614 def backward(ctx, gO): 615 return gO * 2 616 617 def get_grad_fn(t): 618 if t.requires_grad and t.grad_fn is None: 619 return t.clone().grad_fn.next_functions[0][0] 620 else: 621 return t.grad_fn 622 623 a = torch.randn(2, 3, 4, requires_grad=True) 624 a2 = torch.randn(2, 3, 4, requires_grad=True) 625 b = a * a2 626 b2 = b.cos() 627 c = MyFunction.apply(b) 628 629 should_execute = list(map(get_grad_fn, (a, b, c))) 630 should_not_execute = list(map(get_grad_fn, (a2, b2))) 631 632 def fn(x): 633 counter[0] += 1 634 635 for g in should_execute: 636 self.assertTrue(torch._C._will_engine_execute_node(g)) 637 638 for g in should_not_execute: 639 self.assertFalse(torch._C._will_engine_execute_node(g)) 640 641 b.register_hook(fn) 642 c.register_hook(fn) 643 644 # .backward(inputs=) is OK 645 out = c.sum() 646 torch.autograd.backward(out, inputs=(a, b), retain_graph=True) 647 self.assertEqual(counter[0], 2) 648 649 # .backward() is OK 650 should_execute = list(map(get_grad_fn, (a, a2, b, c))) 651 should_not_execute = list(map(get_grad_fn, (b2,))) 652 torch.autograd.backward(out, retain_graph=True) 653 654 # .grad is NOT OK when leaf is passed (this is the current state, subject to change) 655 with self.assertRaisesRegex( 656 RuntimeError, "are currently running autograd.grad()" 657 ): 658 torch.autograd.grad(out, (a,)) 659 660 # .grad is OK when non-leaf is passed 661 a = torch.randn(1, 2, 3, requires_grad=True) * 2 662 b = a * 2 663 664 def fn(x): 665 # Check a non-leaf 666 counter[0] += 1 667 self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn)) 668 669 b.register_hook(fn) 670 counter[0] = 0 671 torch.autograd.grad(b.sum(), (a,)) 672 self.assertEqual(counter[0], 1) 673 674 # Verify other errors are raised 675 with self.assertRaisesRegex(RuntimeError, "during the backward pass"): 676 torch._C._will_engine_execute_node(out.grad_fn) 677 678 with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"): 679 torch._C._will_engine_execute_node(out) 680 681 def test_custom_function_vmap_defaults(self): 682 class MySquare(Function): 683 @staticmethod 684 def forward(x): 685 return x**2 686 687 @staticmethod 688 def setup_context(ctx, inputs, output): 689 (x,) = inputs 690 ctx.save_for_backward(x) 691 692 @staticmethod 693 def backward(ctx, gO): 694 (x,) = ctx.saved_tensors 695 return gO * 2 * x 696 697 self.assertFalse(MySquare.generate_vmap_rule) 698 self.assertTrue(hasattr(MySquare, "vmap")) 699 700 def test_custom_function_setup_context_simple(self): 701 class MySquare(Function): 702 @staticmethod 703 def forward(x): 704 return x**2 705 706 @staticmethod 707 def setup_context(ctx, inputs, output): 708 (x,) = inputs 709 ctx.save_for_backward(x) 710 711 @staticmethod 712 def backward(ctx, gO): 713 (x,) = ctx.saved_tensors 714 return gO * 2 * x 715 716 x = torch.randn([], requires_grad=True) 717 y = MySquare.apply(x) 718 (gx,) = torch.autograd.grad(y, x) 719 self.assertEqual(gx, 2 * x) 720 721 def test_custom_function_setup_context_multi_output(self): 722 # Multiple outputs with some non-Tensor outputs. 723 class MySquare(Function): 724 @staticmethod 725 def forward(x): 726 two_x = x.item() * 2 727 return x**2, two_x 728 729 @staticmethod 730 def setup_context(ctx, inputs, output): 731 (x,) = inputs 732 _, two_x = output 733 ctx.two_x = two_x 734 735 @staticmethod 736 @once_differentiable 737 def backward(ctx, gO, _): 738 return gO * ctx.two_x 739 740 x = torch.randn([], requires_grad=True) 741 y, _ = MySquare.apply(x) 742 (gx,) = torch.autograd.grad(y, x) 743 self.assertEqual(gx, 2 * x) 744 745 def test_custom_function_setup_context_multi_input(self): 746 class MyReshape(Function): 747 @staticmethod 748 def forward(x, shape, scale_forward, scale_backward): 749 return x.reshape(shape) * scale_forward 750 751 @staticmethod 752 def setup_context(ctx, inputs, output): 753 x, shape, scale_forward, scale_backward = inputs 754 ctx.scale_backward = scale_backward 755 ctx.x_shape = x.shape 756 757 @staticmethod 758 def backward(ctx, gO): 759 return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None 760 761 class MyReshapeRef(Function): 762 @staticmethod 763 def forward(ctx, x, shape, scale_forward, scale_backward): 764 ctx.scale_backward = scale_backward 765 ctx.x_shape = x.shape 766 return x.reshape(shape) * scale_forward 767 768 @staticmethod 769 def backward(ctx, gO): 770 return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None 771 772 def test(x, shape, scale_forward, scale_backward): 773 y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum() 774 (gx,) = torch.autograd.grad(y, x) 775 776 y_expected = MyReshapeRef.apply( 777 x, shape, scale_forward, scale_backward 778 ).sum() 779 (gx_expected,) = torch.autograd.grad(y_expected, x) 780 781 self.assertEqual(y_expected, y) 782 self.assertEqual(gx_expected, gx) 783 784 test(torch.randn(24, requires_grad=True), (3, 8), 7, 11) 785 test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2) 786 787 def test_multiple_insert_removal_caching(self): 788 torch._C._set_cached_tensors_enabled(True) 789 try: 790 x = torch.rand([4]) 791 792 torch._C._add_cached_tensor(x) 793 self.assertTrue(torch._C._is_cached_tensor(x)) 794 795 torch._C._add_cached_tensor(x) 796 torch._C._remove_cached_tensor(x) 797 798 self.assertFalse(torch._C._is_cached_tensor(x)) 799 finally: 800 torch._C._set_cached_tensors_enabled(False) 801 802 def test_accumulate_grad(self): 803 grad_output = torch.ones(5, 5) 804 805 def compute_grad(create_graph): 806 x = torch.randn(5, 5, requires_grad=True) 807 y = x + 2 808 y.backward(grad_output, retain_graph=True) 809 x_grad = x.grad 810 x_grad_clone = x.grad.clone() 811 y.backward(grad_output, create_graph=create_graph) 812 return x_grad, x_grad_clone 813 814 # Accumulate in-place when create_graph is False 815 x_grad, x_grad_clone = compute_grad(create_graph=False) 816 self.assertEqual(x_grad, x_grad_clone * 2) 817 818 # Accumulate out-of-place when create_graph is False 819 x_grad, x_grad_clone = compute_grad(create_graph=True) 820 self.assertEqual(x_grad, x_grad_clone) 821 822 def test_accumulate_grad_tensor_reference(self): 823 def _test_grad_tensor( 824 params_grad_tensor, 825 backward_grad_tensor, 826 should_preserve_reference, 827 create_graph, 828 ): 829 params = torch.tensor([1.5, 1.5]).requires_grad_() 830 params.grad = params_grad_tensor 831 grad_saved = params.grad 832 params.backward(backward_grad_tensor, create_graph=create_graph) 833 self.assertEqual( 834 id(grad_saved) == id(params.grad), should_preserve_reference 835 ) 836 837 for create_graph in (False, True): 838 # Accumulate dense gradient to sparse gradient will change the `params.grad` reference 839 _test_grad_tensor( 840 torch.sparse_coo_tensor( 841 torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0]) 842 ), 843 torch.tensor([1.5, 1.5]), 844 False, # never accumulates in-place 845 create_graph, 846 ) 847 848 # Accumulate dense gradient to dense gradient will preserve the `params.grad` reference, 849 # but only if create_graph=False. 850 _test_grad_tensor( 851 torch.tensor([1.5, 1.5]), 852 torch.tensor([1.5, 1.5]), 853 not create_graph, 854 create_graph, 855 ) 856 857 # Accumulate sparse gradient to sparse gradient will preserve the `params.grad` reference, 858 # but only if create_graph=False. 859 _test_grad_tensor( 860 torch.sparse_coo_tensor( 861 torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0]) 862 ), 863 torch.sparse_coo_tensor( 864 torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0]) 865 ), 866 not create_graph, 867 create_graph, 868 ) 869 870 def test_accumulate_grad_with_zero_numel_grad(self): 871 a = torch.rand(4, 0, requires_grad=True) 872 b = torch.rand(4, 1, requires_grad=True) 873 c = a + b 874 assert c.shape == (4, 0) 875 c.sum().backward() 876 877 self.assertEqual(b.grad, torch.zeros(4, 1)) 878 self.assertEqual(a.grad, torch.zeros(4, 0)) 879 880 def test_hessian_vector(self): 881 x = torch.randn(2, 2, requires_grad=True) 882 y = torch.randn(2, 2, requires_grad=True) 883 884 z = x**2 + y * x + y**2 885 z.backward(torch.ones(2, 2), create_graph=True) 886 887 with torch.no_grad(): 888 x_grad = 2 * x + y 889 y_grad = x + 2 * y 890 self.assertEqual(x.grad, x_grad) 891 self.assertEqual(y.grad, y_grad) 892 893 grad_sum = 2 * x.grad + y.grad 894 grad_sum.backward(torch.ones(2, 2)) 895 x_hv = torch.ones(2, 2) * 5 896 y_hv = torch.ones(2, 2) * 4 897 self.assertEqual(x.grad, x_grad + x_hv) 898 self.assertEqual(y.grad, y_grad + y_hv) 899 900 def test_grad(self): 901 x = torch.randn(2, 2, requires_grad=True) 902 y = torch.randn(2, 2, requires_grad=True) 903 z = x**2 + y * x + y**2 904 z.backward(torch.ones(2, 2), create_graph=True) 905 906 x_grad = 2 * x + y 907 y_grad = x + 2 * y 908 self.assertEqual(x.grad, x_grad) 909 self.assertEqual(y.grad, y_grad) 910 911 grad_sum = 2 * x.grad + y.grad 912 x_hv = torch.autograd.grad( 913 outputs=[grad_sum], 914 grad_outputs=[torch.ones(2, 2)], 915 inputs=[x], 916 create_graph=True, 917 ) 918 expected_x_hv = torch.ones(2, 2) * 5 919 expected_y_hv = torch.ones(2, 2) * 4 920 921 self.assertEqual(x_hv[0], expected_x_hv) 922 self.assertEqual(x.grad, x_grad) 923 self.assertEqual(y.grad, y_grad) 924 925 # Test that grad_outputs and outputs have the same shape 926 grad_out = torch.ones(2) 927 try: 928 torch.autograd.grad( 929 outputs=[grad_sum], 930 grad_outputs=[grad_out], 931 inputs=[x], 932 create_graph=True, 933 ) 934 self.assertFail() 935 except RuntimeError as error: 936 self.assertEqual( 937 str(error), 938 "Mismatch in shape: grad_output[0] has a shape of " 939 + str(grad_out.shape) 940 + " and output[0] has a shape of " 941 + str(grad_sum.shape) 942 + ".", 943 ) 944 945 def test_grad_to_node(self): 946 def check_matches(out, inp): 947 ref = torch.autograd.grad(out.sum(), inp) 948 949 edge = torch.autograd.graph.get_gradient_edge(inp) 950 new = torch.autograd.grad(out.sum(), edge) 951 self.assertEqual(ref, new) 952 953 # We need to ensure that our main types of Node work (regular cpp Nodes, 954 # AccumulateGrad Nodes and custom Function) 955 x = torch.rand(2, requires_grad=True) 956 out = x.clone() 957 check_matches(out, x) 958 959 x = x.clone() 960 out = x.clone() 961 check_matches(out, x) 962 963 x = torch.autograd._functions.Resize.apply(x, (2,)) 964 out = x.clone() 965 check_matches(out, x) 966 967 x = torch.var_mean(x)[1] 968 out = x.clone() 969 check_matches(out, x) 970 971 def test_grad_to_node_set(self): 972 x = torch.rand(2, requires_grad=True) 973 x_edge = torch.autograd.graph.get_gradient_edge(x) 974 out = x.clone() 975 976 with torch.no_grad(): 977 x.set_(torch.rand_like(x)) 978 979 with self.assertRaisesRegex(RuntimeError, "to not have been used in the graph"): 980 torch.autograd.grad(out.sum(), x) 981 982 # Works 983 torch.autograd.grad(out.sum(), x_edge) 984 985 def test_grad_to_node_inplace(self): 986 x = torch.rand(2, requires_grad=True).clone() 987 x_edge = torch.autograd.graph.get_gradient_edge(x) 988 x *= 2 989 990 g_old, g_new = torch.autograd.grad(x.sum(), (x_edge, x)) 991 self.assertEqual(g_old, 2 * torch.ones_like(x)) 992 self.assertEqual(g_new, torch.ones_like(x)) 993 994 def test_grad_to_node_multi(self): 995 x = torch.rand(2, requires_grad=True).clone() 996 y = torch.rand(2, requires_grad=True).clone() 997 998 out = x + y 999 1000 ref = torch.autograd.grad(out.sum(), (x, y)) 1001 1002 inp_edges = ( 1003 GradientEdge(x.grad_fn, x.output_nr), 1004 GradientEdge(y.grad_fn, y.output_nr), 1005 ) 1006 new = torch.autograd.grad(out.sum(), inp_edges) 1007 1008 self.assertEqual(ref, new) 1009 1010 def test_grad_to_node_materialize(self): 1011 x = torch.rand(2, requires_grad=True).clone() 1012 edge_x = GradientEdge(x.grad_fn, x.output_nr) 1013 y = torch.rand(2, requires_grad=True).clone() 1014 edge_y = GradientEdge(y.grad_fn, y.output_nr) 1015 1016 out = x.clone() 1017 1018 # Works 1019 torch.autograd.grad( 1020 out.sum(), (edge_x, y), allow_unused=True, materialize_grads=True 1021 ) 1022 torch.autograd.grad( 1023 out.sum(), (x, y), allow_unused=True, materialize_grads=True 1024 ) 1025 torch.autograd.grad(out.sum(), (x, edge_y), allow_unused=True) 1026 1027 with self.assertRaisesRegex( 1028 RuntimeError, 1029 "materialize_grads cannot be used when the given input is a GradientEdge", 1030 ): 1031 torch.autograd.grad( 1032 out.sum(), (x, edge_y), allow_unused=True, materialize_grads=True 1033 ) 1034 1035 def test_backward_to_node(self): 1036 x = torch.rand(2, requires_grad=True).clone() 1037 edge_x = GradientEdge(x.grad_fn, x.output_nr) 1038 y = torch.rand(2, requires_grad=True).clone() 1039 edge_y = GradientEdge(y.grad_fn, y.output_nr) 1040 1041 out = x.clone() 1042 1043 # All should work in this case 1044 torch.autograd.backward(out.sum(), inputs=(edge_x, y)) 1045 torch.autograd.backward(out.sum(), inputs=(x, y)) 1046 torch.autograd.backward(out.sum(), inputs=(x, edge_y)) 1047 torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y)) 1048 1049 def test_grad_fn_input_metadata(self): 1050 x = torch.rand(2, requires_grad=True, dtype=torch.float32) 1051 y = torch.rand(2, requires_grad=True, dtype=torch.float32) 1052 z = x * y 1053 z_metadata = z.grad_fn._input_metadata[0] 1054 self.assertEqual(z_metadata.shape, (2,)) 1055 self.assertEqual(z_metadata.dtype, torch.float32) 1056 1057 # Multiple outputs 1058 b = torch.rand(3, 3, requires_grad=True) 1059 var, _ = torch.var_mean(b, dim=0) 1060 1061 metadata_0 = var.grad_fn._input_metadata[0] 1062 metadata_1 = var.grad_fn._input_metadata[1] 1063 self.assertEqual(metadata_0.shape, (3,)) 1064 self.assertEqual(metadata_1.shape, (3,)) 1065 1066 # Preserves symints 1067 nt = torch.nested.nested_tensor( 1068 [torch.randn(3, 2), torch.randn(2, 2)], 1069 layout=torch.jagged, 1070 requires_grad=True, 1071 ) 1072 nt_metadata = nt.clone().grad_fn._input_metadata[0] 1073 1074 self.assertIsInstance(nt_metadata.shape[1], torch.SymInt) 1075 self.assertEqual(nt_metadata.shape, nt.shape) 1076 self.assertTrue(nt_metadata.is_nested_tensor) 1077 self.assertFalse(nt_metadata.is_cpp_nested_tensor) 1078 self.assertEqual(nt_metadata.dtype, nt.dtype) 1079 1080 class Test(torch.autograd.Function): 1081 @staticmethod 1082 def forward(ctx, x): 1083 return x 1084 1085 @staticmethod 1086 def backward(ctx, grad_output): 1087 return grad_output 1088 1089 x = torch.randn(3, 3, requires_grad=True) 1090 x = Test.apply(x) 1091 metadata = x.grad_fn._input_metadata[0] 1092 self.assertEqual(metadata.shape, (3, 3)) 1093 1094 def test_gradient_edge_output(self): 1095 x = torch.tensor([1.0, 2.0], requires_grad=True) 1096 1097 def fn(x, reduce=True): 1098 tmp = x.sin().cos() 1099 if reduce: 1100 tmp = tmp.sum() 1101 out = tmp.exp().clone().sin().sum() 1102 tmp_edge = torch.autograd.graph.get_gradient_edge(tmp) 1103 return out, tmp_edge 1104 1105 # Compute fn backward in two steps 1106 out, tmp_edge = fn(x) 1107 (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) 1108 1109 (x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,)) 1110 1111 # Compare with as if we did it in one go. 1112 out, _ = fn(x) 1113 (x_grad_ref,) = torch.autograd.grad(out, (x,)) 1114 self.assertEqual(x_grad, x_grad_ref) 1115 1116 # Incorrect case: grad_outputs not passed/implicitly None and output is 1117 # not a scalar 1118 out, tmp_edge = fn(x, reduce=False) 1119 with self.assertRaisesRegex( 1120 RuntimeError, "grad can be implicitly created only for scalar output" 1121 ): 1122 torch.autograd.grad(tmp_edge, (x,)) 1123 1124 # grad_outputs is None, and output is a scalar is fine 1125 out, tmp_edge = fn(x, reduce=True) 1126 torch.autograd.grad(tmp_edge, (x,)) 1127 1128 # Incorrect case: grad_outputs wrong size 1129 out, tmp_edge = fn(x) 1130 (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) 1131 with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): 1132 torch.autograd.grad( 1133 tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0]) 1134 ) 1135 1136 # Incorrect case: wrong dtype 1137 out, tmp_edge = fn(x) 1138 (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) 1139 with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"): 1140 torch.autograd.grad( 1141 tmp_edge, 1142 (x,), 1143 grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64), 1144 ) 1145 1146 def test_grad_nonleaf(self): 1147 x_init = torch.randn(2, 2, requires_grad=True) 1148 x = x_init 1149 y = torch.randn(2, 2, requires_grad=True) 1150 grad_output = torch.ones(2, 2) 1151 1152 def fn(x): 1153 return x**2 + y * x + y**2 1154 1155 for _ in range(5): 1156 (grad_x,) = torch.autograd.grad( 1157 fn(x), x, grad_outputs=grad_output, create_graph=True 1158 ) 1159 1160 grad_x_expected = 2 * x + y 1161 self.assertIsNone(y.grad) 1162 self.assertIsNone(x.grad) 1163 self.assertEqual(grad_x, grad_x_expected) 1164 1165 x = x + 0.05 * grad_x 1166 1167 val_init = fn(x_init).sum() 1168 val_final = fn(x).sum() 1169 self.assertGreater(val_final, val_init) 1170 1171 x.backward(grad_output) 1172 self.assertIsNotNone(y.grad) 1173 self.assertIsNotNone(x_init.grad) 1174 1175 def test_grad_nonleaf_many_outputs(self): 1176 # This checks an edge case for function callbacks 1177 # We want to capture two grads of a function, but can only 1178 # register a single callback. 1179 x = torch.randn(4, 2, requires_grad=True) 1180 a, b = x.chunk(2) 1181 1182 def hook(*grads): 1183 hook_called[0] = True 1184 1185 hook_called = [False] 1186 x.register_hook(hook) 1187 1188 go = torch.randn(2, 2) 1189 grad_a, grad_b = torch.autograd.grad( 1190 (a + 2 * b), [a, b], grad_outputs=go, create_graph=True 1191 ) 1192 1193 self.assertEqual(grad_a, go) 1194 self.assertEqual(grad_b, go * 2) 1195 self.assertFalse(hook_called[0]) 1196 self.assertIsNone(x.grad) 1197 1198 def test_grad_nonleaf_register_hook(self): 1199 # This checks an edge case for register_hook. 1200 # We want to capture grad of a nonleaf tensor, 1201 # but avoid segfault during backward of other nonleaf tensors 1202 x = torch.randn(5, requires_grad=True) 1203 x_list = x.unbind() 1204 1205 x0 = x_list[0] 1206 hook_results = [None] 1207 1208 def hook(grad): 1209 hook_results[0] = grad 1210 1211 x0.register_hook(hook) 1212 1213 x_list[0].backward() 1214 self.assertEqual(hook_results[0], torch.tensor(1.0)) 1215 expected_grad = torch.tensor([1.0, 0, 0, 0, 0]) 1216 self.assertEqual(x.grad, expected_grad) 1217 self.assertIsNone(x_list[0].grad) 1218 1219 for i in range(1, 5, 1): 1220 x_list[i].backward() 1221 self.assertEqual(hook_results[0], None) 1222 expected_grad[i] = 1.0 1223 self.assertEqual(x.grad, expected_grad) 1224 self.assertIsNone(x_list[i].grad) 1225 1226 def test_grad_materialize_grads(self): 1227 x = torch.tensor(0.5, requires_grad=True) 1228 a = torch.tensor(1.0, requires_grad=True) 1229 y = x * a 1230 dydx = torch.autograd.grad(y, x, create_graph=True) 1231 d2ydx2_none = torch.autograd.grad(dydx, x, create_graph=True, allow_unused=True) 1232 d2ydx2 = torch.autograd.grad( 1233 dydx, x, create_graph=True, allow_unused=True, materialize_grads=True 1234 ) 1235 # `allow_unused` set to True implicitly 1236 d3ydx3 = torch.autograd.grad(d2ydx2, x, materialize_grads=True) 1237 self.assertIsNone(d2ydx2_none[0]) 1238 self.assertEqual(d2ydx2[0].item(), 0) 1239 self.assertEqual(d3ydx3[0].item(), 0) 1240 with self.assertRaisesRegex( 1241 ValueError, "Expected allow_unused to be True or not passed when" 1242 ): 1243 torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True) 1244 1245 def test_post_accumulate_grad_hook_on_non_leaf(self): 1246 def hook(tensor): 1247 tensor.sub_(1.0) 1248 1249 leaf = torch.rand(3, requires_grad=True) 1250 non_leaf = 2.0 * leaf 1251 1252 with self.assertRaisesRegex( 1253 RuntimeError, 1254 "post accumulate grad hooks cannot be registered on non-leaf tensors", 1255 ): 1256 non_leaf.register_post_accumulate_grad_hook(hook) 1257 1258 def test_post_accumulate_grad_hook_multiple_hooks(self): 1259 def hook1(tensor): 1260 tensor.sub_(tensor.grad) 1261 1262 def hook2(tensor): 1263 tensor.mul_(4.0) 1264 1265 tensor = torch.rand(3, requires_grad=True) 1266 tensor_ref = tensor.clone().detach() 1267 tensor.register_post_accumulate_grad_hook(hook1) 1268 tensor.register_post_accumulate_grad_hook(hook2) 1269 sum = tensor.sum() 1270 sum.backward() 1271 # both hooks should be called, in order 1272 self.assertEqual(4.0 * (tensor_ref - 1.0), tensor) 1273 1274 def test_post_accumulate_grad_hook_multiple_tensors(self): 1275 def hook(tensor): 1276 tensor.sub_(tensor.grad) 1277 1278 tensor1 = torch.rand(3, requires_grad=True) 1279 tensor1_ref = tensor1.clone().detach() 1280 tensor2 = torch.rand(5, requires_grad=True) 1281 tensor2_ref = tensor2.clone().detach() 1282 tensor1.register_post_accumulate_grad_hook(hook) 1283 tensor2.register_post_accumulate_grad_hook(hook) 1284 tensor1.sum().backward() 1285 tensor2.sum().backward() 1286 # both tensors should have been modified 1287 self.assertEqual(tensor1_ref - 1.0, tensor1) 1288 self.assertEqual(tensor2_ref - 1.0, tensor2) 1289 1290 def test_post_accumulate_grad_hook_returns_not_None(self): 1291 def bad_hook(tensor): 1292 return tensor.grad 1293 1294 tensor = torch.rand(2, 3, requires_grad=True) 1295 tensor.register_post_accumulate_grad_hook(bad_hook) 1296 # should error! 1297 with self.assertRaisesRegex(RuntimeError, "hooks should return None."): 1298 tensor.sum().backward() 1299 1300 def test_post_accumulate_grad_hook_e2e(self): 1301 def setup_optim_in_bwd(model): 1302 optims = {} 1303 handles = [] 1304 1305 def optim_step_hook(param): 1306 optims[param].step() 1307 optims[param].zero_grad() 1308 1309 for p in model.parameters(): 1310 optims[p] = torch.optim.Adam([p]) 1311 handles.append(p.register_post_accumulate_grad_hook(optim_step_hook)) 1312 1313 return handles 1314 1315 model = torch.nn.Linear(3, 2) 1316 input = torch.rand(2, 3) 1317 handles = setup_optim_in_bwd(model) 1318 1319 # make a copy for reference 1320 model_copy = deepcopy(model) 1321 optim_copy = torch.optim.Adam(model_copy.parameters()) 1322 1323 iters = 5 1324 1325 for _ in range(iters): 1326 loss = model(input).sum() 1327 loss.backward() 1328 1329 loss_copy = model_copy(input).sum() 1330 loss_copy.backward() 1331 optim_copy.step() 1332 optim_copy.zero_grad() 1333 1334 params_copy = [] # freeze a copy of the params to compare later 1335 for p_reference, p in zip(model_copy.parameters(), model.parameters()): 1336 self.assertEqual(p_reference, p) 1337 params_copy.append(p_reference.clone().detach()) 1338 1339 # After removing the handle, the model should no longer update. 1340 for h in handles: 1341 h.remove() 1342 1343 for _ in range(iters): 1344 loss = model(input).sum() 1345 loss.backward() 1346 1347 loss_copy = model_copy(input).sum() 1348 loss_copy.backward() 1349 optim_copy.step() 1350 optim_copy.zero_grad() 1351 1352 for p_static, p_reference, p in zip( 1353 params_copy, model_copy.parameters(), model.parameters() 1354 ): 1355 self.assertEqual(p_static, p) 1356 self.assertNotEqual(p_reference, p) 1357 1358 def test_post_accumulate_grad_hook_gets_cleaned_up(self): 1359 def fun_stuff_with_hook(): 1360 thing_to_put_in_hook = torch.rand(3) 1361 1362 def hook(tensor): 1363 tensor.sub_(tensor.grad) 1364 tensor.add_(thing_to_put_in_hook) 1365 1366 tensor = torch.rand(3, requires_grad=True) 1367 tensor.register_post_accumulate_grad_hook(hook) 1368 tensor.sum().backward() 1369 ref = weakref.ref(thing_to_put_in_hook) 1370 gc.collect() 1371 return tensor, ref 1372 1373 with disable_gc(): 1374 tensor, ref = fun_stuff_with_hook() 1375 self.assertIsNotNone( 1376 ref() 1377 ) # thing_to_put_in_hook should be kept alive by tensor 1378 1379 del tensor 1380 gc.collect() 1381 self.assertIsNone(ref()) # thing_to_put_in_hook should be cleaned 1382 1383 def test_post_accumulate_grad_hook_ordering(self): 1384 tensor = torch.rand(3, requires_grad=True) 1385 1386 def pre_hook(grad): 1387 return grad.sub(2.0) 1388 1389 def acc_grad_node_pre_hook(grad_out): 1390 return (grad_out[0].div(5.0),) 1391 1392 def post_acc_grad_hook(tensor): 1393 tensor.grad.add_(0.5) 1394 1395 def acc_grad_node_post_hook(grad_in, grad_out): 1396 tensor.grad = grad_out[0].mul(10) 1397 1398 acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0] 1399 tensor.register_hook(pre_hook) 1400 acc_grad.register_prehook(acc_grad_node_pre_hook) 1401 tensor.register_post_accumulate_grad_hook(post_acc_grad_hook) 1402 acc_grad.register_hook(acc_grad_node_post_hook) 1403 tensor.sum().backward() 1404 1405 # the hooks should run in the order of: 1406 # 1. tensor prehook 1407 # 2. acc_grad prehook 1408 # 3. tensor post acc_grad hook 1409 # 4. acc_grad posthook 1410 # so that would be ((1 - 2) / 5 + 0.5) * 10 = 3 1411 self.assertEqual(torch.tensor([3.0, 3.0, 3.0]), tensor.grad) 1412 1413 def test_hook_with_no_name(self): 1414 # Create a hook that do not have a __name__ attribute 1415 class MyHookClass: 1416 def __call__(self, grad): 1417 return grad.clone() 1418 1419 x = torch.randn(5, requires_grad=True).clone() 1420 x.register_hook(MyHookClass()) 1421 x.sum().backward() 1422 # Should run fine 1423 1424 def test_prehook_ordering(self): 1425 # Hooks registered to tensor are ordered before those 1426 # that are registered to grad_fn 1427 log = [] 1428 1429 def hook1(g): 1430 log.append(1) 1431 return g * 3 1432 1433 def hook2(gs): 1434 log.append(2) 1435 return tuple(g * 2 for g in gs) 1436 1437 a = torch.tensor(1.0, requires_grad=True) 1438 b = a.clone() 1439 1440 b.grad_fn.register_prehook(hook2) 1441 b.register_hook(hook1) 1442 b.grad_fn.register_prehook(hook2) 1443 1444 acc = b.grad_fn.next_functions[0][0] 1445 a.register_hook(hook1) 1446 acc.register_prehook(hook2) 1447 a.register_hook(hook1) 1448 1449 b.sum().backward(retain_graph=True) 1450 self.assertEqual(log, [1, 2, 2, 1, 1, 2]) 1451 1452 # grad also runs hooks on accumulate grad nodes, even though 1453 # the accumulate grad nodes are not actually executed 1454 log = [] 1455 torch.autograd.grad(b.sum(), inputs=(a,), retain_graph=True) 1456 self.assertEqual(log, [1, 2, 2, 1, 1]) 1457 1458 log = [] 1459 b.sum().backward(inputs=(b,)) 1460 self.assertEqual(log, [1, 2, 2]) 1461 # retains_grad hooks would not observe modifications by all pre hooks 1462 # because they are executed after 1463 self.assertEqual(b.grad.item(), 3) 1464 1465 def test_retains_grad_can_always_observe_tensor_prehook(self): 1466 def tensor_prehook(g): 1467 return g * 2 1468 1469 a = torch.tensor(1.0, requires_grad=True) 1470 b = a.clone() 1471 b.register_hook(tensor_prehook) 1472 b.retain_grad() 1473 b.register_hook(tensor_prehook) 1474 1475 b.clone().backward() 1476 self.assertEqual(b.grad.item(), 4) 1477 1478 a = torch.tensor(1.0, requires_grad=True) 1479 b = a.clone() 1480 b.retain_grad() 1481 b.register_hook(tensor_prehook) 1482 1483 b.clone().backward() 1484 self.assertEqual(b.grad.item(), 2) 1485 1486 def test_accumulate_grad_posthooks_can_observe_tensor_prehook(self): 1487 # Post hooks on accumulate should be able to observe changes to 1488 # grad made by tensor prehooks 1489 a = torch.tensor(1.0, requires_grad=True) 1490 1491 def tensor_prehook(g): 1492 return g * 2 1493 1494 def posthook(gO, gI): 1495 self.assertTrue(torch.allclose(gI[0], a * 2)) 1496 self.assertEqual(len(gO), 0) 1497 1498 def prehook(gI): 1499 self.assertTrue(torch.allclose(gI[0], a * 2)) 1500 self.assertEqual(len(gI), 1) 1501 1502 b = a.clone() 1503 acc = b.grad_fn.next_functions[0][0] 1504 acc.register_hook(posthook) 1505 acc.register_prehook(prehook) 1506 a.register_hook(tensor_prehook) 1507 1508 b.backward() 1509 1510 def test_accumulate_grad_posthooks_should_not_execute(self): 1511 def tensor_prehook(g): 1512 raise RuntimeError 1513 1514 def posthook(gO, gI): 1515 raise RuntimeError 1516 1517 a = torch.tensor(1.0, requires_grad=True) 1518 a.register_hook(tensor_prehook) 1519 b = torch.tensor(1.0, requires_grad=True) 1520 c = a.clone() 1521 acc = c.grad_fn.next_functions[0][0] 1522 acc.register_hook(posthook) 1523 1524 out = a + b + c 1525 out.sum().backward(inputs=[b]) 1526 1527 def test_hook_edge_case_when_called_with_grad(self): 1528 # grad executes the tensor hooks of the next node but not 1529 # grad_fn pre hooks or the post hooks 1530 a = torch.tensor(1.0, requires_grad=True) 1531 b = a * 2 1532 c = b * 2 1533 1534 tensor_hook_count = [0] 1535 prehook_count = [0] 1536 posthook_count = [0] 1537 1538 def reset_counts(): 1539 nonlocal tensor_hook_count, prehook_count, posthook_count 1540 tensor_hook_count = [0] 1541 prehook_count = [0] 1542 posthook_count = [0] 1543 1544 def tensor_prehook(g): 1545 tensor_hook_count[0] += 1 1546 1547 def prehook(g): 1548 prehook_count[0] += 1 1549 1550 def posthook(gI, gO): 1551 posthook_count[0] += 1 1552 1553 a.register_hook(tensor_prehook) 1554 b.register_hook(tensor_prehook) 1555 acc = b.grad_fn.next_functions[0][0] 1556 acc.register_hook(posthook) 1557 acc.register_prehook(prehook) 1558 b.grad_fn.register_hook(posthook) 1559 b.grad_fn.register_prehook(prehook) 1560 1561 torch.autograd.grad(c, inputs=(b), retain_graph=True) 1562 self.assertEqual(tensor_hook_count[0], 1) 1563 self.assertEqual(posthook_count[0], 0) 1564 self.assertEqual(prehook_count[0], 0) 1565 reset_counts() 1566 1567 torch.autograd.grad(c, inputs=(a, b), retain_graph=True) 1568 self.assertEqual(tensor_hook_count[0], 2) 1569 self.assertEqual(posthook_count[0], 1) 1570 self.assertEqual(prehook_count[0], 1) 1571 reset_counts() 1572 1573 c.backward(retain_graph=True) 1574 self.assertEqual(tensor_hook_count[0], 2) 1575 self.assertEqual(posthook_count[0], 2) 1576 self.assertEqual(prehook_count[0], 2) 1577 reset_counts() 1578 1579 c.backward(inputs=(a, b), retain_graph=True) 1580 self.assertEqual(tensor_hook_count[0], 2) 1581 self.assertEqual(posthook_count[0], 2) 1582 self.assertEqual(prehook_count[0], 2) 1583 1584 def test_sharded_grad(self): 1585 leaves = [torch.zeros(5, 5, requires_grad=True) for _ in range(10)] 1586 intermediates = [l * i + l * l for i, l in enumerate(leaves)] 1587 loss = sum(v * i for i, v in enumerate(intermediates)).sum() 1588 1589 # define a helper for dividing intermediates into groups 1590 def group(l, group_size): 1591 return (l[i : i + group_size] for i in range(0, len(l), group_size)) 1592 1593 # Compute the d loss / d intermediates in chunks of shard_size 1594 shard_size = 2 1595 d_intermediates = [ 1596 d_i 1597 for intermediates_batch in group(intermediates, shard_size) 1598 for d_i in torch.autograd.grad(loss, intermediates_batch) 1599 ] 1600 # Compute rest of backward pass 1601 torch.autograd.backward(intermediates, d_intermediates) 1602 1603 for i, l in enumerate(leaves): 1604 self.assertEqual(l.grad, i * i * (1 + l)) 1605 1606 def test_backward_badcalls(self): 1607 x = torch.ones(1) 1608 with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1609 x.backward() 1610 1611 def test_grad_badcalls(self): 1612 x = torch.ones(1) 1613 y = x**2 1614 with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1615 torch.autograd.grad(x, y) 1616 with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1617 torch.autograd.grad(y, x) 1618 1619 x = torch.ones(1, requires_grad=True) 1620 y = x**2 1621 torch.autograd.grad(y, x) # this should succeed now 1622 1623 def test_grad_empty_inputs(self): 1624 x = torch.tensor([1.0], requires_grad=True) 1625 with self.assertRaisesRegex(ValueError, "grad requires non-empty inputs."): 1626 torch.autograd.grad(2 * x, [], grad_outputs=torch.tensor([1.0])) 1627 1628 def test_grad_fn_badcalls(self): 1629 error_regex = "expected .* arguments, got .* instead" 1630 x = torch.ones(1, requires_grad=True) 1631 y = x**2 1632 with self.assertRaisesRegex(TypeError, error_regex): 1633 y.grad_fn(x.detach(), x.detach()) # too many 1634 with self.assertRaisesRegex(TypeError, error_regex): 1635 y.grad_fn() # too few 1636 1637 y.grad_fn(x.detach()) # this should succeed 1638 1639 def test_grad_unreachable(self): 1640 x = torch.ones(1, requires_grad=True) 1641 y = torch.ones(1, requires_grad=True) 1642 # Make sure x and y have grad accumulators allocated 1643 z = x * 2 1644 w = y * 2 1645 1646 grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=True) 1647 self.assertEqual(grad_x, x * 2) 1648 self.assertIsNone(grad_y) 1649 1650 # This is slightly different than the case above, because z doesn't even 1651 # have a grad accumulator allocated. 1652 z = torch.ones(1, requires_grad=True) 1653 grad_x, grad_z = torch.autograd.grad(x * 2, [x, z], allow_unused=True) 1654 self.assertEqual(grad_x, x * 2) 1655 self.assertIsNone(grad_z) 1656 1657 # allow_unused=False, but grads contains None inside, should throw 1658 with self.assertRaisesRegex(RuntimeError, "Set allow_unused=True"): 1659 grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=False) 1660 1661 def test_grad_unreachable_discovery(self): 1662 # Test that certain nodes are not erroneously executed when an input 1663 # is unreachable. See #39784 1664 class MyFunc(torch.autograd.Function): 1665 @staticmethod 1666 def forward(ctx, x): 1667 return x 1668 1669 @staticmethod 1670 def backward(ctx, x): 1671 self.fail("This node should not be executed!") 1672 1673 x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2) 1674 y = torch.randn(1, requires_grad=True) 1675 (gY,) = torch.autograd.grad(x, (y,), allow_unused=True) 1676 self.assertIsNone(gY) 1677 1678 x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2) 1679 y = torch.randn(1, requires_grad=True) 1680 z = torch.randn(1, requires_grad=True) 1681 (gY, gZ) = torch.autograd.grad(x + z, (y, z), allow_unused=True) 1682 self.assertIsNone(gY) 1683 self.assertIsNotNone(gZ) 1684 1685 x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2) 1686 y = torch.randn(1, requires_grad=True) 1687 torch.autograd.backward(x, inputs=(y,)) # allow_unused is implicitly True! 1688 self.assertIsNone(y.grad) 1689 1690 def test_grad_batched_grad(self): 1691 x = torch.randn(2, 2, requires_grad=True) 1692 1693 out = x.clone() # Size([2, 2]) 1694 batched_grad = ( 1695 torch.arange(3).expand(2, 2, 3).transpose(0, 2) 1696 ) # Size([3, 2, 2]) 1697 (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True) 1698 self.assertEqual( 1699 grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype) 1700 ) 1701 1702 # Detect shape mismatch 1703 grad_out = torch.ones(2, 2) 1704 with self.assertRaisesRegex( 1705 RuntimeError, "If `is_grads_batched=True`, we interpret the first" 1706 ): 1707 torch.autograd.grad( 1708 outputs=out, 1709 grad_outputs=(grad_out,), 1710 inputs=(x,), 1711 is_grads_batched=True, 1712 ) 1713 1714 # Scalar outputs 1715 out = x.sum() # Size([]) 1716 batched_grad = torch.arange(3) # Size([3]) 1717 (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True) 1718 self.assertEqual( 1719 grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype) 1720 ) 1721 1722 # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior. 1723 grad_out = torch.ones(2).unsqueeze(1) 1724 with self.assertRaisesRegex( 1725 RuntimeError, "If `is_grads_batched=True`, we interpret the first" 1726 ): 1727 torch.autograd.grad( 1728 outputs=out, 1729 grad_outputs=(grad_out,), 1730 inputs=(x,), 1731 is_grads_batched=True, 1732 ) 1733 1734 def test_hooks(self): 1735 x = torch.ones(5, 5, requires_grad=True) 1736 y = torch.ones(5, 5) * 4 1737 y.requires_grad_(True) 1738 1739 counter = [0] 1740 1741 def bw_hook(inc, grad): 1742 self.assertIsInstance(grad, torch.Tensor) 1743 counter[0] += inc 1744 1745 z = x**2 + x * 2 + x * y + y 1746 x.register_hook(lambda *args: bw_hook(0, *args)) 1747 test = z.register_hook(lambda *args: bw_hook(1, *args)) 1748 z.backward(torch.ones(5, 5), retain_graph=True) 1749 self.assertEqual(counter[0], 1) 1750 1751 test2 = z.register_hook(lambda *args: bw_hook(2, *args)) 1752 z.backward(torch.ones(5, 5), retain_graph=True) 1753 self.assertEqual(counter[0], 4) 1754 1755 test2.remove() 1756 z.backward(torch.ones(5, 5), retain_graph=True) 1757 self.assertEqual(counter[0], 5) 1758 1759 def bw_hook_modify(grad): 1760 return grad.mul(2) 1761 1762 test.remove() 1763 z.register_hook(bw_hook_modify) 1764 with torch.no_grad(): 1765 y.grad.zero_() 1766 z.backward(torch.ones(5, 5), retain_graph=True) 1767 self.assertEqual(y.grad, (x + 1) * 2) 1768 1769 y.register_hook(bw_hook_modify) 1770 with torch.no_grad(): 1771 y.grad.zero_() 1772 z.backward(torch.ones(5, 5)) 1773 self.assertEqual(y.grad, (x + 1) * 4) 1774 1775 def _get_mul2(self, use_custom_function): 1776 if use_custom_function: 1777 1778 class Mul2(Function): 1779 @staticmethod 1780 def forward(ctx, x): 1781 return x * 2 1782 1783 @staticmethod 1784 def backward(ctx, gO): 1785 return gO * 2 1786 1787 return Mul2.apply 1788 else: 1789 return lambda x: x * 2 1790 1791 def test_grad_fn_prehooks(self): 1792 for use_custom_function in (True, False): 1793 mul2 = self._get_mul2(use_custom_function) 1794 1795 a = torch.tensor([1.0], requires_grad=True) 1796 b = mul2(a) 1797 1798 post_counter = [0] 1799 pre_counter = [0] 1800 1801 def posthook(grad_input, grad_output): 1802 self.assertEqual(pre_counter[0], 3) 1803 self.assertTrue(torch.allclose(grad_output[0], torch.ones(1) * 8)) 1804 self.assertTrue(torch.allclose(grad_input[0], torch.ones(1) * 16)) 1805 post_counter[0] += 1 1806 return grad_input 1807 1808 def prehook(grad_output): 1809 pre_counter[0] += 1 1810 return (grad_output[0] * 2,) 1811 1812 # register posthook x 2 1813 b.grad_fn.register_hook(posthook) 1814 b.grad_fn.register_hook(posthook) 1815 # register prehook x 3 1816 b.grad_fn.register_prehook(prehook) 1817 b.grad_fn.register_prehook(lambda x: None) 1818 b.grad_fn.register_prehook(prehook) 1819 b.grad_fn.register_prehook(prehook) 1820 b.grad_fn.register_prehook(lambda x: x) 1821 b.grad_fn.register_prehook(lambda x: None) 1822 1823 b.sum().backward() 1824 1825 self.assertEqual(post_counter[0], 2) 1826 self.assertEqual(pre_counter[0], 3) 1827 1828 # Return None 1829 a = torch.rand(3, 3, requires_grad=True) 1830 b = mul2(a) 1831 1832 def prehook(grad_output): 1833 pre_counter[0] += 1 1834 return None 1835 1836 b.grad_fn.register_prehook(prehook) 1837 b.sum().backward() 1838 self.assertEqual(pre_counter[0], 4) 1839 self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) 1840 1841 def test_grad_fn_prehooks_multiple_outputs(self): 1842 # Compute gradients without hooks 1843 b = torch.rand(3, 3, requires_grad=True) 1844 var, mean = torch.var_mean(b, dim=0) 1845 (var + mean).sum().backward() 1846 1847 # Compute gradients with hooks 1848 a = b.detach().requires_grad_() 1849 counter = [0] 1850 1851 def prehook(grad_output): 1852 gvar, gmean = grad_output 1853 counter[0] += 1 1854 return (gvar * 2, gmean * 2) 1855 1856 var, mean = torch.var_mean(a, dim=0) 1857 mean.grad_fn.register_prehook(prehook) 1858 (var + mean).sum().backward() 1859 1860 self.assertEqual(counter[0], 1) 1861 # Compare 1862 self.assertTrue(torch.allclose(a.grad, b.grad * 2)) 1863 1864 # Test with custom Function 1865 class DoubleMul2(Function): 1866 @staticmethod 1867 def forward(ctx, x, a, y): 1868 ctx.a = a 1869 return a * x * 2, a, a * y * 2 1870 1871 @staticmethod 1872 def backward(ctx, g1, _a, g2): 1873 return ctx.a * g1 * 2, None, ctx.a * g2 * 2 1874 1875 counter = [0] 1876 1877 def prehook(grad_output): 1878 g1, ga, g2 = grad_output 1879 self.assertIsNone(ga) 1880 counter[0] += 1 1881 return (g1 * 2, None, g2 * 2) 1882 1883 a = torch.randn(3, 3, requires_grad=True) 1884 b = torch.randn(3, 3, requires_grad=True) 1885 k = 3 1886 c, _, d = DoubleMul2.apply(a, k, b) 1887 c.grad_fn.register_prehook(prehook) 1888 (c + d).sum().backward() 1889 1890 self.assertEqual(counter[0], 1) 1891 self.assertTrue(torch.allclose(a.grad, torch.ones(1) * 4 * k)) 1892 self.assertTrue(torch.allclose(b.grad, torch.ones(1) * 4 * k)) 1893 1894 def test_grad_fn_prehooks_remove_hooks(self): 1895 for use_custom_function in (True, False): 1896 mul2 = self._get_mul2(use_custom_function) 1897 1898 # Simply remove hooks 1899 1900 a = torch.rand(3, 3, requires_grad=True) 1901 b = mul2(a) 1902 counter = [0] 1903 1904 def prehook(grad_output): 1905 counter[0] += 1 1906 return None 1907 1908 handle = b.grad_fn.register_prehook(prehook) 1909 b.grad_fn.register_prehook(prehook) 1910 handle.remove() 1911 b.sum().backward() 1912 self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) 1913 self.assertEqual(counter[0], 1) 1914 1915 # Remove hooks during backward 1916 a = torch.rand(3, 3, requires_grad=True) 1917 b = mul2(a) 1918 counter = [0] 1919 1920 def prehook1(grad_output): 1921 handle2.remove() 1922 # Remove hook that is already removed is OK 1923 handle3.remove() 1924 return None 1925 1926 def prehook2(grad_output): 1927 counter[0] += 1 1928 return None 1929 1930 # Hooks that registered first run first 1931 b.grad_fn.register_prehook(prehook1) 1932 handle2 = b.grad_fn.register_prehook(prehook2) 1933 handle3 = b.grad_fn.register_prehook(prehook2) 1934 handle3.remove() 1935 b.sum().backward() 1936 self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) 1937 self.assertEqual(counter[0], 1) 1938 1939 def test_node_post_hook_registered_during_unpack_hook(self): 1940 """ 1941 Test that post hooks registered during one of the node's 1942 unpack hooks are properly restricted and will run properly. 1943 """ 1944 test_case = self 1945 1946 class RegisterPostNodeHook(torch.autograd.graph.saved_tensors_hooks): 1947 def __init__(self) -> None: 1948 def pack_tensor(tensor: torch.Tensor) -> torch.Tensor: 1949 return tensor 1950 1951 def unpack_tensor(tensor: torch.Tensor) -> torch.Tensor: 1952 node = torch._C._current_autograd_node() 1953 1954 def hook(outputs, inputs): 1955 # Assert that inputs passed in are None 1956 test_case.assertTrue(all(i is None for i in inputs)) 1957 halved_outputs = tuple( 1958 o / 2.0 if o is not None else None for o in outputs 1959 ) 1960 return halved_outputs 1961 1962 node.register_hook(hook) 1963 return tensor 1964 1965 super().__init__(pack_tensor, unpack_tensor) 1966 1967 a = torch.rand(3, 3, requires_grad=True) 1968 1969 def model(): 1970 var, mean = torch.var_mean(a, dim=0) 1971 loss = (var + mean).sum() 1972 loss.backward() 1973 1974 model() 1975 ref_grad = a.grad.clone() 1976 1977 with RegisterPostNodeHook(): 1978 model() 1979 1980 # Verify that the post hook got called and the grad propagation worked 1981 self.assertEqual(ref_grad / 2.0 + ref_grad, a.grad) 1982 1983 def test_hooks_cpp(self): 1984 # Tests hooks for autograd function implemented in C++ 1985 bn = torch.nn.BatchNorm1d(5, affine=False) 1986 bn.double() 1987 bn.eval() 1988 1989 counter = [0] 1990 1991 def bw_hook(grad): 1992 counter[0] += 1 1993 return grad * 2 1994 1995 x = torch.ones(5, 5, dtype=torch.double, requires_grad=True) 1996 z = bn(x) 1997 z.register_hook(bw_hook) 1998 z.sum().backward() 1999 2000 self.assertEqual(counter[0], 1, msg="bw_hook not called") 2001 self.assertEqual( 2002 x.grad, torch.ones(5, 5, dtype=torch.double) * 2, atol=1e-5, rtol=0 2003 ) 2004 2005 def test_hook_none(self): 2006 # WARNING: this is a test for autograd internals. 2007 # You should never have to use such things in your code. 2008 class NoneGradientFunction(Function): 2009 @staticmethod 2010 def forward(ctx, x, y): 2011 assert ctx.needs_input_grad[0] 2012 assert not ctx.needs_input_grad[1] 2013 return x, y 2014 2015 @staticmethod 2016 def backward(ctx, grad_x, grad_y): 2017 return grad_x, None 2018 2019 was_called = [False] 2020 2021 def hook(grad): 2022 self.assertIsNotNone(grad) 2023 was_called[0] = True 2024 2025 x = torch.randn(5, 5, requires_grad=True) 2026 y = torch.randn(5, 5) 2027 rx, ry = NoneGradientFunction.apply(x, y) 2028 rx.register_hook(hook) 2029 ry.register_hook(hook) 2030 sum(rx, ry).sum().backward() 2031 self.assertTrue(was_called[0]) 2032 2033 def test_retain_grad(self): 2034 input = torch.rand(1, 3, requires_grad=True) 2035 h1 = input * 3 2036 out = (h1 * h1).sum() 2037 2038 # It should be possible to call retain_grad() multiple times 2039 h1.retain_grad() 2040 h1.retain_grad() 2041 2042 # Gradient should be accumulated 2043 out.backward(retain_graph=True) 2044 self.assertEqual(h1 * 2, h1.grad) 2045 out.backward(retain_graph=True) 2046 self.assertEqual(h1 * 4, h1.grad) 2047 2048 with torch.no_grad(): 2049 input.grad.zero_() 2050 # It should be a no-op for leaves 2051 input.retain_grad() 2052 input.retain_grad() 2053 out.backward() 2054 self.assertEqual(input * 18, input.grad) 2055 2056 # NB: See test/cpp/api/autograd.cpp for more tests on the interaction between 2057 # retains_grad and hooks in cpp 2058 def test_retain_grad_inplace(self): 2059 a = torch.tensor([1.0], requires_grad=True).clone() 2060 a.retain_grad() 2061 a.mul_(2) 2062 a.sum().backward() 2063 self.assertEqual(a.grad, torch.tensor([1.0])) 2064 2065 a = torch.tensor([1.0], requires_grad=True).clone() 2066 a.retain_grad() 2067 # Inplace multiple times is OK 2068 a.mul_(2) 2069 a.mul_(2) 2070 a.sum().backward() 2071 self.assertEqual(a.grad, torch.tensor([1.0])) 2072 2073 # When in-place over view is done, the retains_grad hooks should be 2074 # moved from base's original grad_fn to the copyslices node. 2075 x = torch.tensor([1.0], requires_grad=True).clone() 2076 x.retain_grad() 2077 x_view = x[:] 2078 x_view *= 2 2079 x *= 2 2080 x.sum().backward() 2081 # The grad is 1, not 4, because we are computing grad wrt the latest 2082 # version of x. 2083 self.assertEqual(a.grad, torch.tensor([1.0])) 2084 2085 # If the base did not originally require grad, there should be no hook 2086 # to move. Make sure this case runs without error. 2087 x = torch.zeros(4) 2088 y = x.view(2, 2) 2089 y.add_(torch.randn(2, 2, requires_grad=True)) 2090 2091 def test_retains_grad_inplace_multiple_outputs(self): 2092 class DoubleMul(Function): 2093 @staticmethod 2094 def forward(ctx, x): 2095 return x * 2, x * 3 2096 2097 @staticmethod 2098 def backward(ctx, g1, g2): 2099 return g1 * 2 + g2 * 3 2100 2101 var_mean = partial(torch.var_mean, dim=0) 2102 2103 for fn in (DoubleMul.apply, var_mean): 2104 b = torch.rand(3, 3, requires_grad=True) 2105 var, mean = fn(b) 2106 var.retain_grad() 2107 mean.retain_grad() 2108 # node has two retains_grad hooks 2109 var.mul_(2) 2110 # the retain_grad hook multi-output node refers should now be a nullptr 2111 (var + mean).sum().backward() 2112 gvar = var.grad 2113 gmean = mean.grad 2114 2115 a = b.detach().requires_grad_(True) 2116 var, mean = fn(a) 2117 var.mul_(2) 2118 out = (var + mean).sum() 2119 gvar_expected, gmean_expected = torch.autograd.grad(out, inputs=(var, mean)) 2120 self.assertTrue(torch.allclose(gvar, gvar_expected)) 2121 self.assertTrue(torch.allclose(gmean, gmean_expected)) 2122 2123 def test_retain_grad_inplace_over_view(self): 2124 base = torch.tensor([1.0], requires_grad=True).clone() 2125 view = base[:] 2126 view2 = base[:] 2127 view.retain_grad() 2128 view2.retain_grad() 2129 view.mul_(2) 2130 (view + view2).sum().backward() 2131 2132 # The old grad_fn, slice, wouldn't be part of the graph during backward 2133 # so if the retains grad were not properly updated to the new grad_fn, 2134 # the grad would still be None 2135 self.assertEqual(view.grad, view2.grad) 2136 self.assertEqual(view.grad, torch.tensor([1.0])) 2137 2138 def test_tensor_hooks_inplace(self): 2139 # Check that the second hook gets registered to the new version of tensor 2140 count1 = [0] 2141 count2 = [0] 2142 2143 def fn1(grad): 2144 count1[0] += 1 2145 # x2 from mul, x2 from fn2 2146 self.assertEqual(grad, torch.tensor([4.0])) 2147 return grad * 2 2148 2149 def fn2(grad): 2150 count2[0] += 1 2151 self.assertEqual(grad, torch.tensor([1.0])) 2152 return grad * 2 2153 2154 a = torch.tensor([1.0], requires_grad=True) 2155 b = a.clone() 2156 b.register_hook(fn1) 2157 b.mul_(2) 2158 b.register_hook(fn2) 2159 b.sum().backward() 2160 self.assertEqual(count1[0], 1) 2161 self.assertEqual(count2[0], 1) 2162 self.assertEqual(a.grad, torch.tensor([8.0])) 2163 2164 count3 = [0] 2165 2166 def fn3(grad): 2167 count3[0] += 1 2168 self.assertEqual(grad, torch.tensor([4.0])) 2169 return grad * 2 2170 2171 a = torch.tensor([1.0], requires_grad=True) 2172 b = a.clone() 2173 b.register_hook(fn3) 2174 # Inplace multiple times is OK 2175 b.mul_(2) 2176 b.mul_(2) 2177 b.sum().backward() 2178 self.assertEqual(count1[0], 1) 2179 self.assertEqual(a.grad, torch.tensor([8.0])) 2180 2181 def test_tensor_hooks_inplace_multiple_outputs(self): 2182 class DoubleMul(Function): 2183 @staticmethod 2184 def forward(ctx, x): 2185 return x * 2, x * 3 2186 2187 @staticmethod 2188 def backward(ctx, g1, g2): 2189 return g1 * 2 + g2 * 3 2190 2191 var_mean = partial(torch.var_mean, dim=0) 2192 2193 for fn in (DoubleMul.apply, var_mean): 2194 counts = [0, 0, 0] 2195 2196 def fn0(grad): 2197 counts[0] += 1 2198 self.assertEqual(grad, torch.ones_like(out1) * 2) 2199 2200 def fn1(grad): 2201 counts[1] += 1 2202 self.assertEqual(grad, torch.ones_like(out1) * 3) 2203 2204 def fn2(grad): 2205 counts[2] += 1 2206 self.assertEqual(grad, torch.ones_like(out1)) 2207 2208 b = torch.rand(3, 3, requires_grad=True) 2209 out1, out2 = fn(b) 2210 out1.register_hook(fn0) 2211 out2.register_hook(fn1) 2212 # node refers to two hook dicts 2213 # out1 no longer no longer points to its old hook dict 2214 out1.mul_(2) 2215 # fn2 is registered to out1's new hook dict 2216 out1.register_hook(fn2) 2217 (out1 + out2 * 3).sum().backward() 2218 self.assertEqual(counts, [1, 1, 1]) 2219 2220 def test_tensor_hooks_inplace_over_view(self): 2221 # There might be a better UX here, but this is the way it is now 2222 count = [0] 2223 2224 def fn0(grad): 2225 self.fail() 2226 2227 def fn1(grad): 2228 self.fail() 2229 2230 def fn2(grad): 2231 count[0] += 1 2232 self.assertEqual(grad, torch.tensor([1.0])) 2233 2234 base = torch.tensor([1.0], requires_grad=True).clone() 2235 view = base[:] 2236 view2 = base[:] 2237 view.register_hook(fn0) 2238 view2.register_hook(fn1) 2239 view.mul_(2) 2240 # We need to explicitly trigger an update to view to update its grad_fn 2241 view2.grad_fn 2242 view2.register_hook(fn2) 2243 (view + view2).sum().backward() 2244 # The hooks originally registered to view are not fired, one must explicitly 2245 # trigger an update to the view's grad_fn, and then register a new hook 2246 self.assertEqual(count[0], 1) 2247 2248 def test_retain_grad_cycle(self): 2249 x = torch.ones(5, 5, requires_grad=True) 2250 2251 def run_test(): 2252 y = x * 2 2253 y.retain_grad() 2254 2255 return y / 2, torch._C._WeakTensorRef(y) 2256 2257 z, ref = run_test() 2258 self.assertTrue(ref.expired()) 2259 z.sum().backward() 2260 2261 def test_backward(self): 2262 v = torch.randn(5, 5, requires_grad=True) 2263 x = torch.randn(5, 5, requires_grad=True) 2264 y = (torch.rand(5, 5) + 0.1).requires_grad_(True) 2265 z = torch.randn(5, 5, requires_grad=True) 2266 grad_output = torch.randn(5, 5) 2267 2268 v.backward(grad_output) 2269 self.assertEqual(v.grad, grad_output) 2270 2271 a = x + (y * z) + 4 * z**2 * x / y 2272 a.backward(grad_output) 2273 x_grad = 4 * z.pow(2) / y + 1 2274 y_grad = z - 4 * x * z.pow(2) / y.pow(2) 2275 z_grad = 8 * x * z / y + y 2276 self.assertEqual(x.grad, x_grad * grad_output) 2277 self.assertEqual(y.grad, y_grad * grad_output) 2278 self.assertEqual(z.grad, z_grad * grad_output) 2279 2280 def test_to_sparse_backward(self): 2281 to_attr_names = ( 2282 "to_dense", 2283 "to_sparse", 2284 "to_sparse_csr", 2285 "to_sparse_csc", 2286 "to_sparse_bsr", 2287 "to_sparse_bsc", 2288 ) 2289 to_params = ((), (), (), (), (2,), (2,)) 2290 to_attr_names_params = dict(zip(to_attr_names, to_params)) 2291 2292 def check_inversion_possible( 2293 t, layout1, layout1_params, layout2, layout2_params 2294 ): 2295 l = (layout1, layout2) 2296 p = (layout1_params, layout2_params) 2297 for l1, l2, p1, p2 in ((*l, *p), (*l[::-1], *p[::-1])): 2298 try: 2299 to_l1 = getattr(t, l1)(*p1) 2300 to_l2 = getattr(to_l1, l2)(*p2) 2301 except RuntimeError: 2302 return False 2303 2304 return True 2305 2306 self_strided = torch.rand(4, 4, dtype=torch.double) + 1 2307 grad_strided = torch.rand(4, 4, dtype=torch.double) + 1 2308 2309 for from_to_attr in to_attr_names: 2310 from_params = to_attr_names_params[from_to_attr] 2311 self_from = getattr(self_strided, from_to_attr)( 2312 *from_params 2313 ).requires_grad_(True) 2314 2315 for to_to_attr in to_attr_names[1:]: 2316 to_params = to_attr_names_params[to_to_attr] 2317 2318 if check_inversion_possible( 2319 self_strided, from_to_attr, from_params, to_to_attr, to_params 2320 ): 2321 self_to = getattr(self_from, to_to_attr)(*to_params) 2322 grad_to = getattr(grad_strided, to_to_attr)(*to_params) 2323 2324 # No gradcheck support for BSR/BSC, so the grads are checked explicitly 2325 grad_res = torch.autograd.grad(self_to, self_from, grad_to)[0] 2326 2327 self.assertEqual(grad_res.layout, self_from.layout) 2328 self.assertEqual(grad_res.to_dense(), grad_strided) 2329 2330 def test_sparse_mm_backward(self): 2331 size = (3, 3) 2332 2333 mm_test_cases = product(*(([False, True],) * 4)) 2334 2335 for a_req_grad, a_is_sparse, b_req_grad, b_is_sparse in mm_test_cases: 2336 # We should only be testing cases with sparse inputs, and at least one 2337 # input needs to require grad so we can call a backward pass 2338 if not ((a_is_sparse or b_is_sparse) and (a_req_grad or b_req_grad)): 2339 continue 2340 a = torch.randn(size) 2341 if a_is_sparse: 2342 # detaching as `a` needs to be a leaf 2343 a = a.to_sparse().detach() 2344 b = torch.randn(size) 2345 if b_is_sparse: 2346 # detaching as `b` needs to be a leaf 2347 b = b.to_sparse().detach() 2348 2349 a = a.requires_grad_(a_req_grad) 2350 b = b.requires_grad_(b_req_grad) 2351 2352 r = a.mm(b) 2353 s = r.sum().backward() 2354 a_grad = None if a.grad is None else a.grad.clone().detach() 2355 b_grad = None if b.grad is None else b.grad.clone().detach() 2356 2357 # Redo with only dense tensors 2358 a = ( 2359 (a.to_dense() if a.is_sparse else a) 2360 .clone() 2361 .detach() 2362 .requires_grad_(a_req_grad) 2363 ) 2364 b = ( 2365 (b.to_dense() if b.is_sparse else b) 2366 .clone() 2367 .detach() 2368 .requires_grad_(b_req_grad) 2369 ) 2370 2371 r = a.mm(b) 2372 r.sum().backward() 2373 2374 self.assertEqual(a_grad, a.grad) 2375 self.assertEqual(b_grad, b.grad) 2376 2377 def test_multi_backward(self): 2378 x = torch.randn(5, 5, requires_grad=True) 2379 y = torch.randn(5, 5, requires_grad=True) 2380 2381 q = torch.randn(5, 5, requires_grad=True) 2382 2383 a = torch.randn(5, 5, requires_grad=True) 2384 b = torch.randn(5, 5, requires_grad=True) 2385 2386 q2 = q * 2 2387 z = x + y + q2 2388 c = a * b + q2 2389 grad_z = torch.randn(5, 5) 2390 grad_c = torch.randn(5, 5) 2391 torch.autograd.backward([z, c], [grad_z, grad_c]) 2392 2393 self.assertEqual(x.grad, grad_z) 2394 self.assertEqual(y.grad, grad_z) 2395 self.assertEqual(a.grad, grad_c * b) 2396 self.assertEqual(b.grad, grad_c * a) 2397 self.assertEqual(q.grad, (grad_c + grad_z) * 2) 2398 2399 def test_multi_backward_no_grad(self): 2400 x = torch.randn(5, 5, requires_grad=True) 2401 y = torch.randn(5, 5, requires_grad=False) 2402 2403 z = x + y 2404 q = y * 2 2405 2406 # NB: we currently raise an exception if any arguments to backwards 2407 # have requires_grad=False and don't have a grad_fn. We may want to 2408 # relax that check to a warning. 2409 def call_backwards(): 2410 torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)]) 2411 2412 self.assertRaises(RuntimeError, call_backwards) 2413 2414 def test_backward_with_inputs(self): 2415 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2416 y = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2417 2418 def fn(): 2419 return x**2 + y * x + y**2 2420 2421 gradient = torch.ones(2, 2) 2422 x_grad_expected = 2 * x + y 2423 y_grad_expected = x + 2 * y 2424 2425 @torch.no_grad() 2426 def reset_grad(): 2427 x.grad.zero_() 2428 y.grad.zero_() 2429 2430 torch.autograd.backward(fn(), gradient, inputs=[x, y]) 2431 self.assertEqual(x.grad, x_grad_expected) 2432 self.assertEqual(y.grad, y_grad_expected) 2433 2434 reset_grad() 2435 torch.autograd.backward(fn(), gradient, inputs=[x]) 2436 self.assertEqual(x.grad, x_grad_expected) 2437 self.assertEqual(y.grad, torch.zeros(2, 2), exact_dtype=False) 2438 2439 reset_grad() 2440 torch.autograd.backward(fn(), gradient, inputs=[y]) 2441 self.assertEqual(y.grad, y_grad_expected) 2442 self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False) 2443 2444 reset_grad() 2445 torch.autograd.backward(fn(), gradient, inputs=y) 2446 self.assertEqual(y.grad, y_grad_expected) 2447 self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False) 2448 2449 reset_grad() 2450 self.assertRaisesRegex( 2451 RuntimeError, 2452 "cannot be empty", 2453 lambda: torch.autograd.backward(fn(), gradient, inputs=[]), 2454 ) 2455 2456 def test_backward_with_nonleaf_inputs(self): 2457 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2458 x_nonleaf = x * 1 2459 y = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2460 z = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2461 2462 out = x_nonleaf**2 + y * x_nonleaf + y**2 2463 2464 out.backward( 2465 torch.ones(2, 2, dtype=torch.double), 2466 create_graph=True, 2467 inputs=[x, y, x_nonleaf], 2468 ) 2469 x_grad_expected = 2 * x + y 2470 y_grad_expected = x + 2 * y 2471 x_non_leaf_expected = 2 * x_nonleaf + y 2472 2473 self.assertEqual(y.grad, y_grad_expected) 2474 self.assertEqual(x.grad, x_grad_expected) 2475 self.assertEqual(x_nonleaf.grad, x_non_leaf_expected) 2476 2477 # backward doesn't have an allow_unused flag, so the behavior of backward 2478 # when variable is not part of the graph is as if allow_used were true 2479 # x.grad will simply be None. 2480 out.backward( 2481 torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[z] 2482 ) 2483 self.assertIsNone(z.grad) 2484 2485 def test_dependent_backward(self): 2486 x = torch.randn(10, requires_grad=True) 2487 y = x**2 2488 z = y**3 2489 2490 go_y = torch.randn(10) 2491 go_z = torch.randn(10) 2492 torch.autograd.backward([y, z], [go_y, go_z]) 2493 2494 xd = x 2495 self.assertEqual(x.grad, 2 * xd * go_y + 6 * xd.pow(5) * go_z) 2496 2497 def test_save_output_nr(self): 2498 x = torch.randn(10, requires_grad=True) 2499 2500 class MultiOutputFn(Function): 2501 @staticmethod 2502 def forward(ctx, x): 2503 return x[:5], x[5:] 2504 2505 @staticmethod 2506 def backward(ctx, *grad): 2507 return torch.cat(grad) 2508 2509 a, b = MultiOutputFn.apply(x) 2510 self.assertEqual(b.output_nr, 1) 2511 2512 class TestFn(Function): 2513 @staticmethod 2514 def forward(ctx, b): 2515 ctx.save_for_backward(b) 2516 return b * 2 2517 2518 @staticmethod 2519 def backward(ctx, grad_b): 2520 (b,) = ctx.saved_tensors 2521 self.assertEqual(b.output_nr, 1) 2522 2523 TestFn.apply(b).sum().backward() 2524 2525 def test_first_grad_fn_access_in_no_grad_mode(self): 2526 a = torch.tensor([1 + 1j], requires_grad=True).clone() 2527 v = a.real 2528 a.add_(1) 2529 with torch.autograd.grad_mode.no_grad(): 2530 v.grad_fn 2531 2532 @skipIfTorchDynamo("too slow") 2533 def test_free_deep_graph(self): 2534 def scope(): 2535 depth = 150000 2536 x = torch.randn(1, requires_grad=True) 2537 y = x.clone() 2538 2539 # build a "chain" computation graph 2540 for _ in range(depth): 2541 y = y + y * 0.000001 2542 2543 # graph deletion occurs when the above locals go out of scope. 2544 # In this case `del y` will trigger it but it's easier to leave 2545 # it to Python to delete the locals. 2546 2547 # Should not stack overflow 2548 scope() 2549 2550 @skipIfTorchDynamo("too slow") 2551 def test_free_deep_graph_complicated(self): 2552 def scope(): 2553 depth = 100000 2554 randchoice = torch.randint(2, [depth, 2]) 2555 x = torch.randn(1, requires_grad=True) 2556 y = x.clone() 2557 2558 # Hold the two previous values 2559 prev_values = [None, None] 2560 2561 # Build a "chain with skip connections" graph 2562 for _ in range(depth): 2563 prev_tensors = [ 2564 tensor for tensor in prev_values[:-1] if tensor is not None 2565 ] 2566 prev_values.append(y) 2567 prev_values.pop(0) 2568 2569 # Definitely pick one tensor to add 2570 y += y * 0.000001 2571 2572 # Possibly add other tensors 2573 nprev = len(prev_tensors) 2574 if nprev == 2: 2575 y += randchoice[depth].mul(torch.cat(prev_tensors)).sum() 2576 2577 # graph deletion occurs when the above locals go out of scope. 2578 2579 # Should not stack overflow 2580 scope() 2581 2582 @skipIfTorchDynamo("too slow") 2583 def test_free_deep_graph_pyfunction(self): 2584 class MyOp(Function): 2585 @staticmethod 2586 def forward(ctx, tensor1, tensor2): 2587 return tensor1 + tensor2 2588 2589 @staticmethod 2590 def backward(ctx, grad_output): 2591 return grad_output, grad_output 2592 2593 def scope(): 2594 depth = 150000 2595 x = torch.randn(1, requires_grad=True) 2596 y = x.clone() 2597 2598 # build deeply nested computation graph 2599 for _ in range(depth): 2600 y = MyOp.apply(y, y) 2601 2602 # graph deletion occurs when the above locals go out of scope. 2603 2604 # Should not stack overflow 2605 scope() 2606 2607 def test_no_unnecessary_save(self): 2608 # If we kept x in the derivative Function of x * 2 we would 2609 # get an error in the backward that would complain that we've 2610 # modified x, which was needed for gradient computation. 2611 # Since we should elide unnecessary saves, this test should pass. 2612 mu = torch.ones(1, requires_grad=True) 2613 x = torch.empty(1) 2614 loss = 0 2615 for i in range(3): 2616 x.detach_() 2617 x.copy_(mu + i) 2618 ft = torch.tensor([float(i)]) 2619 multiplied = x * ft 2620 s = multiplied.sum() 2621 loss += s 2622 loss.backward() 2623 2624 def test_no_grad(self): 2625 x = torch.ones(5, 5, requires_grad=True) 2626 y = torch.ones(5, 5) * 4 2627 with torch.no_grad(): 2628 w = x + y 2629 2630 def adder(x, y): 2631 return x + y 2632 2633 adders = [torch.no_grad()(adder), torch.no_grad(adder)] 2634 2635 for adder in adders: 2636 z = adder(x, y) 2637 2638 self.assertFalse(w.requires_grad) 2639 self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) 2640 self.assertIsNone(w.grad_fn) 2641 self.assertFalse(z.requires_grad) 2642 self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) 2643 self.assertIsNone(z.grad_fn) 2644 2645 # test nested decorator and with-statement on no_grad 2646 with torch.no_grad(): 2647 self.assertFalse(torch.is_grad_enabled()) 2648 w = adder(x, y) 2649 self.assertFalse(torch.is_grad_enabled()) 2650 2651 def test_enable_grad_decorator_no_paren(self): 2652 x = torch.ones(1, requires_grad=True) 2653 2654 @torch.enable_grad 2655 def doubler(x): 2656 return x * 2 2657 2658 with torch.no_grad(): 2659 z = doubler(x) 2660 self.assertTrue(z.requires_grad) 2661 2662 def test_set_grad_generator_functions(self): 2663 @torch.no_grad() 2664 def gen_no_grad(): 2665 for i in range(10): 2666 self.assertEqual(torch.is_grad_enabled(), False) 2667 yield i 2668 2669 with torch.enable_grad(): 2670 for _ in gen_no_grad(): 2671 self.assertEqual(torch.is_grad_enabled(), True) 2672 2673 @torch.enable_grad() 2674 def gen_enable_grad(): 2675 for i in range(10): 2676 self.assertEqual(torch.is_grad_enabled(), True) 2677 yield i 2678 2679 with torch.no_grad(): 2680 for _ in gen_enable_grad(): 2681 self.assertEqual(torch.is_grad_enabled(), False) 2682 2683 def test_set_grad_generator_functions_recursive(self): 2684 # enable_grad_decorator_recursive and no_grad_decorator_recursive call each other 2685 # recursively, to ensure that the decorators preserve the caller's setting 2686 @torch.enable_grad() 2687 def enable_grad_decorator_recursive(depth): 2688 self.assertTrue(torch.is_grad_enabled()) 2689 if depth > 0: 2690 no_grad_decorator_recursive(depth - 1) 2691 self.assertTrue(torch.is_grad_enabled()) 2692 2693 @torch.no_grad() 2694 def no_grad_decorator_recursive(depth): 2695 self.assertFalse(torch.is_grad_enabled()) 2696 if depth > 0: 2697 enable_grad_decorator_recursive(depth - 1) 2698 self.assertFalse(torch.is_grad_enabled()) 2699 2700 # enable_grad_context_manager_recursive and no_grad_context_manager_recursive call 2701 # each other recursively, to ensure that the decorators preserve the caller's setting 2702 def enable_grad_context_manager_recursive(depth): 2703 with torch.enable_grad(): 2704 self.assertTrue(torch.is_grad_enabled()) 2705 if depth > 0: 2706 no_grad_context_manager_recursive(depth - 1) 2707 self.assertTrue(torch.is_grad_enabled()) 2708 2709 def no_grad_context_manager_recursive(depth): 2710 with torch.no_grad(): 2711 self.assertFalse(torch.is_grad_enabled()) 2712 if depth > 0: 2713 enable_grad_context_manager_recursive(depth - 1) 2714 self.assertFalse(torch.is_grad_enabled()) 2715 2716 with torch.enable_grad(): 2717 self.assertTrue(torch.is_grad_enabled()) 2718 enable_grad_decorator_recursive(10) 2719 self.assertTrue(torch.is_grad_enabled()) 2720 enable_grad_context_manager_recursive(10) 2721 self.assertTrue(torch.is_grad_enabled()) 2722 2723 with torch.no_grad(): 2724 self.assertFalse(torch.is_grad_enabled()) 2725 enable_grad_decorator_recursive(10) 2726 self.assertFalse(torch.is_grad_enabled()) 2727 enable_grad_context_manager_recursive(10) 2728 self.assertFalse(torch.is_grad_enabled()) 2729 2730 def test_set_grad_coroutines(self): 2731 @torch.no_grad() 2732 def coro_no_grad(n=10): 2733 self.assertFalse(torch.is_grad_enabled()) 2734 for i in range(n): 2735 self.assertFalse(torch.is_grad_enabled()) 2736 r = yield i 2737 self.assertFalse(torch.is_grad_enabled()) 2738 self.assertEqual(i, r) 2739 self.assertFalse(torch.is_grad_enabled()) 2740 2741 @torch.enable_grad() 2742 def coro_enable_grad(n=10): 2743 self.assertTrue(torch.is_grad_enabled()) 2744 for i in range(n): 2745 self.assertTrue(torch.is_grad_enabled()) 2746 r = yield i 2747 self.assertTrue(torch.is_grad_enabled()) 2748 self.assertEqual(i, r) 2749 self.assertTrue(torch.is_grad_enabled()) 2750 2751 with torch.enable_grad(): 2752 self.assertTrue(torch.is_grad_enabled()) 2753 coro, r = coro_no_grad(), None 2754 try: 2755 while True: 2756 self.assertTrue(torch.is_grad_enabled()) 2757 r = coro.send(r) 2758 self.assertTrue(torch.is_grad_enabled()) 2759 2760 except StopIteration: 2761 pass 2762 2763 with torch.no_grad(): 2764 self.assertFalse(torch.is_grad_enabled()) 2765 coro, r = coro_enable_grad(), None 2766 try: 2767 while True: 2768 self.assertFalse(torch.is_grad_enabled()) 2769 r = coro.send(r) 2770 self.assertFalse(torch.is_grad_enabled()) 2771 2772 except StopIteration: 2773 pass 2774 2775 def test_set_grad_coroutines_benign_exceptions(self): 2776 class RecoverableException(Exception): 2777 pass 2778 2779 @torch.no_grad() 2780 def coro_no_grad(n=10): 2781 has_raised = False 2782 for i in range(n): 2783 try: 2784 self.assertFalse(torch.is_grad_enabled()) 2785 yield (-i if has_raised else i) 2786 2787 except RecoverableException: 2788 self.assertFalse(torch.is_grad_enabled()) 2789 has_raised = True 2790 2791 @torch.enable_grad() 2792 def coro_enable_grad(n=10): 2793 has_raised = False 2794 for i in range(n): 2795 try: 2796 self.assertTrue(torch.is_grad_enabled()) 2797 yield (-i if has_raised else i) 2798 2799 except RecoverableException: 2800 self.assertTrue(torch.is_grad_enabled()) 2801 has_raised = True 2802 2803 with torch.enable_grad(): 2804 coro = coro_no_grad() 2805 assert 0 == next(coro) 2806 try: 2807 while True: 2808 r = coro.throw(RecoverableException) 2809 self.assertLess(r, 0) 2810 2811 except StopIteration: 2812 pass 2813 2814 with torch.no_grad(): 2815 coro = coro_enable_grad() 2816 assert 0 == next(coro) 2817 try: 2818 while True: 2819 r = coro.throw(RecoverableException) 2820 self.assertLess(r, 0) 2821 2822 except StopIteration: 2823 pass 2824 2825 def test_set_grad_coroutines_critical_exceptions(self): 2826 class UnrecoverableException(Exception): 2827 pass 2828 2829 class SecondaryException(Exception): 2830 pass 2831 2832 @torch.no_grad() 2833 def coro_no_grad(n=10): 2834 has_raised = False 2835 for i in range(n): 2836 try: 2837 self.assertFalse(torch.is_grad_enabled()) 2838 yield (-i if has_raised else i) 2839 2840 except UnrecoverableException: 2841 self.assertFalse(torch.is_grad_enabled()) 2842 raise SecondaryException from None 2843 2844 @torch.enable_grad() 2845 def coro_enable_grad(n=10): 2846 has_raised = False 2847 for i in range(n): 2848 try: 2849 self.assertTrue(torch.is_grad_enabled()) 2850 yield (-i if has_raised else i) 2851 2852 except UnrecoverableException: 2853 self.assertTrue(torch.is_grad_enabled()) 2854 raise SecondaryException from None 2855 2856 with torch.enable_grad(): 2857 coro = coro_no_grad() 2858 assert 0 == next(coro) 2859 with self.assertRaises(SecondaryException): 2860 coro.throw(UnrecoverableException) 2861 2862 with torch.no_grad(): 2863 coro = coro_enable_grad() 2864 assert 0 == next(coro) 2865 with self.assertRaises(SecondaryException): 2866 coro.throw(UnrecoverableException) 2867 2868 def test_set_grad_coroutines_exit(self): 2869 @torch.no_grad() 2870 def coro_no_grad(state): 2871 for i in range(10): 2872 try: 2873 self.assertFalse(torch.is_grad_enabled()) 2874 yield i 2875 2876 except GeneratorExit: 2877 self.assertFalse(torch.is_grad_enabled()) 2878 state.add("GeneratorExit") 2879 raise 2880 2881 @torch.enable_grad() 2882 def coro_enable_grad(state): 2883 for i in range(10): 2884 try: 2885 self.assertTrue(torch.is_grad_enabled()) 2886 yield i 2887 2888 except GeneratorExit: 2889 self.assertTrue(torch.is_grad_enabled()) 2890 state.add("GeneratorExit") 2891 raise 2892 2893 state = set() 2894 with torch.enable_grad(): 2895 coro = coro_no_grad(state) 2896 for i in range(5): 2897 next(coro) 2898 2899 coro.close() 2900 self.assertTrue("GeneratorExit" in state) 2901 2902 state = set() 2903 with torch.no_grad(): 2904 coro = coro_enable_grad(state) 2905 for i in range(5): 2906 next(coro) 2907 2908 coro.close() 2909 self.assertTrue("GeneratorExit" in state) 2910 2911 def test_no_grad_python_function(self): 2912 """Python Functions should respect grad mode.""" 2913 x = torch.ones(5, 5, requires_grad=True) 2914 2915 class MyOp(Function): 2916 @staticmethod 2917 def forward(self, x): 2918 return x + 1 2919 2920 @staticmethod 2921 def backward(self, dy): 2922 return dy 2923 2924 with torch.no_grad(): 2925 y = MyOp.apply(x) 2926 self.assertFalse(y.requires_grad) 2927 2928 def test_indexing(self): 2929 x = torch.arange(1.0, 17).view(4, 4) 2930 y = Variable(x, requires_grad=True) 2931 2932 def compare(x, y, idx, indexed_tensor, indexed_var): 2933 indexed_var_t = indexed_var.data 2934 if not isinstance(indexed_tensor, torch.Tensor): 2935 indexed_var_t = indexed_var_t[0] 2936 self.assertEqual(indexed_tensor, indexed_var_t) 2937 2938 indexed_var.sum().backward() 2939 expected_grad = torch.empty(x.size()).fill_(0) 2940 expected_grad[idx] = 1 2941 self.assertEqual(y.grad, expected_grad) 2942 2943 def check_index(x, y, idx): 2944 if y.grad is not None: 2945 with torch.no_grad(): 2946 y.grad.zero_() 2947 indexed_tensor = x[idx] 2948 indexed_var = y[idx] 2949 compare(x, y, idx, indexed_tensor, indexed_var) 2950 2951 check_index(x, y, 1) 2952 check_index(x, y, (1, 1)) 2953 check_index(x, y, slice(1, None)) 2954 check_index(x, y, slice(None, 2)) 2955 check_index(x, y, (slice(None, 2), 2)) 2956 check_index(x, y, (slice(1, 2), 2)) 2957 check_index(x, y, (1, slice(2, None))) 2958 check_index(x, y, (slice(None, None), slice(2, None))) 2959 check_index(x, y, torch.LongTensor([0, 2])) 2960 check_index(x, y, torch.rand(4, 4).bernoulli().bool()) 2961 check_index(x, y, (Ellipsis, slice(2, None))) 2962 check_index(x, y, ([0], [0])) 2963 check_index(x, y, ([1, 2, 3], [0])) 2964 check_index(x, y, ([1, 2], [2, 1])) 2965 check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]])) 2966 check_index(x, y, ([slice(None), [2, 3]])) 2967 check_index(x, y, ([[2, 3], slice(None)])) 2968 2969 # advanced indexing, with less dim, or ellipsis 2970 check_index(x, y, ([0])) 2971 check_index(x, y, ([0],)) 2972 2973 x = torch.arange(1.0, 49).view(4, 3, 4) 2974 y = Variable(x, requires_grad=True) 2975 2976 check_index(x, y, (slice(None), [0], [0])) 2977 check_index(x, y, ([0], [0], slice(None))) 2978 check_index(x, y, (slice(None), [0, 1, 2], [0])) 2979 check_index(x, y, ([0, 1, 2], [0], slice(None))) 2980 check_index(x, y, (slice(None), [1, 2], [2, 1])) 2981 check_index(x, y, ([1, 2], [2, 1], slice(None))) 2982 check_index(x, y, (slice(None), [[1, 2], [2, 0]], [[0, 1], [2, 3]])) 2983 check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(None))) 2984 check_index(x, y, (slice(None), slice(None), [2, 1])) 2985 check_index(x, y, (slice(None), [2, 1], slice(None))) 2986 check_index(x, y, ([2, 1], slice(None), slice(None))) 2987 2988 # advanced indexing, with less dim, or ellipsis 2989 check_index(x, y, ([0],)) 2990 check_index(x, y, ([0], slice(None))) 2991 check_index(x, y, ([0], Ellipsis)) 2992 check_index(x, y, ([1, 2], [0, 1])) 2993 check_index(x, y, ([1, 2], [0, 1], Ellipsis)) 2994 check_index(x, y, (Ellipsis, [1, 2], [0, 1])) 2995 2996 # advanced indexing, with a tensor wrapped in a variable 2997 z = torch.LongTensor([0, 1]) 2998 zv = Variable(z, requires_grad=False) 2999 seq = [z, Ellipsis] 3000 seqv = [zv, Ellipsis] 3001 3002 if y.grad is not None: 3003 with torch.no_grad(): 3004 y.grad.zero_() 3005 indexed_tensor = x[seq] 3006 indexed_var = y[seqv] 3007 compare(x, y, seq, indexed_tensor, indexed_var) 3008 3009 def test_indexing_duplicates(self): 3010 x = torch.arange(1.0, 17).view(4, 4) 3011 y = Variable(x, requires_grad=True) 3012 3013 idx = torch.LongTensor([1, 1, 3, 2, 1, 2]) 3014 y[idx].sum().backward() 3015 expected_grad = torch.zeros(4, 4) 3016 for i in idx: 3017 expected_grad[i] += 1 3018 self.assertEqual(y.grad, expected_grad) 3019 3020 # with advanced indexing 3021 x = torch.arange(1.0, 17).view(4, 4) 3022 y = Variable(x, requires_grad=True) 3023 3024 idx = [[1, 1, 3, 2, 1, 2], [0]] 3025 y[idx].sum().backward() 3026 expected_grad = torch.zeros(4, 4) 3027 for i in idx[0]: 3028 for j in idx[1]: 3029 expected_grad[i][j] += 1 3030 3031 self.assertEqual(y.grad, expected_grad) 3032 3033 x = torch.arange(1.0, 17).view(4, 4) 3034 y = Variable(x, requires_grad=True) 3035 idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]] 3036 y[idx].sum().backward() 3037 expected_grad = torch.tensor( 3038 [ 3039 [0.0, 2.0, 0.0, 0.0], 3040 [1.0, 0.0, 0.0, 0.0], 3041 [0.0, 1.0, 0.0, 0.0], 3042 [0.0, 0.0, 0.0, 0.0], 3043 ] 3044 ) 3045 self.assertEqual(y.grad, expected_grad) 3046 3047 x = torch.arange(1.0, 65).view(4, 4, 4) 3048 y = Variable(x, requires_grad=True) 3049 3050 idx = [[1, 1, 1], slice(None), slice(None)] 3051 y[idx].sum().backward() 3052 expected_grad = torch.empty(4, 4, 4).zero_() 3053 expected_grad[1].fill_(3) 3054 self.assertEqual(y.grad, expected_grad) 3055 3056 def test_index_backward_does_not_save_tensor(self): 3057 # Example from https://github.com/pytorch/pytorch/issues/24853. 3058 # if `index(tensor, indices)` saves `tensor` for backwards, then it will 3059 # trigger a version check on `tensor` during the backward pass, which 3060 # will cause the following code to error because `tensor` gets modified 3061 # by the indexing line. 3062 a = torch.tensor([1.0, 0, 0]) 3063 b = torch.zeros(3, requires_grad=True) 3064 tensor = b + 0 3065 tensor[a != 0] = tensor[a != 0] 3066 tensor.backward(torch.zeros_like(tensor)) 3067 3068 def test_volatile_deprecated(self): 3069 v = torch.autograd.torch.randn(3, 3) 3070 with warnings.catch_warnings(record=True) as w: 3071 self.assertFalse(v.volatile) 3072 self.assertIn("volatile", str(w[0].message)) 3073 3074 def test_saved_variables_deprecated(self): 3075 class MyFunction(Function): 3076 @staticmethod 3077 def forward(ctx, tensor1, tensor2): 3078 ctx.save_for_backward(tensor1, tensor2) 3079 return tensor1 + tensor2 3080 3081 @staticmethod 3082 def backward(ctx, grad_output): 3083 var1, var2 = ctx.saved_variables 3084 return (grad_output, grad_output) 3085 3086 with warnings.catch_warnings(record=True) as warns: 3087 warnings.simplefilter("always") 3088 x = torch.randn((3, 3), requires_grad=True) 3089 y = torch.randn((3, 3), requires_grad=True) 3090 MyFunction.apply(x, y).sum().backward() 3091 3092 has_deprecated = ( 3093 "deprecated" in str(warn) and "saved_variables" in str(warn) 3094 for warn in warns 3095 ) 3096 has_deprecated = reduce(lambda x, y: x or y, has_deprecated) 3097 self.assertTrue(has_deprecated) 3098 3099 def test_requires_grad(self): 3100 x = torch.randn(5, 5) 3101 y = torch.randn(5, 5) 3102 z = torch.randn(5, 5, requires_grad=True) 3103 a = x + y 3104 self.assertFalse(a.requires_grad) 3105 b = a + z 3106 self.assertTrue(b.requires_grad) 3107 3108 def error(): 3109 raise RuntimeError 3110 3111 # Make sure backward isn't called on these 3112 a._backward_hooks = OrderedDict() 3113 x._backward_hooks = OrderedDict() 3114 y._backward_hooks = OrderedDict() 3115 a._backward_hooks["test"] = error 3116 x._backward_hooks["test"] = error 3117 y._backward_hooks["test"] = error 3118 b.backward(torch.ones(5, 5)) 3119 3120 def test_requires_grad_(self): 3121 x = torch.randn(5, 5) 3122 y = torch.randn(5, 5, requires_grad=True) 3123 self.assertIs(x, x.requires_grad_()) 3124 self.assertTrue(x.requires_grad) 3125 self.assertIs(y, y.requires_grad_()) 3126 self.assertTrue(y.requires_grad) 3127 self.assertIs(x, x.requires_grad_(True)) 3128 self.assertTrue(x.requires_grad) 3129 self.assertIs(y, y.requires_grad_(True)) 3130 self.assertTrue(y.requires_grad) 3131 z = x * y 3132 self.assertRaises(RuntimeError, lambda: z.requires_grad_(False)) 3133 self.assertIs(z, z.requires_grad_()) 3134 self.assertTrue(z.requires_grad) 3135 self.assertIs(z, z.requires_grad_(True)) 3136 self.assertTrue(z.requires_grad) 3137 3138 self.assertIs(x, x.requires_grad_(False)) 3139 self.assertFalse(x.requires_grad) 3140 self.assertIs(y, y.requires_grad_(False)) 3141 self.assertFalse(y.requires_grad) 3142 3143 def test_requires_grad_inplace(self): 3144 a = torch.randn(5, 5) 3145 b = torch.randn(5, 5, requires_grad=True) 3146 a += b 3147 self.assertTrue(a.requires_grad) 3148 3149 # non-leaf 3150 a = torch.randn(5, 5) + 0 3151 b = torch.randn(5, 5, requires_grad=True) 3152 a += b 3153 self.assertTrue(a.requires_grad) 3154 3155 def test_no_requires_grad_inplace(self): 3156 # basic case, should be able to modify inplace while requires_grad is False 3157 a = torch.randn(2, 3) 3158 a.add_(5) 3159 a.requires_grad = True 3160 a.sum().backward() 3161 self.assertEqual(a.grad, torch.ones(2, 3)) 3162 3163 # same but with a view 3164 a = torch.randn(2, 3) 3165 b = a[:] 3166 b.add_(5) 3167 a.requires_grad = True 3168 a.sum().backward() 3169 self.assertEqual(a.grad, torch.ones(2, 3)) 3170 3171 # should fail if requires_grad = True when we modify inplace 3172 a = torch.randn(2, 3) 3173 b = a[:] 3174 a.requires_grad = True 3175 with self.assertRaises(RuntimeError): 3176 a.add_(5) 3177 with self.assertRaises(RuntimeError): 3178 b.add_(5) 3179 3180 def test_attribute_deletion(self): 3181 x = torch.randn((5, 5), requires_grad=True) 3182 del x.grad 3183 self.assertIsNone(x.grad) 3184 with self.assertRaises(RuntimeError): 3185 del x.data 3186 with self.assertRaises(TypeError): 3187 x.data = None 3188 with self.assertRaises(RuntimeError): 3189 del x.requires_grad 3190 with self.assertRaises(RuntimeError): 3191 del x._grad_fn 3192 with self.assertRaises(RuntimeError): 3193 del x._backward_hooks 3194 3195 def test_duplicate_backward_root(self): 3196 a = torch.randn(5, 5, requires_grad=True) 3197 b = torch.randn(5, 5, requires_grad=True) 3198 3199 x = a * b 3200 grad_output = torch.randn_like(x) 3201 torch.autograd.backward([x, x], [grad_output, grad_output]) 3202 3203 self.assertEqual(a.grad, b * grad_output * 2) 3204 self.assertEqual(b.grad, a * grad_output * 2) 3205 3206 def test_backward_no_grad(self): 3207 a = torch.randn(5, 5, requires_grad=True) 3208 b = a + 2 3209 with self.assertRaises(RuntimeError): 3210 torch.autograd.backward([b], [None]) 3211 3212 def test_backward_twice_with_saved_values(self): 3213 b = torch.randn(3, requires_grad=True, dtype=torch.double) 3214 c = torch.zeros(3, dtype=torch.double) 3215 c[[1, 2]] = b[[1, 1]] 3216 c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3217 self.assertRaisesRegex( 3218 RuntimeError, 3219 "Specify retain_graph=True", 3220 lambda: c.backward(torch.tensor([1, 1, 1], dtype=torch.double)), 3221 ) 3222 3223 def test_backward_twice_retained_graph_with_saved_values(self): 3224 b = torch.randn(3, requires_grad=True, dtype=torch.double) 3225 c = torch.zeros(3, dtype=torch.double) 3226 c[[1, 2]] = b[[1, 1]] 3227 c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True) 3228 c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3229 3230 def test_backward_twice_without_saved_values(self): 3231 b = torch.randn(3, requires_grad=True, dtype=torch.double) 3232 c = b + 1 3233 c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3234 c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3235 3236 def test_backward_twice_retained_graph_without_saved_values(self): 3237 b = torch.randn(3, requires_grad=True, dtype=torch.double) 3238 c = torch.zeros(3, dtype=torch.double) 3239 c[[1, 2]] = b[[1, 1]] 3240 c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True) 3241 c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3242 3243 def test_backward_create_graph_warns(self): 3244 with set_warn_always_context(True): 3245 b = torch.randn(3, requires_grad=True, dtype=torch.double) 3246 c = b * b 3247 with warnings.catch_warnings(record=True) as ws: 3248 c.backward(torch.ones_like(c), create_graph=True) 3249 b.grad = None 3250 self.assertTrue( 3251 any( 3252 "Using backward() with create_graph=True" in str(w.message) 3253 for w in ws 3254 ) 3255 ) 3256 3257 # Should not warn for grad 3258 with warnings.catch_warnings(record=True) as ws: 3259 torch.autograd.grad(c, b, torch.ones_like(c), create_graph=True) 3260 self.assertFalse( 3261 any( 3262 "Using backward() with create_graph=True" in str(w.message) 3263 for w in ws 3264 ) 3265 ) 3266 3267 def test_next_functions(self): 3268 x = torch.randn(5, 5, requires_grad=True) 3269 y = torch.randn(5, 5, requires_grad=True) 3270 3271 a = x + y 3272 self.assertIsNotNone(a.grad_fn) 3273 next_functions = a.grad_fn.next_functions 3274 self.assertEqual(len(next_functions), 2) 3275 self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad) 3276 self.assertEqual(next_functions[0][1], 0) 3277 self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad) 3278 self.assertEqual(next_functions[1][1], 0) 3279 3280 b = a + 5 3281 next_functions = b.grad_fn.next_functions 3282 self.assertEqual(len(next_functions), 2) 3283 self.assertIs(next_functions[0][0], a.grad_fn) 3284 self.assertIs(next_functions[1][0], None) 3285 3286 def test_inplace(self): 3287 x = torch.ones(5, 5, requires_grad=True) 3288 y = Variable(torch.ones(5, 5) * 4, requires_grad=True) 3289 3290 z = x * y 3291 q = z + y 3292 w = z * y 3293 z.add_(2) 3294 # Add doesn't need it's inputs to do backward, so it shouldn't raise 3295 q.backward(torch.ones(5, 5), retain_graph=True) 3296 # Mul saves both inputs in forward, so it should raise 3297 self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) 3298 3299 z = x * y 3300 q = z * y 3301 r = z + y 3302 w = z.add_(y) 3303 # w is a the last expression, so this should succeed 3304 w.backward(torch.ones(5, 5), retain_graph=True) 3305 # r doesn't use the modified value in backward, so it should succeed 3306 r.backward(torch.ones(5, 5), retain_graph=True) 3307 # q uses dirty z, so it should raise 3308 self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) 3309 3310 with torch.no_grad(): 3311 x.grad.zero_() 3312 m = x / 2 3313 z = m + y / 8 3314 q = z * y 3315 r = z + y 3316 prev_version = z._version 3317 w = z.exp_() 3318 self.assertNotEqual(z._version, prev_version) 3319 r.backward(torch.ones(5, 5), retain_graph=True) 3320 self.assertEqual(x.grad, torch.ones(5, 5) / 2) 3321 w.backward(torch.ones(5, 5), retain_graph=True) 3322 self.assertEqual(x.grad, torch.empty(5, 5).fill_((1 + math.e) / 2)) 3323 self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) 3324 3325 leaf = torch.ones(5, 5, requires_grad=True) 3326 x = leaf.clone() 3327 x.add_(10) 3328 self.assertEqual(x, torch.ones(5, 5) * 11) 3329 # x should be still usable 3330 y = x + 2 3331 y.backward(torch.ones(5, 5)) 3332 self.assertEqual(leaf.grad, torch.ones(5, 5)) 3333 z = x * y 3334 x.add_(2) 3335 self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) 3336 3337 def test_mark_non_differentiable(self): 3338 class MyFunction(Function): 3339 @staticmethod 3340 def forward(ctx, input): 3341 output = input > 0 3342 ctx.mark_non_differentiable(output) 3343 return output 3344 3345 @staticmethod 3346 def backward(ctx, grad_output): 3347 return (grad_output * 0).to(torch.double) 3348 3349 x = torch.randn(5, 5, requires_grad=True) 3350 mask = MyFunction.apply(x) 3351 self.assertFalse(mask.requires_grad) 3352 y = x.masked_fill(mask, 0) 3353 y.sum().backward() 3354 3355 @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 3356 def test_mark_non_differentiable_mixed(self): 3357 class MyFunction(Function): 3358 @staticmethod 3359 def forward(ctx, input): 3360 a = input + 1 3361 b = input + 2 3362 ctx.mark_non_differentiable(a) 3363 return a, b 3364 3365 @staticmethod 3366 def backward(ctx, grad_a, grad_b): 3367 self.assertTrue((grad_a == 0).all()) 3368 self.assertTrue((grad_b == 1).all()) 3369 return grad_b 3370 3371 x = torch.randn(5, 5, requires_grad=True) 3372 a, b = MyFunction.apply(x) 3373 self.assertFalse(a.requires_grad) 3374 self.assertTrue(b.requires_grad) 3375 b.sum().backward() 3376 self.assertEqual(x.grad, torch.ones(5, 5)) 3377 3378 def test_mark_non_differentiable_none(self): 3379 # This used to segfault because MyFunction would send back null 3380 # gradients to MulBackward, which is implemented in C++. C++ 3381 # implemented functions expect incoming grad_outputs to be non-null. 3382 class MyFunction(Function): 3383 @staticmethod 3384 def forward(ctx, input): 3385 output = input.clone() 3386 ctx.mark_non_differentiable(output) 3387 return output 3388 3389 @staticmethod 3390 def backward(ctx, grad_output): 3391 return None 3392 3393 x = torch.randn(5, 5, requires_grad=True) 3394 r = MyFunction.apply(x * x) 3395 (r * x).sum().backward() 3396 3397 def test_return_duplicate(self): 3398 class DoubleDuplicate(Function): 3399 @staticmethod 3400 def forward(ctx, x): 3401 output = x * 2 3402 return output, output 3403 3404 @staticmethod 3405 def backward(ctx, grad1, grad2): 3406 return grad1 * 2 + grad2 * 2 3407 3408 def fn(x): 3409 a, b = DoubleDuplicate.apply(x) 3410 self.assertIs(a, b) 3411 return a + b 3412 3413 x = torch.randn(5, 5, dtype=torch.double, requires_grad=True) 3414 gradcheck(fn, [x]) 3415 gradgradcheck(fn, [x]) 3416 3417 def test_return_duplicate_inplace(self): 3418 class DoubleInplace(Function): 3419 @staticmethod 3420 def forward(ctx, x): 3421 x.mul_(2) 3422 ctx.mark_dirty(x) 3423 return x, x 3424 3425 @staticmethod 3426 def backward(ctx, grad1, grad2): 3427 return grad1 * 2 + grad2 * 2 3428 3429 def inplace_fn(x): 3430 a, b = DoubleInplace.apply(x.clone()) 3431 self.assertIs(a, b) 3432 return a + b 3433 3434 x = torch.randn(5, 5, dtype=torch.double, requires_grad=True) 3435 gradcheck(inplace_fn, [x]) 3436 gradgradcheck(inplace_fn, [x]) 3437 3438 # Can't modify leaf variables in-place 3439 self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x)) 3440 # Functions which modify views in-place must return only one output 3441 self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x.clone()[0])) 3442 3443 def _test_setitem(self, size, index): 3444 x = torch.ones(*size, requires_grad=True) 3445 y = x + 2 3446 y_version = y._version 3447 y[index] = 2 3448 self.assertNotEqual(y._version, y_version) 3449 y.backward(torch.ones(*size)) 3450 expected_grad = torch.ones(*size) 3451 expected_grad[index] = 0 3452 self.assertEqual(x.grad, expected_grad) 3453 3454 def _test_setitem_tensor(self, size, index): 3455 x = torch.ones(*size, requires_grad=True) 3456 y = x + 2 3457 y_version = y._version 3458 value = x.new(x[index].size()).fill_(7) 3459 value.requires_grad = True 3460 y[index] = value 3461 self.assertNotEqual(y._version, y_version) 3462 y.backward(torch.ones(*size)) 3463 expected_grad_input = torch.ones(*size) 3464 expected_grad_input[index] = 0 3465 self.assertEqual(x.grad, expected_grad_input) 3466 self.assertEqual(value.grad, torch.ones_like(value)) 3467 3468 # case when x broadcasts to as y[1] 3469 x = torch.randn(4, requires_grad=True) 3470 y = torch.zeros(2, 3, 4) 3471 y[1] = x 3472 y.backward(torch.randn(2, 3, 4)) 3473 self.assertEqual(x.size(), x.grad.size()) 3474 3475 def test_setitem(self): 3476 self._test_setitem((5, 5), 1) 3477 self._test_setitem((5,), 1) 3478 self._test_setitem((1,), 0) 3479 self._test_setitem((10,), [[0, 4, 2]]) 3480 self._test_setitem((5, 5), [[0, 4], [2, 2]]) 3481 self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]]) 3482 self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)]) 3483 self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)]) 3484 self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]]) 3485 self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)]) 3486 self._test_setitem_tensor((5, 5), 3) 3487 self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]]) 3488 self._test_setitem_tensor((5,), 3) 3489 self._test_setitem_tensor( 3490 (5,), Variable(torch.LongTensor([3]), requires_grad=False).sum() 3491 ) 3492 self._test_setitem_tensor((5,), [[0, 1, 2, 3]]) 3493 self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]]) 3494 self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)]) 3495 self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)]) 3496 self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]]) 3497 self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)]) 3498 self._test_setitem_tensor( 3499 (5, 5, 5), 3500 [ 3501 Variable(torch.LongTensor([1, 3]), requires_grad=False), 3502 [2, 4], 3503 slice(None), 3504 ], 3505 ) 3506 3507 def test_setitem_mask(self): 3508 mask = torch.BoolTensor(5, 5).bernoulli_() 3509 self._test_setitem((5, 5), Variable(mask)) 3510 self._test_setitem((5,), Variable(mask[0])) 3511 self._test_setitem((1,), Variable(mask[0, 0:1])) 3512 self._test_setitem_tensor((5, 5), Variable(mask)) 3513 self._test_setitem_tensor((5,), Variable(mask[0])) 3514 3515 def test_select_sum(self): 3516 # both select and sum return Scalars in ATen; ensure they work together. 3517 x = torch.randn(10, dtype=torch.double, requires_grad=True) 3518 3519 def func(x): 3520 return x.select(0, 1).sum() 3521 3522 gradcheck(func, [x]) 3523 gradgradcheck(func, [x]) 3524 3525 def test_diagonal_expanded_v(self): 3526 value = torch.rand([]) 3527 v_expanded = torch.tensor(value).expand(10) 3528 a = torch.rand(10, 10, dtype=torch.double, requires_grad=True) 3529 (result,) = torch.autograd.grad(a.diagonal(), a, v_expanded) 3530 self.assertEqual(result, torch.eye(10, dtype=torch.double) * value) 3531 3532 def test_select_expanded_v(self): 3533 v_expanded = torch.rand(10).expand(10, 10) 3534 a = torch.rand(10, 10, 10, requires_grad=True) 3535 (result,) = torch.autograd.grad(a[0], a, v_expanded) 3536 expected = torch.zeros(10, 10, 10) 3537 expected[0] = v_expanded 3538 self.assertEqual(result, expected) 3539 3540 def test_slice_expanded_v(self): 3541 v_expanded = torch.rand(10, 1).expand(2, 10, 10) 3542 a = torch.rand(10, 10, 10, requires_grad=True) 3543 (result,) = torch.autograd.grad(a[3:5], a, v_expanded) 3544 expected = torch.zeros(10, 10, 10) 3545 expected[3:5] = v_expanded 3546 self.assertEqual(result, expected) 3547 3548 def test_unused_output(self): 3549 x = torch.randn(10, 10, requires_grad=True) 3550 outputs = x.chunk(5) 3551 o = outputs[2] 3552 o = o * 4 + 2 3553 o.sum().backward() 3554 expected_grad = torch.zeros(10, 10) 3555 expected_grad[4:6] = 4 3556 self.assertEqual(x.grad, expected_grad) 3557 3558 with torch.no_grad(): 3559 x.grad.zero_() 3560 grad_output = torch.randn(2, 10) 3561 outputs = x.chunk(5) 3562 outputs[0].backward(grad_output) 3563 expected_grad = torch.zeros(10, 10) 3564 expected_grad[:2] = grad_output 3565 self.assertEqual(x.grad, expected_grad) 3566 3567 # TODO: opinfo this or move to the sparse test suite 3568 def _test_sparse_gather(self, size_x, size_ind, dim): 3569 x = torch.randn(size_x, requires_grad=True) 3570 if len(size_ind) > 0 and len(size_x) > 0: 3571 ind = torch.randint(x.size(dim), size_ind) 3572 else: 3573 ind = torch.zeros(size_ind, dtype=torch.int64) 3574 out = torch.gather(x, dim, ind, sparse_grad=False) 3575 grad = torch.rand_like(out) 3576 out.backward(grad) 3577 grad_dense = x.grad.clone() 3578 x.grad = None 3579 out = torch.gather(x, dim, ind, sparse_grad=True) 3580 out.backward(grad) 3581 self.assertEqual(grad_dense, x.grad.to_dense()) 3582 3583 def test_sparse_gather_dim0(self): 3584 self._test_sparse_gather((10, 10), (5, 10), 0) 3585 3586 def test_sparse_gather_dim1(self): 3587 self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1) 3588 3589 def test_sparse_gather_dim_neg(self): 3590 self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1) 3591 3592 def test_sparse_gather_ind_scalar(self): 3593 self._test_sparse_gather((10,), (), 0) 3594 3595 def test_sparse_gather_x_scalar(self): 3596 self._test_sparse_gather((), (2,), 0) 3597 3598 def test_sparse_gather_both_scalar(self): 3599 self._test_sparse_gather((), (), 0) 3600 3601 def test_gc_in_destructor(self): 3602 """ 3603 Previously, if a Function destructor triggered a garbage collection, 3604 the Variable's tp_dealloc handler would get called twice leading to a 3605 segfault. 3606 """ 3607 3608 class CollectOnDelete(Function): 3609 def forward(self, x): 3610 return x 3611 3612 def backward(self, grad_output): 3613 return grad_output 3614 3615 def __del__(self): 3616 gc.collect() 3617 3618 for _ in range(10): 3619 CollectOnDelete().forward(torch.randn(1, requires_grad=True)).backward() 3620 3621 def test_naughty_autograd_function_attribute_access(self): 3622 class Id(Function): 3623 @staticmethod 3624 def forward(ctx, x): 3625 return x 3626 3627 @staticmethod 3628 def backward(ctx, grad_x): 3629 return grad_x 3630 3631 with self.assertWarnsRegex(DeprecationWarning, "should not be instantiated"): 3632 f = Id() 3633 3634 # After raising warning, should still return an instance 3635 self.assertIsInstance(f, Id) 3636 x = torch.zeros(1, requires_grad=True) 3637 with self.assertRaisesRegex( 3638 RuntimeError, "non-static forward method is deprecated" 3639 ): 3640 f(x) 3641 t = Id.apply(x) 3642 self.assertEqual(t.grad_fn.name(), "IdBackward") 3643 3644 # THPFunction is the base class of both grad_fn and autograd functions, 3645 # which means that a lot of accessors on them may segfault. Test that we 3646 # properly error in this case. 3647 t = torch.ones(1, requires_grad=True) 3648 t._backward_hooks = {} 3649 with self.assertRaisesRegex( 3650 RuntimeError, "Attribute '_register_hook_dict' is invalid" 3651 ): 3652 f._register_hook_dict(t) 3653 with self.assertRaisesRegex( 3654 RuntimeError, "Attribute 'register_hook' is invalid" 3655 ): 3656 f.register_hook(lambda x, y: None) 3657 with self.assertRaisesRegex( 3658 RuntimeError, "Attribute 'next_functions' is invalid" 3659 ): 3660 f.next_functions 3661 with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"): 3662 f.name() 3663 with self.assertRaisesRegex( 3664 RuntimeError, "underlying PyNode has already been deallocated" 3665 ): 3666 f.metadata 3667 3668 @unittest.expectedFailure 3669 def test_naughty_anomaly_access(self): 3670 class MyFunction(Function): 3671 @staticmethod 3672 def forward(ctx, x): 3673 return x 3674 3675 @staticmethod 3676 def backward(ctx, g): 3677 return g 3678 3679 x = torch.zeros(1, requires_grad=True) 3680 y = MyFunction.apply(x) 3681 y.backward() 3682 y.grad_fn.metadata 3683 g = y.grad_fn 3684 del y 3685 g.metadata # this currently fails, but shouldn't 3686 3687 def test_naughty_autograd_function_stashing_ctx(self): 3688 saved_ctx = [] 3689 3690 class Id(Function): 3691 @staticmethod 3692 def forward(ctx, x): 3693 ctx.save_for_backward(x) 3694 return x 3695 3696 @staticmethod 3697 def backward(ctx, grad_x): 3698 saved_ctx.append(ctx) 3699 return ctx.saved_tensors 3700 3701 p = torch.zeros(1, requires_grad=True) 3702 loss = Id.apply(p) 3703 loss.backward(retain_graph=True) 3704 del loss 3705 # At this point in time, it complains that the graph has been freed 3706 # (which indeed true, although a somewhat indirect way of stating the 3707 # problem). 3708 self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors) 3709 3710 def test_custom_autograd_repeated_grad_grad(self): 3711 # This test failed the equality check in PR #22983; it's an interesting 3712 # and different test case worth enshrining. mult1 is not testing 3713 # anything that interesting, but mult2 is the interesting case. 3714 3715 def mult1(x): 3716 return x.prod(dim=-1).prod(dim=-1) 3717 3718 class Mult(torch.autograd.Function): 3719 @staticmethod 3720 def forward(ctx, x): 3721 y = mult1(x) 3722 ctx.save_for_backward(x, y) 3723 return y 3724 3725 @staticmethod 3726 def backward(ctx, grad_output): 3727 x, y = ctx.saved_tensors 3728 return (grad_output * y)[:, None, None] / x 3729 3730 mult2 = Mult.apply 3731 3732 def check_gradgrad_repeated(x, y): 3733 (gy,) = torch.autograd.grad(y[0], x, create_graph=True) 3734 (ggy_1,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True) 3735 (gy,) = torch.autograd.grad(y[0], x, create_graph=True) 3736 (ggy_2,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True) 3737 self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1]) 3738 3739 x = torch.ones(2, 4, 4).requires_grad_() 3740 check_gradgrad_repeated(x, mult1(x)) 3741 check_gradgrad_repeated(x, mult2(x)) 3742 3743 def test_custom_autograd_no_early_free(self): 3744 # This test failed complaining that buffers had already been freed 3745 # prior to #22983. Also pretty interesting test case. 3746 class Double(torch.autograd.Function): 3747 @staticmethod 3748 def forward(ctx, x): 3749 y = x**2 3750 ctx.save_for_backward(x, y) 3751 return y 3752 3753 @staticmethod 3754 def backward(ctx, grad_output): 3755 x, _ = ctx.saved_tensors 3756 return grad_output * 2 * x 3757 3758 # this is equivalent, but uses the output of .forward() in .backward() 3759 class Double2(Double): 3760 @staticmethod 3761 def backward(ctx, grad_output): 3762 x, y = ctx.saved_tensors 3763 return grad_output * 2 * y / x 3764 3765 double = Double.apply 3766 double2 = Double2.apply 3767 3768 x = torch.tensor(2).double().requires_grad_() 3769 3770 self.assertTrue(gradcheck(double, x)) 3771 self.assertTrue(gradgradcheck(double, x)) 3772 self.assertTrue(gradcheck(double2, x)) 3773 self.assertTrue(gradgradcheck(double2, x)) 3774 3775 y = double(x) 3776 torch.autograd.grad(y, x, create_graph=True) 3777 torch.autograd.grad(y, x) 3778 3779 y = double2(x) 3780 torch.autograd.grad(y, x, create_graph=True) 3781 torch.autograd.grad(y, x) # should not error! 3782 3783 def test_detach(self): 3784 x = torch.randn(10, 10, requires_grad=True) 3785 y = x + 2 3786 y = y.detach() 3787 z = y * 4 + 2 3788 self.assertFalse(y.requires_grad) 3789 self.assertFalse(z.requires_grad) 3790 3791 x = torch.randn(10, 10, requires_grad=True) 3792 y = x * 2 3793 y = y.detach() 3794 self.assertFalse(y.requires_grad) 3795 self.assertIsNone(y.grad_fn) 3796 z = x + y 3797 z.sum().backward() 3798 # This is an incorrect gradient, but we assume that's what the user 3799 # wanted. detach() is an advanced option. 3800 self.assertEqual(x.grad, torch.ones(10, 10)) 3801 3802 # in-place detach 3803 x = torch.randn(10, 10, requires_grad=True) 3804 y = torch.randn(10, 10, requires_grad=True) 3805 a = x * 2 3806 (y + a).sum().backward(retain_graph=True) 3807 a.detach_() 3808 self.assertFalse(a.requires_grad) 3809 (y + a).sum().backward() # this won't backprop to x 3810 self.assertEqual(x.grad, torch.ones(10, 10) * 2) 3811 self.assertEqual(y.grad, torch.ones(10, 10) * 2) 3812 3813 # in-place detach on a view raises an exception 3814 view = x.narrow(0, 1, 4) 3815 self.assertRaisesRegex(RuntimeError, "view", lambda: view.detach_()) 3816 3817 def test_detach_base(self): 3818 "detaching base does not detach view" 3819 x = torch.randn(10, 10, requires_grad=True) 3820 view = x.narrow(0, 1, 4) 3821 x.detach_() 3822 self.assertFalse(x.requires_grad) 3823 self.assertTrue(view.requires_grad) 3824 self.assertIsNotNone(view.grad_fn) 3825 self.assertIs(view._base, x) 3826 3827 def test_detach_then_inplace_raises_in_autograd(self): 3828 x = torch.randn([], requires_grad=True) 3829 orig_x = x.detach().clone() 3830 3831 y = x**2 # saves x 3832 z = x.detach() 3833 z.zero_() 3834 with self.assertRaisesRegex(RuntimeError, "has been modified by an inplace"): 3835 y.backward() 3836 3837 def _test_type_conversion_backward(self, t): 3838 fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True) 3839 fvar.double().sum().backward() 3840 self.assertEqual(fvar.grad, torch.ones_like(fvar)) 3841 self.assertEqual(type(fvar.grad), type(fvar)) 3842 dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=True) 3843 dvar.float().sum().backward() 3844 self.assertEqual(dvar.grad, torch.ones_like(dvar)) 3845 self.assertEqual(type(dvar.grad), type(dvar)) 3846 3847 def test_type_conversions(self): 3848 x = torch.randn(5, 5) 3849 self.assertIsInstance(x.float(), torch.FloatTensor) 3850 self.assertIsInstance(x.int(), torch.IntTensor) 3851 if torch.cuda.is_available(): 3852 self.assertIsInstance(x.float().cuda(), torch.cuda.FloatTensor) 3853 self.assertIsInstance(x.int().cuda(), torch.cuda.IntTensor) 3854 self.assertIsInstance(x.int().cuda().cpu(), torch.IntTensor) 3855 if torch.cuda.device_count() >= 2: 3856 x2 = x.float().cuda(1) 3857 self.assertIsInstance(x2, torch.cuda.FloatTensor) 3858 self.assertIs(x2.get_device(), 1) 3859 x2 = x.float().cuda() 3860 self.assertIsInstance(x2, torch.cuda.FloatTensor) 3861 self.assertIs(x2.get_device(), 0) 3862 x2 = x2.cuda(1) 3863 self.assertIsInstance(x2, torch.cuda.FloatTensor) 3864 self.assertIs(x2.get_device(), 1) 3865 y = Variable(torch.randn(5).cuda(1), requires_grad=True) 3866 y.cpu().sum().backward() 3867 self.assertIs(y.grad.get_device(), 1) 3868 self.assertIs(y.long().get_device(), 1) 3869 3870 for t in [ 3871 torch.DoubleTensor, 3872 torch.FloatTensor, 3873 torch.IntTensor, 3874 torch.ByteTensor, 3875 ]: 3876 for y_var in (True, False): 3877 y = torch.randint(5, (5, 5), dtype=t.dtype) 3878 y = Variable(y) if y_var else y 3879 self.assertIsInstance(x.type(t), t) 3880 self.assertIsInstance(x.type_as(y), t) 3881 # TODO: t.dtype should work 3882 t_dtype = t().dtype 3883 self.assertIsInstance(x.type(t_dtype), t) 3884 self.assertIs(t_dtype, x.type(t_dtype).dtype) 3885 self.assertEqual(y.data_ptr(), y.type(t).data_ptr()) 3886 if torch.cuda.is_available(): 3887 for x_cuda in (True, False): 3888 for y_cuda in (True, False): 3889 x_c = x.cuda() if x_cuda else x 3890 y_c = y.cuda() if y_cuda else y 3891 _, y_type = y_c.type().rsplit(".", 1) 3892 y_typestr = ("torch.cuda." if y_cuda else "torch.") + y_type 3893 self.assertEqual(y_c.type(), x_c.type(y_typestr).type()) 3894 self.assertIs(y_c.dtype, x_c.type(y_c.dtype).dtype) 3895 self.assertEqual( 3896 y_c.data_ptr(), 3897 y_c.cuda().data_ptr() if y_cuda else y_c.data_ptr(), 3898 ) 3899 3900 self._test_type_conversion_backward(lambda x: x) 3901 if torch.cuda.is_available(): 3902 self._test_type_conversion_backward(lambda x: x.cuda()) 3903 if torch.cuda.device_count() >= 2: 3904 # one of these has to be the non-default device 3905 self._test_type_conversion_backward(lambda x: x.cuda(0)) 3906 self._test_type_conversion_backward(lambda x: x.cuda(1)) 3907 3908 def test_isolated_node(self): 3909 x = torch.randn(5, 5, requires_grad=True) 3910 y = torch.randn(5, 5, requires_grad=True) 3911 3912 a = x + y 3913 b = torch.max(a, 1, True)[1].repeat(1, 5).double() 3914 o = (b + a).sum() 3915 o.backward() 3916 3917 def test_shape(self): 3918 x = torch.randn(3, 4) 3919 self.assertEqual(2, len(x.shape)) 3920 self.assertEqual(x.shape[0], 3) 3921 self.assertEqual(x.shape[1], 4) 3922 3923 def test_numpy_requires_grad(self): 3924 x = torch.randn(2, 2, requires_grad=True) 3925 err_msg_outputs = r"Can't call numpy\(\) on Tensor that requires grad. Use tensor.detach\(\).numpy\(\) instead." 3926 with self.assertRaisesRegex(RuntimeError, err_msg_outputs): 3927 x.numpy() 3928 3929 with torch.no_grad(): 3930 x.numpy() 3931 3932 x = torch.randn(2, 2) 3933 x.numpy() 3934 3935 with torch.no_grad(): 3936 x.numpy() 3937 3938 def test_return_leaf(self): 3939 class Identity(Function): 3940 @staticmethod 3941 def forward(ctx, a, b): 3942 return a, a + b 3943 3944 @staticmethod 3945 def backward(ctx, grad_a, grad_b): 3946 return grad_a + grad_b, grad_b 3947 3948 hook_called = [False] 3949 x = torch.randn(5, 5, requires_grad=True) 3950 y = torch.randn(5, 5, requires_grad=True) 3951 3952 q, p = Identity.apply(x, y) 3953 3954 # Make sure hooks only receive grad from usage of q, not x. 3955 def hook(grad): 3956 hook_called[0] = True 3957 self.assertEqual(grad, torch.ones(5, 5)) 3958 3959 q.register_hook(hook) 3960 (q + p + x).sum().backward() 3961 self.assertEqual(x.grad, torch.ones(5, 5) * 3) 3962 self.assertEqual(y.grad, torch.ones(5, 5)) 3963 self.assertTrue(hook_called[0]) 3964 3965 def test_return_leaf_inplace(self): 3966 class Inplace(InplaceFunction): 3967 @staticmethod 3968 def forward(ctx, a, b): 3969 ctx.mark_dirty(a) 3970 return a.add_(b), b + 2 3971 3972 @staticmethod 3973 def backward(ctx, grad_a, grad_b): 3974 return grad_a, grad_a + grad_b 3975 3976 x = torch.randn(5, 5) 3977 y = torch.randn(5, 5, requires_grad=True) 3978 3979 q, p = Inplace.apply(x, y) 3980 self.assertIs(q, x) 3981 self.assertIs(q.grad_fn.__class__, Inplace._backward_cls) 3982 self.assertTrue(q.requires_grad) 3983 q.sum().backward() 3984 self.assertEqual(y.grad, torch.ones(5, 5)) 3985 3986 def test_leaf_assignment(self): 3987 x = torch.randn(5, 5) 3988 y = torch.randn(5, requires_grad=True) 3989 z = torch.randn(5, requires_grad=True) 3990 3991 x[0] = y 3992 x[1] = 2 * z 3993 self.assertTrue(x.requires_grad) 3994 self.assertIsNot(x.grad_fn, None) 3995 x.sum().backward() 3996 self.assertEqual(y.grad, torch.ones(5)) 3997 self.assertEqual(z.grad, torch.ones(5) * 2) 3998 3999 def test_no_grad_assignment(self): 4000 x = torch.randn(5, 5, requires_grad=True) 4001 y = torch.randn(5) 4002 with torch.no_grad(): 4003 x[0] = y 4004 4005 self.assertTrue(x.requires_grad) 4006 self.assertIsNone(x.grad_fn) 4007 4008 def test_no_grad_modifies_version(self): 4009 x = torch.randn(5, requires_grad=True) 4010 y = torch.randn(5, requires_grad=True) 4011 z = (x * y).sum() 4012 with torch.no_grad(): 4013 x *= 2 4014 self.assertRaisesRegex( 4015 RuntimeError, "modified by an inplace operation", lambda: z.backward() 4016 ) 4017 4018 def test_increment_version(self): 4019 a = torch.rand(5, requires_grad=True) 4020 v = a._version 4021 torch.autograd.graph.increment_version(a) 4022 self.assertEqual(a._version, v + 1) 4023 4024 a = torch.zeros(5, dtype=torch.int) 4025 v = a._version 4026 torch.autograd.graph.increment_version(a) 4027 self.assertEqual(a._version, v + 1) 4028 4029 with torch.inference_mode(): 4030 a = torch.rand(5, requires_grad=True) 4031 # does not error 4032 torch.autograd.graph.increment_version(a) 4033 4034 # does not error 4035 torch.autograd.graph.increment_version(a) 4036 4037 def test_no_grad_input(self): 4038 class MyFunction(Function): 4039 @staticmethod 4040 def forward(self, x): 4041 return x 4042 4043 @staticmethod 4044 def backward(self, grad_output): 4045 return grad_output 4046 4047 x = torch.randn(5, requires_grad=True) 4048 with torch.no_grad(): 4049 y = MyFunction.apply(x) 4050 4051 self.assertTrue(x.requires_grad) 4052 self.assertIsNone(y.grad_fn) 4053 4054 def test_backward_copy(self): 4055 # This tests checks backward engine for a very subtle bug that appreared 4056 # in one of the initial versions of autograd. Gradients tensors were 4057 # simply stored in lists while the function waited for all its gradients 4058 # to be computed. However, sometimes an output was used multiple times, 4059 # so the gradients needed to be summed. Engine used to keep a need_copy 4060 # set of tensors that will need a clone upon next addition and removed 4061 # them from the set as soon as the clone was performed. However, this 4062 # could lead to incorrect results if the same gradient tensor was 4063 # buffered in three places in the graph: 4064 # 1. When accumulating gradients in one of these places it was cloned 4065 # and removed from need_copy set. 4066 # 2. When accumulating in second place, it wasn't in the need_copy set, 4067 # so the gradients were simply accumulated in-place (which already 4068 # modified the grad in 3rd place) 4069 # 3. When accumulating in the third place, it wasn't in the need_copy set 4070 # as well, so the incoming gradient was summed in-place, yielding 4071 # incorrect results in all functions, except the first one. 4072 x = torch.ones(5, 5, requires_grad=True) 4073 y = torch.ones(5, 5, requires_grad=True) 4074 # Simulate that we're in the middle of the graph 4075 a = x + 2 4076 b = y + 2 4077 c = x + 2 4078 # This op will just return grad_output two times in backward 4079 add1 = a + b 4080 add2 = add1 + c 4081 # Simulate a long branch, so grad_output will get buffered. 4082 for _ in range(4): 4083 a = a * 2 4084 b = b * 2 4085 c = c * 2 4086 branch = a + b + c 4087 out = add2 + branch 4088 # expected gradients are: 4089 # for x: 34 (16 from final a, 16 from final c, 2 from add2) 4090 # for y: 17 (16 from final b, 1 from add2) 4091 grad_output = torch.ones(5, 5) 4092 out.backward(grad_output) 4093 self.assertEqual(x.grad, torch.ones(5, 5) * 34) 4094 self.assertEqual(y.grad, torch.ones(5, 5) * 17) 4095 4096 def test_save_none_for_backward(self): 4097 test_case = self 4098 4099 class MyFn(Function): 4100 @staticmethod 4101 def forward(ctx, input): 4102 ctx.save_for_backward(None, input, None) 4103 return input * input 4104 4105 @staticmethod 4106 def backward(ctx, grad_output): 4107 n1, input, n2 = ctx.saved_tensors 4108 test_case.assertIsNone(n1) 4109 test_case.assertIsNone(n2) 4110 return 2 * input * grad_output 4111 4112 x = torch.randn(5, 5, requires_grad=True) 4113 y = MyFn.apply(x) 4114 y.sum().backward() 4115 self.assertEqual(x.grad, 2 * x) 4116 4117 def test_too_many_grads(self): 4118 class MyFn(Function): 4119 @staticmethod 4120 def forward(ctx, input): 4121 return input 4122 4123 @staticmethod 4124 def backward(ctx, grad_output): 4125 return grad_output, None, None 4126 4127 x = torch.randn(5, 5, requires_grad=True) 4128 y = MyFn.apply(x) 4129 y.sum().backward() 4130 self.assertEqual(x.grad, torch.ones_like(x)) 4131 4132 def test_pickle(self): 4133 x = torch.randn(10, 10, requires_grad=True) 4134 y = torch.randn(10, 10, requires_grad=False) 4135 4136 def assert_strict_equal(var1, var2): 4137 self.assertEqual(var1, var2) 4138 self.assertEqual(var1.requires_grad, var2.requires_grad) 4139 4140 serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)] 4141 for dump in serialized: 4142 xc, yc = pickle.loads(dump) 4143 assert_strict_equal(xc, x) 4144 assert_strict_equal(yc, y) 4145 4146 @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 4147 def test_dep_nograd(self): 4148 class F1(Function): 4149 @staticmethod 4150 def forward(ctx, input): 4151 out = torch.randn(input.size()) 4152 ctx.mark_non_differentiable(out) 4153 return input, out 4154 4155 @staticmethod 4156 def backward(ctx, grad_output, ignored): 4157 return grad_output 4158 4159 class F2(Function): 4160 @staticmethod 4161 def forward(ctx, input, ignored): 4162 return input 4163 4164 @staticmethod 4165 def backward(ctx, grad_output): 4166 return grad_output, None 4167 4168 x = torch.randn(5, requires_grad=True) 4169 a, b = F1.apply(x) 4170 b = b + 1 # separate F1 from F2 by another op 4171 self.assertTrue(a.requires_grad) 4172 self.assertFalse(b.requires_grad) 4173 c = F2.apply(a, b) 4174 c.backward(torch.ones(c.size())) 4175 self.assertEqual(x.grad, torch.ones(x.size())) 4176 4177 def test_set_grad_enabled(self): 4178 x = torch.tensor([1.0], requires_grad=True) 4179 with torch.set_grad_enabled(False): 4180 y = x * 2 4181 self.assertFalse(y.requires_grad) 4182 with torch.set_grad_enabled(True): 4183 y = x * 2 4184 self.assertTrue(y.requires_grad) 4185 with torch.set_grad_enabled(False): 4186 torch.set_grad_enabled(True) 4187 y = x * 2 4188 self.assertTrue(y.requires_grad) 4189 4190 def test_set_grad_enabled_wraps(self): 4191 for decorator in [True, False]: 4192 with torch.enable_grad(): 4193 self.assertTrue(torch.is_grad_enabled()) 4194 4195 if decorator: 4196 # This should not mutate the global grad mode! 4197 @torch.set_grad_enabled(False) 4198 def inner_func(x): 4199 return x.sin() 4200 4201 else: 4202 4203 def inner_func(x): 4204 return x.sin() 4205 4206 # This is non-idiomatic usage! 4207 # More idiomatic usage: torch.set_grad_enabled(False)(inner_func) 4208 obj = torch.set_grad_enabled(False) 4209 self.assertTrue(not torch.is_grad_enabled()) 4210 4211 # this will consume the set_grad_enabled global mutation! 4212 inner_func = obj(inner_func) 4213 self.assertTrue(torch.is_grad_enabled()) 4214 4215 self.assertTrue(torch.is_grad_enabled()) 4216 4217 x = torch.zeros(1, requires_grad=True) 4218 self.assertTrue(not inner_func(x).requires_grad) 4219 4220 def test_simple_reentrant(self): 4221 y_data = torch.randn(2, 2) 4222 4223 class Reenter(Function): 4224 @staticmethod 4225 def forward(ctx, x): 4226 with torch.enable_grad(): 4227 ctx.x = Variable(x, requires_grad=True) 4228 ctx.y = Variable(y_data, requires_grad=True) 4229 ctx.output_var = ctx.x * ctx.y 4230 return ctx.output_var.detach() 4231 4232 @staticmethod 4233 def backward(ctx, grad_output): 4234 with torch.enable_grad(): 4235 ctx.output_var.sum().backward() 4236 return ctx.x.grad * grad_output 4237 4238 # Reentrant starts on CPU thread, finishs on GPU thread 4239 x = torch.randn(2, 2, requires_grad=True) 4240 out = Reenter.apply(x) 4241 out.sum().backward() 4242 self.assertEqual(x.grad, y_data) 4243 4244 def test_reentrant_child_error(self): 4245 # Parent graph. 4246 a = torch.rand(3, 3, requires_grad=True) 4247 c = a * a 4248 4249 # Reentrant child graph. 4250 b = torch.rand(3, 3, requires_grad=True) 4251 e = b * b 4252 f = TestAutograd.SimulateBackwardError.apply(e) 4253 reentrant_root = f.sum() 4254 4255 class ReentrantFunc(Function): 4256 @staticmethod 4257 def forward(ctx, inp): 4258 return inp.clone() 4259 4260 @staticmethod 4261 def backward(ctx, grad): 4262 # Reentrant backward in child will throw an error. 4263 reentrant_root.backward() 4264 return grad 4265 4266 d = ReentrantFunc.apply(c) 4267 with self.assertRaisesRegex(Exception, "Simulate error"): 4268 d.sum().backward() 4269 4270 def test_var_mean_differentiable(self): 4271 dim = [2, 4] 4272 keepdim = False 4273 input1 = torch.randn(3, 4, 5, 6, 2, 3, requires_grad=True) 4274 input2 = deepcopy(input1) 4275 var1, mean1 = torch.var_mean(input1, dim=dim, keepdim=keepdim) 4276 var2 = input2.var(dim=dim, keepdim=keepdim) 4277 mean2 = input2.mean(dim=dim, keepdim=keepdim) 4278 grad = torch.randn(3, 4, 6, 3, requires_grad=True) 4279 4280 r1 = var1 * var1 * mean1 * mean1 4281 r2 = var2 * var2 * mean2 * mean2 4282 self.assertEqual(r1, r2, rtol=0.01, atol=0.0) 4283 4284 torch.autograd.backward(r1, grad) 4285 torch.autograd.backward(r2, grad) 4286 self.assertEqual(input1.grad, input2.grad, rtol=0.01, atol=0.0) 4287 4288 @skipIfNoLapack 4289 def test_lobpcg(self): 4290 def func(k, A, largest=True, B=None): 4291 X_shape = list(A.shape) 4292 X_shape[-1] = k 4293 X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device) 4294 if A.dim() > 2: 4295 X = X.expand(X_shape) 4296 4297 D, U = torch.lobpcg(A=A, k=k, B=B, X=X, largest=largest) 4298 4299 # LOBPCG uses a random initial eigenspace approximation 4300 # if parameter `X` is not provided. 4301 # This may cause a non-deterministic behavior 4302 # when it comes to the sign of an eigenvector 4303 # (note if v is an eigenvector, so is -v), 4304 # hence we eliminate this non-determinism 4305 # by making sure that each column of U 4306 # gets multiplied by the sign of its max (in absolute value) element. 4307 # Also, gradcheck changes the content of the input by +/- eps (default to 1e-06) 4308 # to compute the numerical gradient which can also cause the signs to flip. 4309 _, idx = U.abs().max(-2, keepdim=True) 4310 sign = U.gather(-2, idx).sign() 4311 U = U * sign 4312 return D, U 4313 4314 # TODO: review if this can be ported to OpInfos or moved to test_linalg.py 4315 def run_symeig_test(k, sizes, largest=True): 4316 A = torch.rand(*sizes).double() 4317 A = (A @ A.mT) / 10 4318 A.requires_grad_(True) 4319 4320 gradcheck(lambda A: func(k, A, largest), A, check_batched_grad=False) 4321 4322 # Custom gradient vectors for better stability due to some 4323 # non-determinism in the lobpcg's forward. 4324 # Note it is not required if symeig is in forward instead (tested). 4325 D_grad = torch.rand(*A.shape[:-2], k) / 100 4326 U_grad = torch.rand(*A.shape[:-1], k) / 100 4327 gradgradcheck( 4328 lambda A: func(k, A, largest), 4329 A, 4330 [D_grad, U_grad], 4331 atol=1e-4, 4332 check_batched_grad=False, 4333 ) 4334 4335 # check whether A.grad is symmetric 4336 A = A.detach().requires_grad_(True) 4337 D, U = func(k, A, largest) 4338 (D.sum() + U.sum()).backward() 4339 self.assertEqual(A.grad, A.grad.mT) 4340 4341 for largest in [True, False]: 4342 run_symeig_test(1, (6, 6), largest=largest) 4343 run_symeig_test(1, (2, 6, 6), largest=largest) 4344 run_symeig_test(1, (2, 2, 6, 6), largest=largest) 4345 run_symeig_test(2, (6, 6), largest=largest) 4346 run_symeig_test(2, (2, 6, 6), largest=largest) 4347 run_symeig_test(2, (2, 2, 6, 6), largest=largest) 4348 run_symeig_test(3, (9, 9), largest=largest) 4349 run_symeig_test(3, (2, 9, 9), largest=largest) 4350 run_symeig_test(3, (2, 2, 9, 9), largest=largest) 4351 4352 def test_variable_traverse(self): 4353 def get_out_and_unrefed_cycle(): 4354 inp = torch.randn(10, requires_grad=True) 4355 tmp = inp.view(10, 1) 4356 out = tmp.view(10) 4357 4358 # Create a reference cycle that contains an 4359 # intermediary Variable in the graph 4360 my_list = [] 4361 my_list.append(tmp) 4362 my_list.append(my_list) 4363 4364 return out 4365 4366 out = get_out_and_unrefed_cycle() 4367 gc.collect() 4368 # This will segfault if things have been erroneously released 4369 out.backward(torch.randn(out.size())) 4370 4371 # TODO: review porting these to OpInfo tests 4372 def test_pow_zero_tensor_gradient(self): 4373 def run_test(input_size, exponent): 4374 input = torch.zeros(*input_size, requires_grad=True) 4375 input.pow(exponent).sum().backward() 4376 self.assertEqual(input.grad.abs().sum(), 0) 4377 4378 run_test((10,), torch.zeros(10)) 4379 run_test((10, 10), torch.zeros(10, 10)) 4380 run_test((10,), 0) 4381 4382 def test_current_graph_task_id(self): 4383 id = [-1] 4384 4385 def hook(_): 4386 id[0] = torch._C._current_graph_task_id() 4387 4388 t = torch.tensor(1.0, requires_grad=True).clone() 4389 t.register_hook(hook) 4390 4391 t.backward(retain_graph=True) 4392 base = id[0] 4393 t.backward(retain_graph=True) 4394 self.assertEqual(id[0] - base, 1) 4395 t.backward(retain_graph=True) 4396 self.assertEqual(id[0] - base, 2) 4397 4398 self.assertEqual(torch._C._current_graph_task_id(), -1) 4399 4400 def test_current_graph_task_execution_order(self): 4401 predicted = [None] 4402 4403 def hook(_): 4404 predicted[0] = torch._C._current_graph_task_execution_order() 4405 4406 def names(nodes): 4407 return ", ".join([node.name().split(" ")[-1] for node in nodes]) + "\n" 4408 4409 def grad_fns(*tensors): 4410 # or grad accumulator 4411 out = [] 4412 for t in tensors: 4413 if t.requires_grad and t.grad_fn is None: 4414 out.append(t.clone().grad_fn.next_functions[0][0]) 4415 else: 4416 out.append(t.grad_fn) 4417 return out 4418 4419 actual = [] 4420 4421 def register_logging_hooks(*tensors): 4422 # register hooks that log the order in which they are called 4423 def get_hook(i): 4424 def hook(t_): 4425 actual.append(tensors[i]) 4426 4427 return hook 4428 4429 for i, t in enumerate(tensors): 4430 t.register_hook(get_hook(i)) 4431 4432 # Basic example: single path 4433 t = torch.tensor(1.0, requires_grad=True).clone().sin().exp() 4434 t.register_hook(hook) 4435 with torch.autograd.set_multithreading_enabled(False): 4436 t.backward() 4437 self.assertExpectedInline( 4438 names(predicted[0]), 4439 """\ 4440ExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad 4441""", 4442 ) 4443 4444 # We don't exactly follow sequence_nr order 4445 a = torch.tensor(1.0, requires_grad=True) 4446 b = torch.tensor(2.0, requires_grad=True) 4447 c = b.sin() 4448 d = a.cos() 4449 out = c * d 4450 register_logging_hooks(a, b, c, d, out) 4451 out.register_hook(hook) 4452 with torch.autograd.set_multithreading_enabled(False): 4453 out.backward() 4454 self.assertEqual(predicted[0], grad_fns(*actual)) 4455 actual = [] 4456 4457 # Accumulate grad node has more than one input 4458 a = torch.tensor(1.0, requires_grad=True) 4459 b = a.sin() 4460 c = a.cos() 4461 out = b * c 4462 register_logging_hooks(a, b, c, out) 4463 out.register_hook(hook) 4464 with torch.autograd.set_multithreading_enabled(False): 4465 out.backward() 4466 self.assertEqual(predicted[0], grad_fns(*actual)) 4467 actual = [] 4468 4469 # Multiple roots are also OK 4470 a = torch.tensor(1.0, requires_grad=True) 4471 b = a * 2 4472 out = b.sin() 4473 out2 = b.cos() 4474 out3 = b.cos() 4475 register_logging_hooks(a, b, out, out2, out3) 4476 out3.register_hook(hook) 4477 with torch.autograd.set_multithreading_enabled(False): 4478 torch.autograd.grad((out, out3, out2), inputs=(a,)) 4479 self.assertExpectedInline( 4480 names(predicted[0]), 4481 """\ 4482CosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad 4483""", 4484 ) 4485 # TODO: Uncomment after update to hooks behavior 4486 # self.assertEqual(predicted[0], grad_fns(*actual)) 4487 actual = [] 4488 4489 # Case where next node is nullptr 4490 a = torch.tensor(1.0, requires_grad=True) 4491 b = a * 2 4492 out = b.sin() 4493 register_logging_hooks(a, b, out) 4494 out.register_hook(hook) 4495 with torch.autograd.set_multithreading_enabled(False): 4496 out.backward() 4497 self.assertEqual(predicted[0], grad_fns(*actual)) 4498 actual = [] 4499 4500 # Case where two `inputs` on the same path 4501 a = torch.tensor(1.0, requires_grad=True) 4502 b = a * 2 4503 out = b.sin() 4504 register_logging_hooks(a, b, out) 4505 out.register_hook(hook) 4506 with torch.autograd.set_multithreading_enabled(False): 4507 torch.autograd.grad((out,), inputs=(a, b)) 4508 self.assertEqual( 4509 names(predicted[0]), 4510 """\ 4511SinBackward0, MulBackward0, torch::autograd::AccumulateGrad 4512""", 4513 ) 4514 # TODO: Uncomment after update to hooks behavior 4515 # self.assertEqual(predicted[0], grad_fns(*actual)) 4516 actual = [] 4517 4518 # Case where `inputs` specifies a subgraph 4519 a = torch.tensor(1.0, requires_grad=True) 4520 b = torch.tensor(1.0, requires_grad=True) 4521 c = a * b 4522 out = c.sin() 4523 register_logging_hooks(a, b, c, out) 4524 out.register_hook(hook) 4525 with torch.autograd.set_multithreading_enabled(False): 4526 torch.autograd.grad((out,), inputs=(a,)) 4527 self.assertEqual( 4528 names(predicted[0]), 4529 """\ 4530SinBackward0, MulBackward0, torch::autograd::AccumulateGrad 4531""", 4532 ) 4533 # TODO: Uncomment after update to hooks behavior 4534 # self.assertEqual(predicted[0], grad_fns(*actual)) 4535 actual = [] 4536 4537 # Errors when not called in a backward 4538 with self.assertRaisesRegex( 4539 RuntimeError, "should only be called during the backward pass" 4540 ): 4541 torch._C._current_graph_task_execution_order() 4542 4543 # Errors when context manager not enabled 4544 t = torch.tensor(1.0, requires_grad=True).clone().sin().exp() 4545 t.register_hook(hook) 4546 with self.assertRaisesRegex( 4547 RuntimeError, 4548 "expects the current backward to be executed with multithreading disabled", 4549 ): 4550 t.backward() 4551 4552 def test_view_replay_enabled(self): 4553 def f(x): 4554 out = x.clone().view(-1) 4555 # mutate the view, triggering autograd view-replay logic 4556 out.add_(1) 4557 return out 4558 4559 x = torch.ones(2, 2, requires_grad=True) 4560 4561 # Test as a context manager 4562 with torch.autograd._force_original_view_tracking(False): 4563 out = f(x) 4564 self.assertTrue("AsStridedBackward" in str(out.grad_fn)) 4565 self.assertFalse(torch.autograd.is_view_replay_enabled()) 4566 self.assertFalse(torch.autograd.is_view_replay_enabled()) 4567 4568 with torch.autograd._force_original_view_tracking(True): 4569 out = f(x) 4570 self.assertTrue("ViewBackward" in str(out.grad_fn)) 4571 self.assertTrue(torch.autograd.is_view_replay_enabled()) 4572 out = f(x) 4573 self.assertTrue("AsStridedBackward" in str(out.grad_fn)) 4574 self.assertFalse(torch.autograd.is_view_replay_enabled()) 4575 4576 with torch.autograd._force_original_view_tracking(False): 4577 torch.autograd._force_original_view_tracking(True) 4578 out = f(x) 4579 self.assertTrue("ViewBackward" in str(out.grad_fn)) 4580 self.assertTrue(torch.autograd.is_view_replay_enabled()) 4581 self.assertFalse(torch.autograd.is_view_replay_enabled()) 4582 4583 # Test as a function 4584 torch.autograd._force_original_view_tracking(False) 4585 out = f(x) 4586 self.assertTrue("AsStridedBackward" in str(out.grad_fn)) 4587 self.assertFalse(torch.autograd.is_view_replay_enabled()) 4588 4589 torch.autograd._force_original_view_tracking(True) 4590 out = f(x) 4591 self.assertTrue("ViewBackward" in str(out.grad_fn)) 4592 self.assertTrue(torch.autograd.is_view_replay_enabled()) 4593 4594 def test_unsafe_set_version_counter(self): 4595 x = torch.ones(2, requires_grad=True).clone() 4596 x.add_(1) 4597 x.add_(2) 4598 self.assertEqual(2, x._version) 4599 with torch.autograd._unsafe_preserve_version_counter(x): 4600 x.mul_(2) 4601 x.mul_(3) 4602 # version counter doesn't change inside of the context manager 4603 self.assertEqual(2, x._version) 4604 4605 torch._C._autograd._unsafe_set_version_counter(x, 0) 4606 self.assertEqual(0, x._version) 4607 with self.assertRaisesRegex(RuntimeError, "Cannot set"): 4608 torch._C._autograd._unsafe_set_version_counter(x, -1) 4609 4610 def test_current_node(self): 4611 pr = [] 4612 4613 class MyMode(TorchDispatchMode): 4614 def __torch_dispatch__(self, func, types, args, kwargs=None): 4615 node = torch._C._current_autograd_node() 4616 # Don't use node.name() here as it is not consistent on windows 4617 node_name = node.__class__.__name__ if node else "None" 4618 pr.append(f"Running {func} from within {node_name}") 4619 return func(*args, **(kwargs or {})) 4620 4621 with MyMode(): 4622 pr.append("FW") 4623 a = torch.rand(10, requires_grad=True) 4624 b = a.mul(2).div(3).sum() 4625 pr.append("BW") 4626 b.backward() 4627 pr.append("Done") 4628 4629 self.assertExpectedInline( 4630 "\n".join(pr), 4631 """\ 4632FW 4633Running aten.rand.default from within None 4634Running aten.mul.Tensor from within None 4635Running aten.div.Tensor from within None 4636Running aten.sum.default from within None 4637BW 4638Running aten.ones_like.default from within None 4639Running aten.expand.default from within SumBackward0 4640Running aten.div.Tensor from within DivBackward0 4641Running aten.mul.Tensor from within MulBackward0 4642Running aten.detach.default from within AccumulateGrad 4643Running aten.detach.default from within AccumulateGrad 4644Done""", 4645 ) 4646 4647 def test_profiler(self): 4648 x = torch.randn(10, 10) 4649 4650 with profile(use_kineto=kineto_available()) as p: 4651 self.assertTrue(torch.autograd._profiler_enabled()) 4652 y = x * 2 + 4 4653 4654 self.assertFalse(torch.autograd._profiler_enabled()) 4655 4656 names = ["aten::mul", "aten::add"] 4657 found_indices = set() 4658 for evt in p.function_events: 4659 if evt.name in names: 4660 found_indices.add(names.index(evt.name)) 4661 self.assertEqual(len(found_indices), len(names)) 4662 4663 def test_profiler_seq_nr(self): 4664 with profile(use_kineto=kineto_available()) as p: 4665 x = torch.randn(10, 10, requires_grad=True) 4666 y = torch.randn(10, 10, requires_grad=True) 4667 z = x + y 4668 s = z.sum(dim=None) 4669 s.backward() 4670 print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) 4671 # expecting aten::add, aten::sum to have the sequence numbers, 4672 # expecting the corresponding backward nodes to have the same numbers 4673 # as the forward ops 4674 autograd_ops = { 4675 ("aten::add", "Add"): [], 4676 ("aten::sum", "Sum"): [], 4677 } 4678 accumulate_ops = [] 4679 found_empty = False 4680 for e in p.function_events: 4681 for (fwd_name, bwd_name), ops in autograd_ops.items(): 4682 if e.name == fwd_name or (bwd_name in e.name and "Backward" in e.name): 4683 ops.append(e) 4684 4685 if "AccumulateGrad" in e.name: 4686 accumulate_ops.append(e) 4687 4688 # check that nested ops (e.g. empty) don't have 4689 # sequence number 4690 if e.name == "aten::empty": 4691 self.assertEqual(e.sequence_nr, -1) 4692 found_empty = True 4693 4694 for idx, ((fwd_name, bwd_name), ops) in enumerate(autograd_ops.items()): 4695 self.assertEqual(len(ops), 3) 4696 self.assertEqual(ops[0].name, fwd_name) 4697 self.assertEqual( 4698 ops[1].name, 4699 f"autograd::engine::evaluate_function: {bwd_name}Backward{idx}", 4700 ) 4701 self.assertEqual(ops[2].name, f"{bwd_name}Backward{idx}") 4702 self.assertGreaterEqual(ops[0].sequence_nr, 0) 4703 self.assertEqual(ops[1].sequence_nr, ops[0].sequence_nr) 4704 self.assertEqual(ops[2].sequence_nr, ops[0].sequence_nr) 4705 self.assertEqual(ops[0].fwd_thread, 0) 4706 self.assertEqual(ops[1].fwd_thread, ops[0].thread) 4707 self.assertEqual(ops[2].fwd_thread, ops[0].thread) 4708 self.assertTrue(found_empty) 4709 4710 def test_profiler_unboxed_only(self): 4711 x = torch.rand(3, 4) 4712 4713 with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: 4714 x.resize_([3, 2]) 4715 4716 def test_profiler_propagation(self): 4717 def foo(x): 4718 with record_function("in_foo") as rf: 4719 return x * 2 4720 4721 x = torch.rand(3, 4) 4722 traced_foo = torch.jit.trace(foo, x) 4723 4724 def bar(x): 4725 with record_function("in_bar") as rf: 4726 # we expect that profiler will be able 4727 # propagate across fork 4728 fut = torch.jit._fork(traced_foo, x) 4729 y = torch.jit._wait(fut) 4730 # note: continuation (and rf's end) can 4731 # be executed in a different thread 4732 with record_function("in_bar_after_wait") as rf2: 4733 y = y * 2 4734 return y 4735 4736 traced_bar = torch.jit.trace(bar, x) 4737 4738 with profile(use_kineto=kineto_available()) as p: 4739 traced_bar(x) 4740 4741 found_foo = False 4742 found_bar = False 4743 found_bar_after_wait = False 4744 for info in p.function_events: 4745 if info.name == "in_foo": 4746 self.assertFalse(found_foo) 4747 found_foo = True 4748 elif info.name == "in_bar": 4749 self.assertFalse(found_bar) 4750 found_bar = True 4751 elif info.name == "in_bar_after_wait": 4752 self.assertFalse(found_bar_after_wait) 4753 found_bar_after_wait = True 4754 self.assertTrue(found_foo) 4755 self.assertTrue(found_bar) 4756 self.assertTrue(found_bar_after_wait) 4757 4758 def test_record_function_callbacks(self): 4759 x = torch.randn(10, 10) 4760 with profile(use_kineto=kineto_available()) as p: 4761 with record_function("foo"): 4762 y = x * 2 + 4 4763 4764 function_events = p.function_events 4765 foo_event = next(event for event in function_events if "foo" in event.name) 4766 self.assertEqual(foo_event.count, 1) 4767 4768 def test_record_function_legacy(self): 4769 # Test the new _record_function ops work 4770 # Note: Remove once record_function uses these directly 4771 x = torch.randn(10, 10) 4772 with profile(use_kineto=kineto_available()) as p: 4773 handle = torch.ops.profiler._record_function_enter("bar", None) 4774 try: 4775 y = x * 2 + 4 4776 finally: 4777 torch.ops.profiler._record_function_exit(handle) 4778 4779 function_events = p.function_events 4780 foo_event = next(event for event in function_events if "bar" in event.name) 4781 self.assertEqual(foo_event.count, 1) 4782 4783 def test_profiler_aggregation_fake(self): 4784 events = EventList() 4785 id = [0] 4786 4787 def get_id(): 4788 id[0] = id[0] + 1 4789 return id[0] 4790 4791 # [[thread_id, [(start, end, id), ....]], ...] 4792 # Using list instead of a dict so order is guaranteed for any Python 4793 # version 4794 threads = [ 4795 [1, [(0, 1, get_id()), (1, 2, get_id())]], 4796 [0, [(0, 2, get_id()), (1, 2, get_id()), (1, 3, get_id())]], 4797 ] 4798 for thread, ranges in threads: 4799 for range in ranges: 4800 assert len(range) == 3 4801 events.append( 4802 FunctionEvent( 4803 id=range[2], 4804 node_id=0, 4805 name="", 4806 thread=thread, 4807 start_us=range[0], 4808 end_us=range[1], 4809 ) 4810 ) 4811 4812 events._populate_cpu_children() 4813 4814 # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2] 4815 # as a child of [1, 3] 4816 res = [[], [], [], [], [4]] 4817 4818 def get_children_ids(event): 4819 return [child.id for child in event.cpu_children] 4820 4821 assert [get_children_ids(event) for event in events] == res 4822 4823 def test_profiler_aggregation_table(self): 4824 """ 4825 Test if the profiling result is aggregated for `str(prof)` 4826 4827 See: https://github.com/pytorch/pytorch/issues/37500 4828 """ 4829 4830 x = torch.randn(1024) 4831 with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: 4832 torch.einsum("i->", x) 4833 4834 prof_str = str(prof) 4835 prof_table = prof.table() 4836 4837 self.assertEqual(prof_table, prof_str) 4838 4839 def test_profiler_function_event_avg(self): 4840 avg = FunctionEventAvg() 4841 avg.add( 4842 FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15) 4843 ) 4844 avg.add( 4845 FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30) 4846 ) 4847 avg.add(avg) 4848 self.assertEqual(avg.key, "foo") 4849 4850 # aggregate stats 4851 self.assertEqual(avg.count, 4) 4852 self.assertEqual(avg.cpu_time_total, 30) 4853 self.assertEqual(avg.self_cpu_time_total, 30) 4854 self.assertEqual(avg.device_time_total, 0) 4855 4856 # average stats 4857 self.assertEqual(avg.cpu_time, 7.5) 4858 self.assertEqual(avg.device_time_total, 0) 4859 4860 def test_profiler_shapes(self): 4861 print() 4862 layer1 = torch.nn.Linear(20, 30) 4863 layer2 = torch.nn.Linear(30, 40) 4864 input = torch.randn(128, 20) 4865 with profile(record_shapes=True, use_kineto=kineto_available()) as prof: 4866 layer2(layer1(input)) 4867 4868 print(prof.function_events) 4869 4870 linear_expected_shapes = [ 4871 [[128, 20], [30, 20], [30]], 4872 [[128, 30], [40, 30], [40]], 4873 ] 4874 4875 found_indices = set() 4876 for event in prof.function_events: 4877 if event.name == "aten::linear": 4878 self.assertTrue(event.input_shapes in linear_expected_shapes) 4879 found_indices.add(linear_expected_shapes.index(event.input_shapes)) 4880 self.assertEqual(len(found_indices), len(linear_expected_shapes)) 4881 4882 def test_profiler_aggregation_lstm(self): 4883 print() 4884 rnn = torch.nn.LSTM(10, 20, 2) 4885 total_time_s = 0 4886 with profile(record_shapes=True, use_kineto=kineto_available()) as prof: 4887 for i in range(20): 4888 input = torch.randn(5, 3, 10) 4889 h = torch.randn(2, 3, 20) 4890 c = torch.randn(2, 3, 20) 4891 start = time.time() 4892 rnn(input, (h, c)) 4893 end = time.time() 4894 total_time_s += end - start 4895 4896 print(prof.table(sort_by="self_cpu_time_total", row_limit=10, header="TEST")) 4897 print( 4898 prof.key_averages(group_by_input_shape=True).table( 4899 sort_by="self_cpu_time_total", row_limit=10 4900 ) 4901 ) 4902 print( 4903 prof.table( 4904 sort_by="self_cpu_time_total", 4905 row_limit=10, 4906 max_src_column_width=300, 4907 header="TEST", 4908 top_level_events_only=True, 4909 ) 4910 ) 4911 print( 4912 prof.key_averages(group_by_input_shape=True).table( 4913 sort_by="self_cpu_time_total", row_limit=10, top_level_events_only=True 4914 ) 4915 ) 4916 4917 total_time_us = ( 4918 total_time_s * 1000.0 * 1000.0 4919 ) # make it us which is profiler default 4920 print("Total time based on python measurements: ", _format_time(total_time_us)) 4921 print( 4922 f"CPU time measurement python side overhead: {(total_time_us / prof.self_cpu_time_total - 1.0) * 100.0:.2f}%" 4923 ) 4924 4925 if sys.platform != "win32": 4926 with tempfile.NamedTemporaryFile() as trace_file: 4927 prof.export_chrome_trace(trace_file.name) 4928 4929 def test_record_function(self): 4930 x = torch.randn(10, 10) 4931 4932 def forward(x): 4933 with record_function("outer"): 4934 y = x * 2 + 4 4935 with record_function("inner"): 4936 y = y - 1 4937 y = y / 1 4938 4939 forward(x) 4940 4941 with profile(use_kineto=kineto_available()) as p: 4942 forward(x) 4943 4944 events = p.function_events 4945 important_events = [ 4946 "outer", 4947 "aten::mul", 4948 "aten::add", 4949 "inner", 4950 "aten::sub", 4951 "aten::div", 4952 ] 4953 idx = 0 4954 for info in events: 4955 if info.name == important_events[idx]: 4956 idx = idx + 1 4957 if idx == len(important_events): 4958 break 4959 self.assertEqual(idx, len(important_events)) 4960 4961 # We can also use record_function to decorate arbitrary function 4962 @record_function("my_func") 4963 def f(x, y): 4964 return x + y 4965 4966 with profile(use_kineto=kineto_available()) as p: 4967 f(1, 2) 4968 4969 self.assertTrue("my_func" in str(p)) 4970 4971 def test_record_function_multithreaded(self): 4972 rf = record_function("outer") 4973 rf.__enter__() 4974 with record_function("inner"): 4975 # test that exiting the record function after starting another one 4976 # doesn't throw. 4977 rf.__exit__(None, None, None) 4978 4979 with record_function("inner"): 4980 rf.__enter__() 4981 # test that exiting the record function after ending another one 4982 # doesn't throw. 4983 rf.__exit__(None, None, None) 4984 4985 def test_dir(self): 4986 x = torch.randn(10, 10) 4987 keys = dir(x) 4988 self.assertIn("shape", keys) 4989 4990 # real and imag are only implemented for complex tensors. 4991 y = torch.randn(10, 10, dtype=torch.cfloat) 4992 imag_key = "imag" 4993 self.assertRaises(RuntimeError, lambda: hasattr(x, imag_key)) 4994 self.assertTrue(hasattr(y, imag_key)) 4995 keys.remove(imag_key) 4996 4997 for key in keys: 4998 self.assertTrue(hasattr(x, key)) 4999 5000 def test_inplace_on_view_saved_output(self): 5001 # Test an in-place operation on a view in which the in-place op saves 5002 # its output. Previously, this created a reference cycle. 5003 dealloc = [0] 5004 5005 class IncrementOnDelete: 5006 def __del__(self): 5007 dealloc[0] += 1 5008 5009 def test(): 5010 root = torch.randn(3, 3, requires_grad=True) 5011 copy = root.clone() 5012 copy.grad_fn.register_hook(IncrementOnDelete()) 5013 view = copy.view(9) 5014 torch.nn.functional.relu(view, inplace=True) 5015 5016 test() 5017 self.assertEqual(dealloc[0], 1) 5018 5019 def test_inplace_on_view_leaf_errors(self): 5020 # Issue #21875: Fail faster (when we try to modify the view vs. in backward()) 5021 x = torch.zeros(1, requires_grad=True) 5022 y = x.view_as(x) 5023 with self.assertRaisesRegex( 5024 RuntimeError, 5025 "a view of a leaf Variable that " 5026 "requires grad is being used in " 5027 "an in-place operation.", 5028 ): 5029 y.add_(1) 5030 5031 def test_inplace_on_view_backward(self): 5032 # Issue #10532: Make sure that this does not raise RuntimeError. 5033 net = nn.Sequential(nn.InstanceNorm2d(2), nn.ReLU(True)) 5034 5035 x = torch.tensor([[[[1.0, 1.0]]]], requires_grad=True) 5036 (g,) = torch.autograd.grad( 5037 net(x).pow(2), [x], grad_outputs=x.new_ones(x.shape), create_graph=True 5038 ) 5039 torch.autograd.grad(g.sum(), [x]) 5040 self.assertEqual(x, torch.tensor([[[[1.0, 1.0]]]])) 5041 5042 # https://discuss.pytorch.org/t/freeing-buffer-strange-behavior/31955/8 5043 inputs = torch.ones((1, 3, 256, 256), requires_grad=True) 5044 5045 tmp1 = (inputs + 1).view_as(inputs) 5046 tmp2 = torch.nn.functional.threshold(tmp1, 0.0, 0.0, True) 5047 prob_interpolated = torch.sigmoid(tmp2) 5048 5049 gradients = torch.autograd.grad( 5050 outputs=prob_interpolated, 5051 inputs=inputs, 5052 grad_outputs=torch.ones(prob_interpolated.size()), 5053 create_graph=True, 5054 retain_graph=True, 5055 )[0] 5056 5057 gradient_penalty = gradients.sum() 5058 gradient_penalty.backward() 5059 5060 fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0] 5061 self.assertEqual(fn.name(), "ThresholdBackwardBackward0") 5062 5063 def test_inplace_on_view_weak_grad_fn(self): 5064 # Issue 23502: Test that b's grad_fn is preserved. 5065 a = torch.arange(10.0, requires_grad=True) 5066 5067 b = a.narrow(0, 0, 2).clone().view(-1) 5068 b.relu_() 5069 5070 c = b.clone() 5071 del b 5072 gc.collect() 5073 5074 s = c.sum() 5075 s.backward() 5076 self.assertEqual(s, torch.tensor(1.0)) 5077 5078 # Issue #21875: Fail faster (when we try to modify the view vs. in backward()) 5079 a = torch.rand(10, requires_grad=True).narrow(0, 0, 10) 5080 with self.assertRaises(RuntimeError): 5081 b = a.relu_() 5082 5083 def test_out_variant_raises_when_inputs_require_grad(self): 5084 a = torch.randn(2, 2, requires_grad=True) 5085 b = torch.randn(2, 2, requires_grad=True) 5086 x = torch.zeros_like(a) 5087 5088 # out=... functions don't support automatic differentiation currently 5089 self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x)) 5090 5091 # the inputs can require grad if we're in no_grad() mode 5092 with torch.no_grad(): 5093 torch.mul(a, b, out=x) 5094 self.assertEqual(x, a * b) 5095 5096 a = torch.randn(2, 2) 5097 b = torch.randn(2, 2) 5098 x = torch.zeros(2, 2, requires_grad=True) 5099 # we should throw an exception if the output requires grad 5100 self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x)) 5101 5102 def test_anomaly_detect_nan(self): 5103 size = 10 5104 5105 class MyFunc(Function): 5106 @staticmethod 5107 def forward(ctx, inp1, inp2, fail_0th): 5108 ctx.fail_0th = fail_0th 5109 return inp1.sum(0, keepdim=True) 5110 5111 @staticmethod 5112 def backward(ctx, gO): 5113 gI = gO.clone().expand(size) 5114 gI[0] = 0 5115 gI[0] /= 0 # Generate a nan 5116 if ctx.fail_0th: 5117 return gI, None, None 5118 else: 5119 return None, gI, None 5120 5121 inp = torch.rand(size, requires_grad=True) 5122 out = MyFunc.apply(inp, inp, True) 5123 out.backward() # Should not fail 5124 5125 inp = torch.rand(size, requires_grad=True) 5126 out = MyFunc.apply(inp, inp, True) 5127 with self.assertRaisesRegex( 5128 RuntimeError, 5129 "Function 'MyFuncBackward' returned nan values in its 0th output.", 5130 ): 5131 with warnings.catch_warnings(record=True) as w: 5132 with detect_anomaly(): 5133 out.backward() 5134 self.assertIn("No forward pass information", str(w[0].message)) 5135 5136 inp = torch.rand(size, requires_grad=True) 5137 with self.assertRaisesRegex( 5138 RuntimeError, 5139 "Function 'MyFuncBackward' returned nan values in its 1th output.", 5140 ): 5141 with warnings.catch_warnings(record=True) as w: 5142 with detect_anomaly(): 5143 out = MyFunc.apply(inp, inp, False) 5144 out.backward() 5145 self.assertIn("MyFunc.apply", str(w[0].message)) 5146 5147 def test_calculate_shape_util(self): 5148 out = torch.randn(10, 5, requires_grad=True) 5149 grad = torch.randn(5, 10, requires_grad=True) 5150 out_shape, grad_shape = _calculate_shape(out, grad, False) 5151 5152 assert out_shape == torch.Size([10, 5]) 5153 assert grad_shape == torch.Size([5, 10]) 5154 5155 out = torch.nested.as_nested_tensor( 5156 [ 5157 torch.randn(10, 5, requires_grad=True), 5158 torch.randn(10, 5, requires_grad=True), 5159 torch.randn(10, 5, requires_grad=True), 5160 ] 5161 ) 5162 grad = torch.nested.as_nested_tensor( 5163 [ 5164 torch.randn(5, 10, requires_grad=True), 5165 torch.randn(5, 10, requires_grad=True), 5166 ] 5167 ) 5168 out_shape, grad_shape = _calculate_shape(out, grad, False) 5169 5170 assert torch.equal(out_shape, torch.tensor([[10, 5], [10, 5], [10, 5]])) 5171 assert torch.equal(grad_shape, torch.tensor([[5, 10], [5, 10]])) 5172 5173 def test_nested_anomaly_detect_nan(self): 5174 size = 10 5175 5176 class MyFunc(Function): 5177 @staticmethod 5178 def forward(ctx, inp1, fail_0th): 5179 ctx.fail_0th = fail_0th 5180 ctx.save_for_backward(inp1) 5181 return inp1.sum(0, keepdim=True) 5182 5183 @staticmethod 5184 def backward(ctx, gO): 5185 (inp,) = ctx.saved_tensors 5186 fail_0th = ctx.fail_0th 5187 g = gO.clone().expand(size) 5188 gI = MyFunc2.apply(g * inp, g + inp, fail_0th) 5189 return gI, None 5190 5191 class MyFunc2(Function): 5192 @staticmethod 5193 def forward(ctx, inp1, inp2, fail_0th): 5194 ctx.fail_0th = fail_0th 5195 return inp1 * 2.0 + inp2 5196 5197 @staticmethod 5198 def backward(ctx, gO): 5199 fail_0th = ctx.fail_0th 5200 g1 = gO.clone() 5201 g2 = gO.clone() 5202 g1[0] = 0 5203 g2[0] = 0 5204 # generate a nan 5205 if fail_0th: 5206 g1[0] /= 0 5207 else: 5208 g2[0] /= 0 5209 return g1, g2, None 5210 5211 inp = torch.rand(size, requires_grad=True) 5212 out = MyFunc.apply(inp, True) 5213 (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5214 gsum = ginp.sum() 5215 gsum.backward() # should not fail 5216 5217 inp = torch.rand(size, requires_grad=True) 5218 out = MyFunc.apply(inp, True) 5219 (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5220 gsum = ginp.sum() 5221 with warnings.catch_warnings(record=True) as w: 5222 with self.assertRaisesRegex( 5223 RuntimeError, 5224 "Function 'MyFunc2Backward' returned nan values in its 0th output.", 5225 ): 5226 with detect_anomaly(): 5227 gsum.backward() 5228 self.assertIn("No forward pass information", str(w[1].message)) 5229 5230 inp = torch.rand(size, requires_grad=True) 5231 with warnings.catch_warnings(record=True) as w: 5232 with self.assertRaisesRegex( 5233 RuntimeError, 5234 "Function 'MyFunc2Backward' returned nan values in its 1th output.", 5235 ): 5236 with detect_anomaly(): 5237 out = MyFunc.apply(inp, False) 5238 (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5239 gsum = ginp.sum() 5240 gsum.backward() 5241 self.assertIn("MyFunc2.apply", str(w[1].message)) 5242 self.assertIn("MyFunc.apply", str(w[2].message)) 5243 5244 def test_anomaly_grad_warnings(self): 5245 # PyTorch won't throw warnings if there is an error 5246 # but we'd want to at least see them in stderr 5247 5248 class StdErrDiverter: 5249 def __enter__(self): 5250 self.stderr_orig = sys.stderr 5251 self.stderr_new = io.StringIO() 5252 sys.stderr = self.stderr_new 5253 return self 5254 5255 def __exit__(self, *args): 5256 self.captured = self.stderr_new.getvalue() 5257 sys.stderr = self.stderr_orig 5258 5259 # if the warnings don't throw, they will be handled as regular warnings 5260 with self.assertRaisesRegex( 5261 RuntimeError, 5262 "one of the variables needed for gradient computation has been " 5263 "modified by an inplace operation", 5264 ): 5265 with warnings.catch_warnings(record=True) as w: 5266 with detect_anomaly(): 5267 a = torch.randn(5, requires_grad=True) 5268 d1 = a + 1 5269 d2 = d1**2 5270 d1 += 1 5271 torch.autograd.grad(d2.sum(), a) 5272 5273 self.assertEqual(len(w), 2) 5274 self.assertIn("Anomaly Detection has been enabled", str(w[0].message)) 5275 self.assertIn("Error detected in PowBackward0", str(w[1].message)) 5276 5277 # if the warning throws, it will be printed to sys.stderr 5278 with self.assertRaisesRegex( 5279 RuntimeError, 5280 "one of the variables needed for gradient computation has been " 5281 "modified by an inplace operation", 5282 ): 5283 with warnings.catch_warnings(record=True) as w: 5284 with detect_anomaly(): 5285 warnings.simplefilter("error") 5286 with StdErrDiverter() as s: 5287 a = torch.randn(5, requires_grad=True) 5288 d1 = a + 1 5289 d2 = d1**2 5290 d1 += 1 5291 torch.autograd.grad(d2.sum(), a) 5292 5293 self.assertEqual(len(w), 1) 5294 self.assertIn("Anomaly Detection has been enabled", str(w[0].message)) 5295 self.assertIn("Error detected in PowBackward0", s.captured) 5296 5297 def test_anomaly_assign_parent_cleanup(self): 5298 # Test that python objects created are properly cleaned up when assign_parent is called 5299 5300 def get_ref(): 5301 # we use torch.exp here but any function that will construct a new node in its 5302 # backward call in grad mode will work 5303 x = torch.randn(2, 2, requires_grad=True) 5304 t = x.exp() 5305 5306 # ExpBackward calls mul, creating the MulBackward node when create_graph=True. 5307 # In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to 5308 # MulBackward's anomaly metadata dict, creating the following reference chain: 5309 # 5310 # grad -> MulBackward -> PyObject -> ExpBackward 5311 # 5312 with detect_anomaly(): 5313 grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True) 5314 5315 # We add a weak reference to a new Foo object, which we insert into ExpBackward's metadata dict 5316 # 5317 # (PyObject) -> ExpBackward -> dict -> *Foo* 5318 # t ----^ WeakRef ---^ 5319 # 5320 # We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed 5321 # We can test this by seeing whether Foo is not kept alive once t is destroyed 5322 class Foo: 5323 pass 5324 5325 my_obj = Foo() 5326 meta_dict = t.grad_fn.metadata 5327 meta_dict[0] = my_obj 5328 ref = weakref.ref(my_obj) 5329 return t, ref 5330 5331 t, ref = get_ref() 5332 self.assertIsNotNone(ref()) 5333 del t 5334 self.assertIsNone(ref()) 5335 5336 def test_nested_anomaly_printstack_cleanup(self): 5337 # Test if metadata dict PyObject is properly destroyed 5338 def get_ref(): 5339 # This is similar to the construction in test_anomaly_assign_parent_cleanup: 5340 # 5341 # MyFuncBackward2 -> PyObject -> MyFuncBackward -> dict -> Foo 5342 # out ---^ WeakRef ---^ 5343 # 5344 # We want to check that Foo is still properly destroyed even when MyFunc2Backward's 5345 # AnomalyMetadata calls printstack, which does some python object manipulation. 5346 # 5347 # You might be wondering why we still have to test_anomaly_assign_parent_cleanup, 5348 # since if PyObject is not destroyed here, wouldn't this test would detect that also? 5349 # The answer is that custom function's PyObject (THPFunction) actually only hold 5350 # a weak reference to the c++ node! 5351 class MyFunc(Function): 5352 @staticmethod 5353 def forward(ctx, x): 5354 ctx.save_for_backward(x) 5355 return x 5356 5357 @staticmethod 5358 def backward(ctx, gO): 5359 (x,) = ctx.saved_tensors 5360 return MyFunc2.apply(x) 5361 5362 class MyFunc2(Function): 5363 @staticmethod 5364 def forward(ctx, x): 5365 return x 5366 5367 @staticmethod 5368 def backward(ctx, gO): 5369 return gO + float("NaN") 5370 5371 inp = torch.rand(1, requires_grad=True) 5372 out = MyFunc.apply(inp) 5373 (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5374 5375 with warnings.catch_warnings(record=True) as w: 5376 with self.assertRaisesRegex( 5377 RuntimeError, 5378 "Function 'MyFunc2Backward' returned nan values in its 0th output.", 5379 ): 5380 with detect_anomaly(): 5381 ginp.backward() 5382 5383 class Foo: 5384 pass 5385 5386 my_obj = Foo() 5387 meta_dict = out.grad_fn.metadata 5388 meta_dict[0] = my_obj 5389 ref = weakref.ref(my_obj) 5390 return out, ref 5391 5392 t, ref = get_ref() 5393 self.assertIsNotNone(ref()) 5394 del t 5395 self.assertIsNone(ref()) 5396 5397 def test_anomaly_mode_no_check_nan(self): 5398 class MyFunc(torch.autograd.Function): 5399 @staticmethod 5400 def forward(ctx, inp): 5401 return inp.clone() 5402 5403 @staticmethod 5404 def backward(ctx, gO): 5405 return torch.tensor(float("nan")).expand(10, 10) 5406 5407 def run_fn(a): 5408 out = MyFunc.apply(a) 5409 return out.sum() 5410 5411 with warnings.catch_warnings(record=True) as w: 5412 with torch.autograd.detect_anomaly(check_nan=False): 5413 inp = torch.rand(10, 10, requires_grad=True) 5414 out = run_fn(inp) 5415 out.backward(retain_graph=True) 5416 5417 with torch.autograd.detect_anomaly(check_nan=True): 5418 with self.assertRaisesRegex( 5419 RuntimeError, 5420 "Function 'MyFuncBackward' returned nan values in its 0th output.", 5421 ): 5422 out.backward(retain_graph=True) 5423 5424 out.backward() 5425 5426 def test_no_grad_copy(self): 5427 # create autograd function that saves grad pointer as class static 5428 class MyFunc(Function): 5429 static_grad_ptr = None 5430 5431 @staticmethod 5432 def forward(ctx, inp1, inp2): 5433 return inp1 + inp2 5434 5435 @staticmethod 5436 def backward(ctx, grad): 5437 MyFunc.static_grad_ptr = grad.data_ptr() 5438 return grad, grad 5439 5440 class NonContGradFunc(Function): 5441 @staticmethod 5442 def forward(ctx, inp1): 5443 ctx.size = inp1.size() 5444 return torch.tensor([1.0]) 5445 5446 @staticmethod 5447 def backward(ctx, grad): 5448 return torch.ones(1).expand(ctx.size) 5449 5450 a = torch.randn(5, 6, requires_grad=True) 5451 b = torch.randn(5, 6, requires_grad=True) 5452 # non-contiguous grad should be copied 5453 NonContGradFunc.apply(MyFunc.apply(a, b)).backward() 5454 self.assertFalse(a.grad.data_ptr() == MyFunc.static_grad_ptr) 5455 self.assertFalse(b.grad.data_ptr() == MyFunc.static_grad_ptr) 5456 # test case that should trigger no copy for one of a,b 5457 a.grad = b.grad = None 5458 MyFunc.apply(a, b)[1][0].backward() 5459 p_g = MyFunc.static_grad_ptr 5460 p_a = a.grad.data_ptr() 5461 p_b = b.grad.data_ptr() 5462 # check a,b uses different grad buffer 5463 self.assertFalse(p_a == p_b) 5464 # check one of them is using the computed buffer 5465 self.assertTrue(p_a == p_g or p_b == p_g) 5466 5467 def test_no_grad_copy_sparse(self): 5468 # create autograd function that saves grad pointer as class static 5469 class MyFunc(Function): 5470 static_grad_ptr = None 5471 5472 @staticmethod 5473 def forward(ctx, inp1, inp2): 5474 return inp1 + inp2 5475 5476 @staticmethod 5477 def backward(ctx, grad): 5478 MyFunc.static_grad_ptr = grad._values().data_ptr() 5479 return grad, grad 5480 5481 class NonContGradFunc(Function): 5482 static_grad_ptr = None 5483 5484 @staticmethod 5485 def forward(ctx, inp1, inp2): 5486 return inp1 + inp2 5487 5488 @staticmethod 5489 def backward(ctx, grad): 5490 # Create a sparse tensor with non-contigous indices and values 5491 # and return as grad. 5492 v = torch.rand(1, 3) 5493 i = torch.ones(1, 1, dtype=torch.long) 5494 nv = v.expand(8, 3) 5495 ni = i.expand(1, 8) 5496 ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32) 5497 NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr() 5498 return ngrad, ngrad 5499 5500 a = torch.randn(10, 3, requires_grad=True) 5501 b = torch.randn(10, 3, requires_grad=True) 5502 input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) 5503 offsets = torch.tensor([0, 4]) 5504 import torch.nn.functional as F 5505 5506 # test case that should trigger no copy for one of a,b 5507 emb_matrix = MyFunc.apply(a, b) 5508 loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 5509 loss.backward(retain_graph=True) 5510 p_g = MyFunc.static_grad_ptr 5511 p_a = a.grad._values().data_ptr() 5512 p_b = b.grad._values().data_ptr() 5513 # check a,b uses different grad buffer 5514 self.assertFalse(p_a == p_b) 5515 # check one of them is using the computed buffer 5516 self.assertTrue(p_a == p_g or p_b == p_g) 5517 5518 # Run backwards multiple times to ensure accumulation works. 5519 for i in range(10): 5520 loss.backward(retain_graph=True) 5521 5522 # non-contiguous indices and value, we should trigger a copy. 5523 a.grad = b.grad = None 5524 emb_matrix = NonContGradFunc.apply(a, b) 5525 loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 5526 loss.backward(retain_graph=True) 5527 p_g = NonContGradFunc.static_grad_ptr 5528 p_a = a.grad._values().data_ptr() 5529 p_b = b.grad._values().data_ptr() 5530 # check a,b uses different grad buffer 5531 self.assertFalse(p_a == p_b) 5532 # Verify we cloned both grads. 5533 self.assertFalse(p_a == p_g) 5534 self.assertFalse(p_b == p_g) 5535 5536 # Run backwards multiple times to ensure accumulation works. 5537 for i in range(10): 5538 loss.backward(retain_graph=True) 5539 5540 def test_gradcheck_single_input(self): 5541 def check(fast_mode): 5542 def f(inp): 5543 return inp.mul(5) 5544 5545 gradcheck( 5546 f, 5547 torch.rand(10, dtype=torch.float64, requires_grad=True), 5548 fast_mode=fast_mode, 5549 ) 5550 gradgradcheck( 5551 f, 5552 torch.rand(10, dtype=torch.float64, requires_grad=True), 5553 fast_mode=fast_mode, 5554 ) 5555 5556 check(fast_mode=True) 5557 check(fast_mode=False) 5558 5559 @parametrize( 5560 "layout", 5561 ( 5562 torch.sparse_coo, 5563 torch.sparse_csr, 5564 torch.sparse_csc, 5565 torch.sparse_bsr, 5566 torch.sparse_bsc, 5567 ), 5568 ) 5569 def test_gradcheck_input(self, layout): 5570 if layout in {torch.sparse_bsr, torch.sparse_bsc}: 5571 blocksize = (2, 2) 5572 size = (4, 8) 5573 else: 5574 blocksize = None 5575 size = (2, 2) 5576 5577 def check(fast_mode, masked): 5578 def fn(sparse): 5579 return torch.sum(sparse) 5580 5581 gradcheck( 5582 fn, 5583 torch.rand(size, dtype=torch.double) 5584 .to_sparse(layout=layout, blocksize=blocksize) 5585 .requires_grad_(), 5586 masked=masked, 5587 check_batched_grad=False, 5588 fast_mode=fast_mode, 5589 ) 5590 5591 for fast_mode, masked in product(*[(True, False)] * 2): 5592 check(fast_mode=fast_mode, masked=masked) 5593 5594 def test_gradcheck_nondeterministic(self): 5595 class NonDetFunc(Function): 5596 @staticmethod 5597 def forward(ctx, x, jitter=0.0): 5598 ctx._jitter = jitter 5599 return x 5600 5601 @staticmethod 5602 def backward(ctx, grad_out): 5603 return ( 5604 NonDetFunc.apply(grad_out, ctx._jitter) 5605 * (1 + torch.rand_like(grad_out) * ctx._jitter), 5606 None, 5607 ) 5608 5609 def check(fast_mode): 5610 inp = torch.randn(5, 5, dtype=torch.double, requires_grad=True) 5611 gradcheck( 5612 lambda x: NonDetFunc.apply(x, 0.0), 5613 inp, 5614 check_batched_grad=False, 5615 fast_mode=fast_mode, 5616 ) 5617 with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"): 5618 gradcheck( 5619 lambda x: NonDetFunc.apply(x, 1e-6), 5620 inp, 5621 check_batched_grad=False, 5622 fast_mode=fast_mode, 5623 ) 5624 with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"): 5625 gradgradcheck( 5626 lambda x: NonDetFunc.apply(x, 1e-12), 5627 inp, 5628 check_batched_grad=False, 5629 fast_mode=fast_mode, 5630 ) 5631 gradcheck( 5632 lambda x: NonDetFunc.apply(x, 0.0), 5633 inp, 5634 nondet_tol=1e-5, 5635 check_batched_grad=False, 5636 fast_mode=fast_mode, 5637 ) 5638 gradcheck( 5639 lambda x: NonDetFunc.apply(x, 1e-6), 5640 inp, 5641 nondet_tol=1e-5, 5642 check_batched_grad=False, 5643 fast_mode=fast_mode, 5644 ) 5645 gradgradcheck( 5646 lambda x: NonDetFunc.apply(x, 1e-12), 5647 inp, 5648 nondet_tol=1e-5, 5649 check_batched_grad=False, 5650 fast_mode=fast_mode, 5651 ) 5652 5653 check(fast_mode=True) 5654 check(fast_mode=False) 5655 5656 def test_gradcheck_validates_inputs(self): 5657 def check(fast_mode): 5658 x = torch.rand(10, requires_grad=True).to_sparse() 5659 self.assertTrue( 5660 gradcheck( 5661 lambda x: x.to_dense(), 5662 (x,), 5663 check_batched_grad=False, 5664 atol=1e-1, 5665 fast_mode=fast_mode, 5666 masked=True, 5667 ) 5668 ) 5669 self.assertFalse( 5670 gradcheck( 5671 lambda x: x.to_dense(), 5672 (x,), 5673 masked=False, 5674 check_batched_grad=False, 5675 raise_exception=False, 5676 fast_mode=fast_mode, 5677 ) 5678 ) 5679 self.assertTrue( 5680 gradcheck( 5681 lambda x: x.to_dense(masked_grad=False), 5682 (x,), 5683 masked=False, 5684 atol=1e-1, 5685 check_batched_grad=False, 5686 raise_exception=False, 5687 fast_mode=fast_mode, 5688 ) 5689 ) 5690 5691 # when none of the inputs require grad (always raises even if raise_exception=False) 5692 x = torch.rand(10, requires_grad=False) 5693 with self.assertRaisesRegex( 5694 ValueError, "at least one input tensor to require gradient" 5695 ): 5696 gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode) 5697 5698 # (warning) when inputs are not double precision 5699 x = torch.ones(1, dtype=torch.float32, requires_grad=True) 5700 with self.assertWarnsRegex( 5701 UserWarning, "Input #0 requires gradient and is not a double precision" 5702 ): 5703 self.assertTrue( 5704 gradcheck(lambda x: x, (x,), atol=1e-1, fast_mode=fast_mode) 5705 ) 5706 5707 # when layout is not mkldnn(aka has strides) and input has a dimension with stride 0. (always raises 5708 # even if raise_exception=False) 5709 x = torch.ones(1, dtype=torch.float64, requires_grad=True) 5710 x = x.expand((2, 2)) 5711 with self.assertRaisesRegex( 5712 RuntimeError, "The 0th input has a dimension with stride 0" 5713 ): 5714 gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode) 5715 5716 check(fast_mode=True) 5717 check(fast_mode=False) 5718 5719 @unittest.skipIf( 5720 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 5721 ) 5722 def test_gradcheck_validates_input_mkldnn(self): 5723 # when mkldnn inputs, forward mode testing is not allowed 5724 # Update tolerances below to make sure the gradient match even in single precision floats 5725 # Use the warning assert to hide the float32 warning 5726 x = torch.ones(1).to_mkldnn().requires_grad_() 5727 with self.assertWarnsRegex( 5728 UserWarning, "Input #0 requires gradient and is not a double precision" 5729 ): 5730 with self.assertRaisesRegex( 5731 ValueError, "MKLDNN inputs are not support for forward AD gradcheck." 5732 ): 5733 gradcheck( 5734 lambda x: x.to_dense(), 5735 (x,), 5736 raise_exception=False, 5737 fast_mode=False, 5738 check_forward_ad=True, 5739 atol=1e-1, 5740 rtol=1e-1, 5741 ) 5742 5743 with self.assertWarnsRegex( 5744 UserWarning, "Input #0 requires gradient and is not a double precision" 5745 ): 5746 with self.assertRaisesRegex( 5747 ValueError, "MKLDNN inputs are not support for forward AD gradcheck." 5748 ): 5749 gradcheck( 5750 lambda x: x.to_dense(), 5751 (x,), 5752 raise_exception=False, 5753 fast_mode=True, 5754 check_forward_ad=True, 5755 atol=1e-1, 5756 rtol=1e-1, 5757 ) 5758 5759 @unittest.skipIf( 5760 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 5761 ) 5762 def test_gradcheck_test_outputs(self): 5763 def check(fast_mode): 5764 # when sparse outputs (always raise even if raise_exception=False) 5765 x = torch.rand(10, requires_grad=True).to_sparse() 5766 with self.assertRaisesRegex( 5767 ValueError, "Sparse output is not supported at gradcheck yet" 5768 ): 5769 gradcheck( 5770 lambda x: x, 5771 (x,), 5772 masked=True, 5773 check_batched_grad=False, 5774 raise_exception=False, 5775 fast_mode=fast_mode, 5776 ) 5777 5778 # when mkldnn outputs (always raise even if raise_exception=False) 5779 root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) 5780 with self.assertRaisesRegex( 5781 ValueError, "MKLDNN output is not supported at gradcheck yet" 5782 ): 5783 gradcheck( 5784 lambda x: x.to_mkldnn(), 5785 (root,), 5786 check_batched_grad=False, 5787 raise_exception=False, 5788 fast_mode=fast_mode, 5789 ) 5790 5791 check(fast_mode=True) 5792 check(fast_mode=False) 5793 5794 def test_gradcheck_check_no_differentiable_outputs(self): 5795 def check(fast_mode): 5796 # When none of the outputs are differentiable, but numerical gradient is not zero 5797 x = torch.ones((1,), requires_grad=True) 5798 with self.assertRaisesRegex( 5799 RuntimeError, "Numerical gradient for function expected to be zero" 5800 ): 5801 gradcheck(lambda x: torch.tensor([x]), x) 5802 self.assertFalse( 5803 gradcheck( 5804 lambda x: torch.tensor([x]), 5805 x, 5806 raise_exception=False, 5807 fast_mode=fast_mode, 5808 ) 5809 ) 5810 5811 # succeed when no outputs at all 5812 self.assertTrue(gradcheck(lambda x: (), (x,), fast_mode=fast_mode)) 5813 5814 check(fast_mode=True) 5815 check(fast_mode=False) 5816 5817 def test_gradcheck_check_batched_grad(self): 5818 def check(fast_mode): 5819 x = torch.rand(10, dtype=torch.double, requires_grad=True).to_sparse() 5820 # runtime error while compute batched grad (print big error) 5821 with self.assertRaisesRegex( 5822 RuntimeError, 5823 "gradcheck or gradgradcheck failed while testing batched gradient", 5824 ): 5825 gradcheck( 5826 lambda x: x.to_dense(), 5827 (x,), 5828 masked=True, 5829 check_batched_grad=True, 5830 fast_mode=fast_mode, 5831 ) 5832 self.assertFalse( 5833 gradcheck( 5834 lambda x: x.to_dense(), 5835 (x,), 5836 masked=True, 5837 check_batched_grad=True, 5838 raise_exception=False, 5839 fast_mode=fast_mode, 5840 ) 5841 ) 5842 5843 check(fast_mode=True) 5844 check(fast_mode=False) 5845 5846 def test_gradcheck_backward_mul_by_grad_output(self): 5847 # when grad_input is sparse and has incorrect sparse_dim/dense_dim 5848 def check(fast_mode): 5849 def fn(x): 5850 def hook(grad): 5851 if grad is not None: 5852 return grad.to_dense().to_sparse(1) 5853 return grad 5854 5855 y = x.clone() 5856 y.register_hook(hook) 5857 return y.to_dense() 5858 5859 x = torch.ones((2, 2), dtype=torch.double, requires_grad=True).to_sparse() 5860 with self.assertRaisesRegex( 5861 RuntimeError, "grad is sparse tensor, but has incorrect sparse_dim" 5862 ): 5863 gradcheck( 5864 fn, 5865 (x,), 5866 atol=1e-1, 5867 masked=True, 5868 check_batched_grad=False, 5869 fast_mode=fast_mode, 5870 ) 5871 self.assertFalse( 5872 gradcheck( 5873 fn, 5874 (x,), 5875 atol=1e-1, 5876 masked=True, 5877 check_batched_grad=False, 5878 raise_exception=False, 5879 fast_mode=fast_mode, 5880 ) 5881 ) 5882 5883 # when backward not multiplied by grad_output (non-sparse case) 5884 def fn2(x): 5885 y = x.clone() 5886 y.register_hook(lambda x: x + 1e-2) 5887 return y 5888 5889 x = torch.ones(1, dtype=torch.double, requires_grad=True) 5890 with self.assertRaisesRegex( 5891 RuntimeError, "backward not multiplied by grad_output" 5892 ): 5893 gradcheck(fn2, (x,), atol=1e-1, fast_mode=fast_mode) 5894 self.assertFalse( 5895 gradcheck( 5896 fn2, (x,), atol=1e-1, raise_exception=False, fast_mode=fast_mode 5897 ) 5898 ) 5899 5900 # when backward not multiplied by grad_output (sparse case) 5901 def fn3(x): 5902 y = x.clone().to_dense() 5903 y.register_hook(lambda x: x + 1e-2) 5904 return y 5905 5906 x = torch.ones(1, dtype=torch.double, requires_grad=True).to_sparse() 5907 with self.assertRaisesRegex( 5908 RuntimeError, "backward not multiplied by grad_output" 5909 ): 5910 gradcheck( 5911 fn3, 5912 (x,), 5913 atol=1e-1, 5914 masked=True, 5915 check_batched_grad=False, 5916 fast_mode=fast_mode, 5917 ) 5918 self.assertFalse( 5919 gradcheck( 5920 fn3, 5921 (x,), 5922 atol=1e-1, 5923 masked=True, 5924 check_batched_grad=False, 5925 raise_exception=False, 5926 fast_mode=fast_mode, 5927 ) 5928 ) 5929 5930 # when layout of grad_input is not the same as input 5931 class Test(Function): 5932 @staticmethod 5933 def forward(ctx, x): 5934 return x 5935 5936 @staticmethod 5937 def backward(ctx, x): 5938 return x.to_sparse() 5939 5940 x = torch.ones(1, dtype=torch.double, requires_grad=True) 5941 with self.assertRaisesRegex(RuntimeError, "grad is incorrect layout"): 5942 gradcheck( 5943 Test.apply, (x,), check_batched_grad=False, fast_mode=fast_mode 5944 ) 5945 self.assertFalse( 5946 gradcheck( 5947 Test.apply, 5948 (x,), 5949 check_batched_grad=False, 5950 raise_exception=False, 5951 fast_mode=fast_mode, 5952 ) 5953 ) 5954 5955 check(fast_mode=True) 5956 check(fast_mode=False) 5957 5958 def test_gradcheck_undefined_grad(self): 5959 def check(fast_mode): 5960 # when encounter runtime error while running backward 5961 def fn(x): 5962 def hook(x): 5963 if x is None: 5964 raise RuntimeError("x is undefined") 5965 5966 y = x.clone() 5967 y.register_hook(hook) 5968 return y 5969 5970 x = torch.ones(1, dtype=torch.double, requires_grad=True) 5971 with self.assertWarnsRegex( 5972 UserWarning, 5973 "Backwards compatibility: New undefined gradient support checking feature", 5974 ): 5975 with self.assertRaisesRegex( 5976 RuntimeError, 5977 "Expected backward function to handle undefined output grads", 5978 ): 5979 gradcheck(fn, (x,), fast_mode=fast_mode) 5980 self.assertFalse( 5981 gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode) 5982 ) 5983 5984 check(fast_mode=True) 5985 check(fast_mode=False) 5986 5987 def test_gradcheck_jacobian_mismatch(self): 5988 def check(fast_mode): 5989 def fn(x): # R -> R, C -> C 5990 y = x.clone() 5991 y.register_hook(lambda x: x + 1e-2) 5992 return y 5993 5994 x = torch.ones(2, 2, requires_grad=True) 5995 with self.assertRaisesRegex( 5996 RuntimeError, "Jacobian mismatch for output 0 with respect to input 0" 5997 ): 5998 gradcheck(fn, (x,), fast_mode=fast_mode) 5999 self.assertFalse( 6000 gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode) 6001 ) 6002 6003 x_c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128) 6004 with self.assertRaisesRegex( 6005 RuntimeError, 6006 "While considering the imaginary part of complex outputs only", 6007 ): 6008 gradcheck(fn, (x_c,), fast_mode=False) 6009 self.assertFalse( 6010 gradcheck(fn, (x_c,), raise_exception=False, fast_mode=False) 6011 ) 6012 6013 def fn2(x): # R -> C 6014 y = torch.complex(x, x) 6015 y.register_hook(lambda x: x + 1e-2) 6016 return y 6017 6018 x = torch.ones(2, 2, requires_grad=True) 6019 with self.assertRaisesRegex( 6020 RuntimeError, 6021 "While considering the imaginary part of complex outputs only", 6022 ): 6023 gradcheck(fn2, (x,), fast_mode=False) 6024 self.assertFalse( 6025 gradcheck(fn2, (x,), raise_exception=False, fast_mode=False) 6026 ) 6027 6028 def fn3(x): # C -> R 6029 y = torch.real(x) 6030 y.register_hook(lambda x: x + 1e-2) 6031 return y 6032 6033 with self.assertRaisesRegex( 6034 RuntimeError, "Jacobian mismatch for output 0 with respect to input 0" 6035 ): 6036 gradcheck(fn3, (x_c,), fast_mode=False) 6037 self.assertFalse( 6038 gradcheck(fn3, (x_c,), raise_exception=False, fast_mode=False) 6039 ) 6040 6041 check(fast_mode=True) 6042 check(fast_mode=False) 6043 6044 def test_gradcheck_dense_and_sparse_inputs(self): 6045 def check(fast_mode): 6046 def fn(x, y): 6047 return x * y.coalesce().to_dense() 6048 6049 a = torch.rand(2, 2, dtype=torch.double, requires_grad=True) 6050 b = torch.rand(2, 2, dtype=torch.double).to_sparse().requires_grad_(True) 6051 self.assertTrue( 6052 gradcheck( 6053 fn, 6054 (a, b), 6055 masked=True, 6056 check_batched_grad=False, 6057 fast_mode=fast_mode, 6058 ) 6059 ) 6060 6061 check(fast_mode=True) 6062 check(fast_mode=False) 6063 6064 @unittest.skipIf( 6065 not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 6066 ) 6067 def test_gradcheck_multiple_mkldnn_inputs(self): 6068 def check(fast_mode): 6069 def fn(x, y): 6070 return x + y.to_dense() 6071 6072 a = torch.rand(10, requires_grad=True) 6073 b = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True) 6074 self.assertTrue( 6075 gradcheck( 6076 fn, (a, b), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode 6077 ) 6078 ) 6079 6080 def fn2(x, y): 6081 return x.to_dense() + y.to_dense() 6082 6083 c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True) 6084 self.assertTrue( 6085 gradcheck( 6086 fn, (a, c), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode 6087 ) 6088 ) 6089 6090 check(fast_mode=True) 6091 check(fast_mode=False) 6092 6093 def test_gradcheck_output_shape_or_dtype_depend_on_values(self): 6094 def check(fast_mode): 6095 def fn(x): 6096 if torch.all(x >= 1): 6097 return torch.cat([x, x]) 6098 else: 6099 return x 6100 6101 a = torch.ones(1, dtype=torch.double, requires_grad=True) 6102 with self.assertRaisesRegex( 6103 AssertionError, 6104 "return outputs with the same shape when inputs are perturbed", 6105 ): 6106 self.assertTrue(gradcheck(fn, (a,), fast_mode=fast_mode)) 6107 6108 def fn2(x): 6109 if torch.all(x >= 1): 6110 return x.to(torch.float32) 6111 else: 6112 return x 6113 6114 with self.assertRaisesRegex( 6115 AssertionError, 6116 "return outputs with the same dtype when inputs are perturbed", 6117 ): 6118 self.assertTrue(gradcheck(fn2, (a,), fast_mode=fast_mode)) 6119 6120 check(fast_mode=True) 6121 check(fast_mode=False) 6122 6123 def test_gradcheck_complex_non_complex_outputs(self): 6124 def fn(x, y): 6125 z = torch.complex(x, y) 6126 return z, x + 1 6127 6128 a = torch.ones(2, 2, requires_grad=True, dtype=torch.float64) 6129 b = torch.ones(2, 2, requires_grad=True, dtype=torch.float64) 6130 self.assertTrue(gradcheck(fn, (a, b))) 6131 6132 def fn2(z): 6133 return z, torch.real(z) 6134 6135 c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128) 6136 self.assertTrue(gradcheck(fn2, (c))) 6137 6138 def test_gradcheck_get_numerical_jacobian(self): 6139 # get_numerical_jacobian is deprecated and no longer used internally by gradcheck 6140 from torch.autograd.gradcheck import get_numerical_jacobian 6141 6142 def fn(inputs): 6143 # get_numerical_jacobian requires fn to take inputs as a tuple 6144 # and returns the jacobian wrt the first output 6145 x = inputs[0] 6146 y = inputs[1] 6147 return 2 * x + y, x + 2 * y 6148 6149 a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6150 b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6151 6152 with self.assertWarnsRegex( 6153 FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" 6154 ): 6155 jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6) 6156 self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) 6157 6158 with self.assertWarnsRegex( 6159 FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" 6160 ): 6161 jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6) 6162 self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) 6163 self.assertEqual(jacobian[1], 1 * torch.eye(4, dtype=torch.double)) 6164 6165 with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"): 6166 jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=2.0) 6167 6168 def test_gradcheck_get_analytical_jacobian(self): 6169 from torch.autograd.gradcheck import get_analytical_jacobian 6170 6171 def fn(x, y): 6172 return 2 * x + y, x + 2 * y 6173 6174 a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6175 b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6176 6177 outputs = fn(a, b) 6178 with self.assertWarnsRegex( 6179 FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" 6180 ): 6181 ( 6182 jacobians, 6183 reentrant, 6184 correct_grad_sizes, 6185 correct_grad_types, 6186 ) = get_analytical_jacobian((a, b), outputs[0]) 6187 self.assertEqual(jacobians[0], 2 * torch.eye(4, dtype=torch.double)) 6188 self.assertEqual(jacobians[1], 1 * torch.eye(4, dtype=torch.double)) 6189 self.assertTrue(reentrant) 6190 6191 class NonDetFunc(Function): 6192 @staticmethod 6193 def forward(ctx, x, jitter=0.0): 6194 ctx._jitter = jitter 6195 return x 6196 6197 @staticmethod 6198 def backward(ctx, grad_out): 6199 return ( 6200 NonDetFunc.apply(grad_out, ctx._jitter) 6201 * (1 + torch.rand_like(grad_out) * ctx._jitter), 6202 None, 6203 ) 6204 6205 outputs = NonDetFunc.apply(a, 1e-6) 6206 with self.assertWarnsRegex( 6207 FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" 6208 ): 6209 ( 6210 jacobians, 6211 reentrant, 6212 correct_grad_sizes, 6213 correct_grad_types, 6214 ) = get_analytical_jacobian((a,), outputs) 6215 self.assertFalse(reentrant) 6216 6217 with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"): 6218 jacobians, _, _, _ = get_analytical_jacobian((a,), outputs, grad_out=2.0) 6219 6220 def test_gradcheck_custom_error(self): 6221 from torch.autograd.gradcheck import GradcheckError 6222 6223 def check(fast_mode): 6224 def fn(x): 6225 y = x.clone() 6226 y.register_hook(lambda x: x + 1e-2) 6227 return y 6228 6229 x = torch.ones(2, 2, requires_grad=True) 6230 with self.assertRaisesRegex( 6231 GradcheckError, "Jacobian mismatch for output 0 with respect to input 0" 6232 ): 6233 gradcheck(fn, (x,), fast_mode=fast_mode) 6234 with self.assertRaisesRegex( 6235 RuntimeError, "Jacobian mismatch for output 0 with respect to input 0" 6236 ): 6237 gradcheck(fn, (x,), fast_mode=fast_mode) 6238 self.assertFalse( 6239 gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode) 6240 ) 6241 6242 def fn2(x): 6243 raise RuntimeError("Not a GradcheckError!") 6244 6245 # Checks that when raise_exception=False, non-GradcheckErrors are not caught by gradcheck 6246 with self.assertRaisesRegex(RuntimeError, "Not a GradcheckError!"): 6247 gradcheck(fn2, (x,), fast_mode=fast_mode, raise_exception=False) 6248 6249 check(fast_mode=True) 6250 check(fast_mode=False) 6251 6252 def test_gradcheck_forward_ad(self): 6253 def fn(x, y): 6254 return x + y, y 6255 6256 def bad_fn(x, y): 6257 # Hacky way to check if we're currently inside a forward ad level 6258 is_running_forward_ad = fwAD._current_level >= 0 6259 6260 if is_running_forward_ad: 6261 y_p, y_d = fwAD.unpack_dual(y) 6262 y = fwAD.make_dual(y_p, y_d * 1.1) 6263 6264 return x + y, y 6265 6266 err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1" 6267 6268 for fast_mode in [True, False]: 6269 # Test for all inputs and outputs being real 6270 x = torch.rand(2, dtype=torch.double, requires_grad=True) 6271 y = torch.rand(2, dtype=torch.double, requires_grad=True) 6272 6273 gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6274 with self.assertRaisesRegex(RuntimeError, err_msg): 6275 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6276 6277 def basic_mul(x): 6278 return torch.view_as_real(torch.resolve_conj(x * 1j)) 6279 6280 gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode) 6281 6282 # Test for one input and one output being complex 6283 x = torch.rand(2, dtype=torch.cdouble, requires_grad=True) 6284 6285 gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6286 with self.assertRaisesRegex(RuntimeError, err_msg): 6287 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6288 6289 # Test for all inputs and outputs being complex 6290 y = torch.rand(2, dtype=torch.cdouble, requires_grad=True) 6291 6292 gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6293 with self.assertRaisesRegex(RuntimeError, err_msg): 6294 gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6295 6296 def test_gradcheck_forward_ad_runs_with_no_requires_grad(self): 6297 # Currently requires_grad is used as a easy way for gradcheck to know 6298 # which inputs of the function are meant to be differentiable 6299 # This test checks that when the inputs are passed to the function they should not have 6300 # requires_grad=True even though they may have requires_grad=True when passed 6301 # to gradcheck 6302 class UserFn(Function): 6303 @staticmethod 6304 def forward(ctx, x, y): 6305 if fwAD._current_level >= 0: 6306 self.assertFalse(x.requires_grad) 6307 self.assertFalse(y.requires_grad) 6308 return x.clone(), y.clone() 6309 6310 @staticmethod 6311 def jvp(ctx, x_t, y_t): 6312 return x_t, y_t 6313 6314 x = torch.rand(2, dtype=torch.double, requires_grad=True) 6315 y = torch.rand(2, dtype=torch.double, requires_grad=True) 6316 6317 gradcheck( 6318 UserFn.apply, 6319 (x, y), 6320 check_forward_ad=True, 6321 check_undefined_grad=False, 6322 check_backward_ad=False, 6323 check_batched_grad=False, 6324 check_batched_forward_grad=False, 6325 ) 6326 6327 gradcheck( 6328 UserFn.apply, 6329 (x, y), 6330 check_forward_ad=True, 6331 check_undefined_grad=True, 6332 check_backward_ad=False, 6333 check_batched_grad=False, 6334 check_batched_forward_grad=False, 6335 ) 6336 6337 gradcheck( 6338 UserFn.apply, 6339 (x, y), 6340 check_forward_ad=True, 6341 check_undefined_grad=True, 6342 check_backward_ad=False, 6343 check_batched_grad=False, 6344 check_batched_forward_grad=True, 6345 ) 6346 6347 x = torch.rand(2, dtype=torch.double, requires_grad=True) 6348 y = torch.rand(2, dtype=torch.double, requires_grad=False) 6349 gradcheck( 6350 UserFn.apply, 6351 (x, y), 6352 check_forward_ad=True, 6353 check_undefined_grad=True, 6354 check_backward_ad=False, 6355 check_batched_grad=False, 6356 check_batched_forward_grad=True, 6357 ) 6358 6359 def test_gradcheck_forward_ad_respects_requires_grad(self): 6360 # Currently requires_grad is used as a easy way for gradcheck to know 6361 # which inputs of the function are meant to be differentiable 6362 jvp_count = [0] 6363 6364 class UserFn(Function): 6365 @staticmethod 6366 def forward(ctx, x, y): 6367 return x.clone(), y.clone() 6368 6369 @staticmethod 6370 def jvp(ctx, x_t, y_t): 6371 jvp_count[0] += 1 6372 return x_t, y_t 6373 6374 # NB: In slow gradcheck we need to loop through numel times so use numel = 1 to ensure 6375 # that fast and slow have the same counts 6376 x = torch.rand(1, dtype=torch.double, requires_grad=True) 6377 y = torch.rand(1, dtype=torch.double, requires_grad=True) 6378 gradcheck( 6379 UserFn.apply, 6380 (x, y), 6381 check_forward_ad=True, 6382 check_undefined_grad=False, 6383 check_backward_ad=False, 6384 check_batched_grad=False, 6385 check_batched_forward_grad=False, 6386 ) 6387 self.assertEqual(jvp_count[0], 2) # (2) once per input 6388 jvp_count = [0] 6389 6390 gradcheck( 6391 UserFn.apply, 6392 (x, y), 6393 check_forward_ad=True, 6394 check_undefined_grad=True, 6395 check_backward_ad=False, 6396 check_batched_grad=False, 6397 check_batched_forward_grad=False, 6398 ) 6399 self.assertEqual( 6400 jvp_count[0], 6 6401 ) # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2) 6402 jvp_count = [0] 6403 6404 gradcheck( 6405 UserFn.apply, 6406 (x, y), 6407 check_forward_ad=True, 6408 check_undefined_grad=True, 6409 check_backward_ad=False, 6410 check_batched_grad=False, 6411 check_batched_forward_grad=True, 6412 ) 6413 self.assertEqual( 6414 jvp_count[0], 12 6415 ) # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2) 6416 jvp_count = [0] 6417 6418 # Repeat the previous test except we mark one input with requires_grad=False 6419 # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)! 6420 # Otherwise, other counts are halved. 6421 x = torch.rand(1, dtype=torch.double, requires_grad=True) 6422 y = torch.rand(1, dtype=torch.double, requires_grad=False) 6423 gradcheck( 6424 UserFn.apply, 6425 (x, y), 6426 check_forward_ad=True, 6427 check_undefined_grad=True, 6428 check_backward_ad=False, 6429 check_batched_grad=False, 6430 check_batched_forward_grad=True, 6431 ) 6432 self.assertEqual(jvp_count[0], 5) # 1 + 1 + 3 6433 6434 def test_gradcheck_check_forward_or_backward_only(self): 6435 """Depending on settings for check_forward_ad and check_backward_ad, the 6436 correct codepaths should be reached (or not reached) 6437 """ 6438 fwd_fail_err_msg = "FAIL FWD" 6439 bwd_fail_err_msg = "FAIL BWD" 6440 6441 class UserFn(Function): 6442 @staticmethod 6443 def forward(ctx, foo, fwd_bad, bwd_bad): 6444 ctx.fwd_bad = fwd_bad 6445 ctx.bwd_bad = bwd_bad 6446 return foo * 2 6447 6448 @staticmethod 6449 def vjp(ctx, gO): 6450 if ctx.bwd_bad: 6451 raise RuntimeError(bwd_fail_err_msg) 6452 else: 6453 return 2 * gO, None, None 6454 6455 @staticmethod 6456 def jvp(ctx, gI, _1, _2): 6457 if ctx.fwd_bad: 6458 raise RuntimeError(fwd_fail_err_msg) 6459 else: 6460 return 2 * gI 6461 6462 for fast_mode in (True, False): 6463 for check_forward_ad in (True, False): 6464 for check_backward_ad in (True, False): 6465 for fwd_bad in (True, False): 6466 for bwd_bad in (True, False): 6467 fwd_should_fail = fwd_bad and check_forward_ad 6468 bwd_should_fail = bwd_bad and check_backward_ad 6469 6470 def run(): 6471 gradcheck( 6472 UserFn.apply, 6473 (x, fwd_bad, bwd_bad), 6474 check_forward_ad=check_forward_ad, 6475 check_backward_ad=check_backward_ad, 6476 check_undefined_grad=check_backward_ad, 6477 check_batched_grad=check_backward_ad, 6478 fast_mode=fast_mode, 6479 ) 6480 6481 x = torch.rand(2, dtype=torch.double, requires_grad=True) 6482 6483 if not check_forward_ad and not check_backward_ad: 6484 with self.assertRaisesRegex( 6485 AssertionError, "Expected at least one of" 6486 ): 6487 run() 6488 continue 6489 6490 if not fwd_should_fail and not bwd_should_fail: 6491 run() 6492 else: 6493 # If both fail, backward AD failure "hides" forward AD failure 6494 if fwd_should_fail: 6495 fail_msg = fwd_fail_err_msg 6496 if bwd_should_fail: 6497 fail_msg = bwd_fail_err_msg 6498 with self.assertRaisesRegex(RuntimeError, fail_msg): 6499 run() 6500 6501 def test_gradcheck_forward_ad_batched_grad(self): 6502 x = torch.rand(2, dtype=torch.double, requires_grad=True) 6503 6504 # multiple inputs and outputs with non-tensors inputs 6505 def fn1(a: torch.Tensor, b: int): 6506 return a.clone(), a + 1 6507 6508 gradcheck( 6509 fn1, 6510 (x, 1), 6511 check_forward_ad=True, 6512 check_backward_ad=False, 6513 check_batched_grad=False, 6514 check_undefined_grad=False, 6515 check_batched_forward_grad=True, 6516 ) 6517 6518 # unrelated inputs: tangent for c is None 6519 def fn2(a: torch.Tensor, c: torch.Tensor): 6520 return a.clone() 6521 6522 gradcheck( 6523 fn2, 6524 (x, x.clone()), 6525 check_forward_ad=True, 6526 check_backward_ad=False, 6527 check_batched_grad=False, 6528 check_undefined_grad=False, 6529 check_batched_forward_grad=True, 6530 ) 6531 6532 class Fn(Function): 6533 @staticmethod 6534 def forward(ctx, foo): 6535 return foo * 2 6536 6537 @staticmethod 6538 def vjp(ctx, gO): 6539 return gO * 2 6540 6541 @staticmethod 6542 def jvp(ctx, gI): 6543 torch.randn_like(gI) 6544 return gI * 2 6545 6546 msg = "vmap: We do not yet support calling random operations inside of vmap" 6547 with self.assertRaisesRegex(RuntimeError, msg): 6548 gradcheck( 6549 Fn.apply, (x,), check_forward_ad=True, check_batched_forward_grad=True 6550 ) 6551 6552 def test_version_counter(self): 6553 x = torch.randn(1, 2) 6554 6555 # In-place op bumps version 6556 x_saved_version = x._version 6557 x.add_(1).add_(1) 6558 self.assertTrue(x._version > x_saved_version) 6559 6560 # Differentiable view shares version counter 6561 xz = x[:] 6562 self.assertTrue(x._version == xz._version) 6563 xz.add_(1) 6564 self.assertTrue(x._version == xz._version) 6565 6566 # `x.data = y` preserves version counter of `x` 6567 x_saved_version = x._version 6568 x.data = torch.randn(2, 3) 6569 self.assertTrue(x._version == x_saved_version) 6570 x.add_(1) 6571 self.assertTrue(x._version > x_saved_version) 6572 # Make sure `x` is still using the same version counter it shares with `xz` 6573 self.assertTrue(x._version == xz._version) 6574 6575 # In-place op on `xz` also updates version of `x`, 6576 # because they share the version counter 6577 xz.add_(1) 6578 self.assertTrue(x._version == xz._version) 6579 6580 def test_set_data_tensorimpl_type(self): 6581 # Dense tensor has impl of type `TensorImpl`, while sparse tensor has impl 6582 # of type `SparseTensorImpl`. 6583 x = torch.randn(1, 2) 6584 x_s = torch.sparse_coo_tensor(torch.zeros([1, 1]), torch.ones([1])) 6585 with self.assertRaisesRegex(RuntimeError, "incompatible tensor type"): 6586 x.data = x_s 6587 6588 def test_set_data_preserve_pyobj(self): 6589 a = torch.randn(1, 2) 6590 b = torch.randn(1, 2) 6591 b_id_saved = id(b) 6592 b.data = a 6593 self.assertTrue(b_id_saved == id(b)) 6594 6595 def test_set_data_self_requires_grad(self): 6596 a = torch.tensor(1.0, requires_grad=True) 6597 b = torch.tensor(2.0) 6598 c = torch.tensor(3, dtype=torch.int64) 6599 a.data = b 6600 with self.assertRaisesRegex( 6601 RuntimeError, "must be floating point or complex dtype" 6602 ): 6603 a.data = c 6604 6605 @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows") 6606 def test_thread_shutdown(self): 6607 code = """import torch 6608from torch.autograd import Function 6609class MyFunction(Function): 6610 @staticmethod 6611 def forward(ctx, x): 6612 return x 6613 6614 @staticmethod 6615 def backward(ctx, grad): 6616 return grad 6617 6618# Run on cuda if it is available to ensure that the worker thread 6619# is properly initialized by the time we exit. 6620device = "cuda" if torch.cuda.is_available() else "cpu" 6621 6622for shape in [(1,), ()]: 6623 v = torch.ones(shape, requires_grad=True, device=device) 6624 MyFunction.apply(v).backward() 6625""" 6626 s = TestCase.runWithPytorchAPIUsageStderr(code) 6627 # The autograd engine creates worker threads only when GPU devices are present. 6628 # So make sure that we do shutdown threads when we're testing cuda and make sure 6629 # that there is no thread to shutdown when we're not using cuda. 6630 if TEST_CUDA or torch.backends.mps.is_available() or torch.xpu.is_available(): 6631 self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") 6632 else: 6633 self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") 6634 6635 @unittest.skipIf( 6636 IS_MACOS, 6637 "Fails with SIGBUS on macOS; https://github.com/pytorch/pytorch/issues/25941", 6638 ) 6639 def test_deep_reentrant(self): 6640 class DeepReentrant(Function): 6641 @staticmethod 6642 def forward(ctx, x): 6643 with torch.enable_grad(): 6644 ctx.x = Variable(x.detach(), requires_grad=True) 6645 ctx.x = ctx.x - 1 6646 return ctx.x.detach() 6647 6648 @staticmethod 6649 def backward(ctx, x): 6650 if ctx.x < 0: 6651 return x 6652 with torch.enable_grad(): 6653 DeepReentrant.apply(ctx.x).sum().backward() 6654 return x 6655 6656 # Test stack overflow escape mechanism 6657 v = torch.tensor(2000.0, requires_grad=True) 6658 # This will cause stack overflow if reentrant calls are handled 6659 # in the same thread recursively 6660 DeepReentrant.apply(v).sum().backward() 6661 6662 # Test stack overflow escape mechanism multiple times 6663 # to ensure reusing workers in the pool works fine 6664 v2 = torch.tensor(200.0, requires_grad=True) 6665 DeepReentrant.apply(v2).sum().backward() 6666 6667 def test_reentrant_priority(self): 6668 order = [] 6669 6670 class MyFunction(Function): 6671 @staticmethod 6672 def forward(ctx, x): 6673 return x 6674 6675 @staticmethod 6676 def backward(ctx, x): 6677 order.append("MyFunction") 6678 return x 6679 6680 class Reentrant(Function): 6681 @staticmethod 6682 def forward(ctx, x): 6683 with torch.enable_grad(): 6684 ctx.x = Variable(x.detach(), requires_grad=True) 6685 ctx.x = ctx.x - 1 6686 return ctx.x.detach() 6687 6688 @staticmethod 6689 def backward(ctx, x): 6690 order.append("Reentrant") 6691 if ctx.x < 0: 6692 return x 6693 with torch.enable_grad(): 6694 Reentrant.apply(ctx.x).backward() 6695 return x 6696 6697 a = MyFunction.apply(torch.tensor(6.0, requires_grad=True)) 6698 b = Reentrant.apply(torch.tensor(9.0, requires_grad=True)) 6699 v = a * b 6700 v.backward() 6701 # The tasks for the Reentrant and MyFunction backward() will be added 6702 # to the queue in the autograd engine at the same time. The backward 6703 # for Reentrant will be executed first, which will then add other 6704 # backward tasks to the queue. We want to ensure all the reentrant tasks 6705 # are prioritized over the MyFunction backward task regardless of their 6706 # sequence numbers 6707 self.assertEqual(len(order), 11) 6708 self.assertEqual(order.count("Reentrant"), 10) 6709 self.assertEqual(order[-1], "MyFunction") 6710 6711 @slowTest 6712 def test_checkpointing(self): 6713 num_inp = 2000 6714 nz_inp = 10 6715 nz_out = 10 6716 nz_bottleneck = 1000 6717 6718 # small proxy network for some complex reasoning we want to do per input 6719 module = nn.Sequential( 6720 nn.Linear(nz_inp, nz_bottleneck), 6721 nn.ReLU(), 6722 nn.Linear(nz_bottleneck, nz_inp), 6723 ) 6724 6725 feat_combined = [] 6726 for r in range(num_inp): 6727 data_r = torch.empty(1, nz_inp) 6728 data_r.uniform_() 6729 data_r.requires_grad = True 6730 feat_r = checkpoint(module, data_r, use_reentrant=True) 6731 feat_combined.append(feat_r) 6732 6733 # compute mean as a proxy for some joint reasoning 6734 mean_combined = torch.stack(feat_combined).mean() 6735 mean_combined.backward() 6736 6737 def _test_checkpointing_non_reentrant_autocast(self, device_type): 6738 for enabled in [True, False]: 6739 6740 def foo(x, y, z): 6741 # torch.mm is on autocast's list of ops that should run in 6742 # the autocast precision 6743 x = torch.mm(x, y) 6744 y = torch.mm(x, z) 6745 z = torch.mm(z, z) 6746 expected_dtype = torch.float32 if not enabled else torch.bfloat16 6747 self.assertEqual(expected_dtype, z.dtype) 6748 return z 6749 6750 x = torch.randn(3, 3, requires_grad=True) 6751 y = torch.randn(3, 3, requires_grad=True) 6752 z = torch.randn(3, 3, requires_grad=True) 6753 if device_type == "cuda": 6754 x = x.cuda() 6755 y = y.cuda() 6756 z = z.cuda() 6757 6758 with torch.autocast( 6759 enabled=enabled, device_type=device_type, dtype=torch.bfloat16 6760 ): 6761 loss = checkpoint(foo, x, y, z, use_reentrant=False) 6762 loss = loss.sum() 6763 6764 # Without saving + recasting the autocast type, would raise error in autograd 6765 # about mismatched dtypes. 6766 loss.backward() # triggers recomputation to check it runs in bfloat 6767 6768 def test_checkpointing_non_reentrant_autocast_cpu(self): 6769 """ 6770 Test that autocast args such as the dtype are preserved during non-reentrant 6771 checkpoint recomputation on CPU. 6772 """ 6773 self._test_checkpointing_non_reentrant_autocast(device_type="cpu") 6774 6775 @unittest.skipIf( 6776 not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(), 6777 "Test requires CUDA bf16 support", 6778 ) 6779 def test_checkpointing_non_reentrant_autocast_gpu(self): 6780 """ 6781 Test that autocast args/kwargs such as the dtype are preserved during 6782 non-reentrant checkpoint recomputation on GPU. 6783 """ 6784 self._test_checkpointing_non_reentrant_autocast(device_type="cuda") 6785 6786 @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") 6787 @slowTest 6788 def test_checkpointing_without_reentrant_memory_savings(self): 6789 class MyModel(nn.Module): 6790 def __init__(self, n, use_checkpoint, use_reentrant): 6791 super().__init__() 6792 self.n = n 6793 self.use_checkpoint = use_checkpoint 6794 self.use_reentrant = use_reentrant 6795 self.layers = nn.ModuleList() 6796 for i in range(self.n): 6797 layer = nn.Sequential( 6798 nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256) 6799 ) 6800 self.layers.append(layer) 6801 # pre-allocate the grad so that increased memory usage is mainly 6802 # due to activations. 6803 for layer in self.layers: 6804 for lin in layer: 6805 lin.weight.grad = torch.ones_like(lin.weight) 6806 lin.bias.grad = torch.ones_like(lin.bias) 6807 6808 def forward(self, x): 6809 for i in range(self.n): 6810 if not self.use_checkpoint: 6811 x = self.layers[i](x) 6812 else: 6813 x = checkpoint( 6814 self.layers[i], x, use_reentrant=self.use_reentrant 6815 ) 6816 6817 return x 6818 6819 model_no_checkpoint = MyModel( 6820 8, use_checkpoint=False, use_reentrant=False 6821 ).cuda() 6822 model_reentrant_checkpoint = MyModel( 6823 8, use_checkpoint=True, use_reentrant=True 6824 ).cuda() 6825 model_no_reentrant_checkpoint = MyModel( 6826 8, use_checkpoint=True, use_reentrant=False 6827 ).cuda() 6828 6829 x = torch.randn(100, 256, requires_grad=True, device="cuda") 6830 6831 torch.cuda.reset_peak_memory_stats() 6832 loss = model_no_checkpoint(x.clone()).sum() 6833 loss.backward() 6834 mem_no_checkpoint = torch.cuda.max_memory_allocated() 6835 6836 torch.cuda.reset_peak_memory_stats() 6837 loss = model_reentrant_checkpoint(x.clone()).sum() 6838 loss.backward() 6839 mem_reentrant_checkpoint = torch.cuda.max_memory_allocated() 6840 6841 torch.cuda.reset_peak_memory_stats() 6842 loss = model_no_reentrant_checkpoint(x.clone()).sum() 6843 loss.backward() 6844 mem_no_reentrant_checkpoint = torch.cuda.max_memory_allocated() 6845 6846 self.assertTrue(mem_reentrant_checkpoint < mem_no_checkpoint) 6847 self.assertTrue(mem_no_reentrant_checkpoint < mem_no_checkpoint) 6848 6849 def test_checkpointing_without_reentrant_custom_function_works(self): 6850 msg = "Unpack is being triggered for a tensor that was already unpacked once" 6851 6852 class MyFunc(torch.autograd.Function): 6853 @staticmethod 6854 def forward(ctx, x, y, z): 6855 w = x * y * z 6856 out = w + w 6857 ctx.save_for_backward(x, y, z, w, out) 6858 return out 6859 6860 @staticmethod 6861 def backward(ctx, grad_out): 6862 x, y, z, w, out = ctx.saved_tensors 6863 # Accessing the saved Tensors a second time will raise because 6864 # recomputed tensors get cleared as soon as they are unpacked. 6865 # A recomputation is only triggered if your backward has a new 6866 # graph-task id. 6867 with self.assertRaisesRegex(RuntimeError, msg): 6868 x_2, y_2, z_2, w_2, out_2 = ctx.saved_tensors 6869 return x, y, z 6870 6871 x = torch.tensor(1.0, requires_grad=True) 6872 y = torch.tensor(2.0, requires_grad=True) 6873 z = torch.tensor(3.0, requires_grad=True) 6874 6875 def foo(x, y, z): 6876 x = x * y * z 6877 y = y * y * z 6878 z = z * z 6879 out = MyFunc.apply(x, y, z) 6880 return out 6881 6882 out = checkpoint(foo, x, y, z, use_reentrant=False) 6883 out.sum().backward() 6884 6885 def test_checkpointing_without_reentrant_with_context_fn(self): 6886 class VerboseTorchDispatchMode(TorchDispatchMode): 6887 def __init__(self) -> None: 6888 self.operators = [] 6889 6890 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 6891 if kwargs is None: 6892 kwargs = {} 6893 self.operators.append(func.__name__) 6894 return func(*args, **kwargs) 6895 6896 x = torch.tensor(1.0, requires_grad=True) 6897 verbose_mode = VerboseTorchDispatchMode() 6898 6899 def context_fn(): 6900 return verbose_mode, contextlib.nullcontext() 6901 6902 out = checkpoint( 6903 lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn 6904 ) 6905 self.assertEqual(verbose_mode.operators, ["exp.default"]) 6906 6907 verbose_mode.operators = [] 6908 6909 def context_fn(): 6910 return contextlib.nullcontext(), verbose_mode 6911 6912 out = checkpoint( 6913 lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn 6914 ) 6915 out.backward() 6916 self.assertEqual( 6917 verbose_mode.operators, ["exp.default", "detach.default", "detach.default"] 6918 ) 6919 6920 with self.assertRaisesRegex( 6921 Exception, "only supported when use_reentrant=False" 6922 ): 6923 out = checkpoint( 6924 lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn 6925 ) 6926 6927 def test_checkpoint_warns_if_use_reentrant_not_passed_explcitly(self): 6928 a = torch.randn(1, requires_grad=True) 6929 6930 # Passing explicitly should not warn 6931 self.assertNotWarn(lambda: checkpoint(lambda x: x, a, use_reentrant=False)) 6932 6933 # Not passing explicitly warns 6934 with self.assertWarnsOnceRegex( 6935 UserWarning, ".*the use_reentrant parameter should be passed explicitly.*" 6936 ): 6937 checkpoint(lambda x: x, a) 6938 6939 def test_checkpoint_sequential_warns_if_use_reentrant_not_passed_explcitly(self): 6940 a = torch.randn(3, requires_grad=True) 6941 modules_list = [ 6942 torch.nn.Linear(3, 3), 6943 torch.nn.Linear(3, 3), 6944 torch.nn.Linear(3, 3), 6945 ] 6946 6947 # Passing explicitly should not warn 6948 self.assertNotWarn( 6949 lambda: checkpoint_sequential(modules_list, 3, a, use_reentrant=False) 6950 ) 6951 6952 # Not passing explicitly warns 6953 with self.assertWarnsOnceRegex( 6954 UserWarning, ".*the use_reentrant parameter should be passed explicitly.*" 6955 ): 6956 checkpoint_sequential(modules_list, 3, a) 6957 6958 def test_checkpoint_detects_non_determinism(self): 6959 def save_3_tensors(x): 6960 out = x.sin().exp() 6961 out = out.sin() 6962 return out 6963 6964 def save_2_tensors(x): 6965 return x.sin().exp() 6966 6967 def save_2_tensors_alt(x): 6968 return x.sin() * torch.tensor([1.0, 2.0]) 6969 6970 def get_non_det_fn(orig_fn, recompute_fn): 6971 counter = [0] 6972 6973 def fn(x): 6974 if counter[0] == 0: 6975 counter[0] += 1 6976 return orig_fn(x) 6977 else: 6978 return recompute_fn(x) 6979 6980 return fn 6981 6982 a = torch.randn(1, requires_grad=True) 6983 6984 # Save fewer tensors during recompute 6985 fn = get_non_det_fn(orig_fn=save_3_tensors, recompute_fn=save_2_tensors) 6986 with self.assertRaisesRegex( 6987 RuntimeError, "A different number of tensors was saved" 6988 ): 6989 out = checkpoint(fn, a, use_reentrant=False) 6990 out.backward() 6991 6992 # Save more tensors during recompute 6993 fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors) 6994 with torch.utils.checkpoint.set_checkpoint_early_stop(False): 6995 with self.assertRaisesRegex( 6996 RuntimeError, "trying to save more tensors during recomputation" 6997 ): 6998 out = checkpoint(fn, a, use_reentrant=False) 6999 out.backward() 7000 7001 fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors) 7002 # If early stopping is enabled, we would not raise (the results would be correct anyway) 7003 out = checkpoint(fn, a, use_reentrant=False) 7004 out.backward() 7005 7006 # Save the same number of tensors but the shape is different 7007 fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7008 with self.assertRaisesRegex(RuntimeError, "tensors have different metadata"): 7009 out = checkpoint(fn, a, use_reentrant=False) 7010 out.backward() 7011 7012 # Get the debug message if debug=True 7013 fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7014 7015 with self.assertRaisesRegex( 7016 RuntimeError, 7017 "You are seeing this error because you passed `debug=True` to checkpoint", 7018 ): 7019 out = checkpoint(fn, a, use_reentrant=False, debug=True) 7020 out.backward() 7021 7022 fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7023 7024 with self.assertRaisesRegex( 7025 RuntimeError, 7026 "You are seeing this error because you passed `debug=True` to checkpoint", 7027 ): 7028 with torch.utils.checkpoint.set_checkpoint_debug_enabled(True): 7029 out = checkpoint(fn, a, use_reentrant=False, debug=False) 7030 out.backward() 7031 7032 fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7033 7034 with self.assertRaisesRegex( 7035 RuntimeError, "Recomputed values for the following tensors have different" 7036 ): 7037 with torch.utils.checkpoint.set_checkpoint_debug_enabled(False): 7038 out = checkpoint(fn, a, use_reentrant=False, debug=True) 7039 out.backward() 7040 7041 def test_access_saved_tensor_twice_without_recomputation_works(self): 7042 count = [0] 7043 7044 def foo(a): 7045 count[0] += 1 7046 b = a * a 7047 c = a * b 7048 d = torch.exp(a) 7049 return d 7050 7051 a = torch.randn(5, requires_grad=True) 7052 d = checkpoint(foo, a, use_reentrant=False) 7053 self.assertEqual(count[0], 1) 7054 # Recomputed variables only persist within a particular backward call. 7055 # If _saved_result is accessed outside of a backward, it will trigger 7056 # a recompute. And afterwards, those recomputed results are immediately 7057 # cleared. 7058 d.grad_fn._saved_result 7059 self.assertEqual(count[0], 2) 7060 # Second access will trigger another recompute 7061 d.grad_fn._saved_result 7062 self.assertEqual(count[0], 3) 7063 # Backward clears the saved variable 7064 d.sum().backward() 7065 self.assertEqual(count[0], 4) 7066 # Now it raises an error 7067 with self.assertRaisesRegex( 7068 RuntimeError, 7069 "or directly access saved tensors after they have already been freed", 7070 ): 7071 d.grad_fn._saved_result 7072 7073 @slowTest 7074 @parametrize("input_requires_grad", [True, False]) 7075 def test_checkpointing_without_reentrant(self, input_requires_grad): 7076 """ 7077 Basic test for checkpoint without reentrant autograd. 7078 """ 7079 num_inp = 2000 7080 nz_inp = 10 7081 nz_out = 10 7082 nz_bottleneck = 1000 7083 7084 # small proxy network for some complex reasoning we want to do per input 7085 module = nn.Sequential( 7086 nn.Linear(nz_inp, nz_bottleneck), 7087 nn.ReLU(), 7088 nn.Linear(nz_bottleneck, nz_inp), 7089 ) 7090 7091 # Module holder for testing activation checkpointing with no_reentrant 7092 # supports kwargs. 7093 class MyModule(nn.Module): 7094 def __init__(self, mod): 7095 super().__init__() 7096 self.module = mod 7097 7098 def forward(self, data): 7099 return self.module(data) 7100 7101 module = MyModule(mod=module) 7102 7103 # Run model with and without checkpointing and verify gradients are 7104 # equivalent, regardless of if inputs require grads or not. 7105 module_copy = deepcopy(module) 7106 7107 feat_combined = [] 7108 feat_combined_no_checkpoint = [] 7109 for r in range(num_inp): 7110 data_r = torch.empty(1, nz_inp) 7111 data_r.uniform_() 7112 data_r.requires_grad = input_requires_grad 7113 data_r_copy = data_r.clone() 7114 feat_r = checkpoint(module, data=data_r, use_reentrant=False) 7115 feat_combined.append(feat_r) 7116 feat_r_no_checkpoint = module_copy(data_r) 7117 feat_combined_no_checkpoint.append(feat_r_no_checkpoint) 7118 7119 # compute mean as a proxy for some joint reasoning 7120 mean_combined = torch.stack(feat_combined).mean() 7121 mean_combined.backward() 7122 mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean() 7123 mean_combined_no_checkpoint.backward() 7124 7125 for checkpoint_param, param in zip( 7126 module.parameters(), module_copy.parameters() 7127 ): 7128 self.assertEqual(checkpoint_param.grad, param.grad) 7129 7130 def test_checkpoint_valid_reset_on_error(self): 7131 a = torch.randn(2, 2, requires_grad=True) 7132 7133 with self.assertRaisesRegex( 7134 Exception, "torch.utils.checkpoint is incompatible" 7135 ): 7136 b = checkpoint(torch.exp, a, use_reentrant=True).sum() 7137 torch.autograd.grad(b, (a,)) 7138 7139 c = checkpoint(torch.exp, a, use_reentrant=True).sum() 7140 c.backward() 7141 7142 @parametrize("use_reentrant", [True, False]) 7143 def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant): 7144 class NoGradModule(torch.nn.Module): 7145 def __init__(self) -> None: 7146 super().__init__() 7147 self.linear = nn.Linear(2, 2, bias=False) 7148 self.lin2 = nn.Linear(2, 2, bias=False) 7149 7150 def forward(self, x): 7151 with torch.no_grad(): 7152 return self.lin2(self.linear(x)) 7153 7154 module = NoGradModule() 7155 7156 err_ctx = ( 7157 self.assertRaisesRegex( 7158 RuntimeError, "none of output has requires_grad=True" 7159 ) 7160 if use_reentrant 7161 else contextlib.nullcontext() 7162 ) 7163 7164 a = torch.randn(2, 2, requires_grad=True) 7165 for _ in range(3): 7166 with err_ctx: 7167 # out does not require grad 7168 out = checkpoint(module, a, use_reentrant=use_reentrant) 7169 # Make loss require grad, otherwise we would run into 7170 # "element 0 of tensors does not require grad and does not have a grad_fn" 7171 out += a 7172 out.sum().backward() 7173 7174 def test_checkpointing_without_reentrant_saved_object_identity(self): 7175 x_backward = None 7176 7177 class Test(torch.autograd.Function): 7178 @staticmethod 7179 def forward(ctx, x, y): 7180 ctx.save_for_backward(y) 7181 return x 7182 7183 @staticmethod 7184 def backward(ctx, x): 7185 nonlocal x_backward 7186 (x_backward,) = ctx.saved_tensors 7187 return x, None 7188 7189 a = torch.tensor(1.0, requires_grad=True) 7190 b = torch.tensor(1.0, requires_grad=False) 7191 7192 Test.apply(a, b).backward() 7193 self.assertIs(b, x_backward) 7194 7195 x_backward = None 7196 checkpoint(Test.apply, a, b, use_reentrant=False).backward() 7197 self.assertIs(b, x_backward) 7198 7199 def test_checkpointing_without_reentrant_correct_grad(self): 7200 """ 7201 Verifies that correct gradients are calculated for checkpoint 7202 without reentrant autograd, for both backward() and autograd.grad(). 7203 """ 7204 a = torch.randn(2, 2, requires_grad=True) 7205 7206 b = torch.exp(a).sum() 7207 b.backward() 7208 b_grad = a.grad 7209 7210 a.grad = None 7211 c = checkpoint(torch.exp, a, use_reentrant=False).sum() 7212 c.backward() 7213 c_grad = a.grad 7214 7215 a.grad = None 7216 d = checkpoint(torch.exp, a, use_reentrant=False).sum() 7217 (d_grad,) = torch.autograd.grad(d, (a,)) 7218 7219 self.assertEqual(b_grad, c_grad) 7220 self.assertEqual(b_grad, d_grad) 7221 7222 # PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally 7223 @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115") 7224 def test_checkpointing_without_reentrant_dataparallel(self): 7225 """ 7226 Verifies gradient correctness when checkpoint without reentrant autograd 7227 is used in conjunction with DataParallel. 7228 """ 7229 7230 class LinearModule(torch.nn.Module): 7231 def __init__(self) -> None: 7232 super().__init__() 7233 self.linear = nn.Linear(2, 2, bias=False) 7234 7235 def forward(self, inp): 7236 return self.linear(inp) 7237 7238 a = torch.randn(2, 2, requires_grad=True) 7239 if torch.cuda.is_available(): 7240 a = a.cuda() 7241 7242 model = LinearModule() 7243 if torch.cuda.is_available(): 7244 model = model.cuda() 7245 7246 b = deepcopy(model)(a).sum() 7247 b.backward() 7248 b_grad = a.grad 7249 7250 a.grad = None 7251 7252 module = torch.nn.DataParallel(deepcopy(model)) 7253 c = checkpoint(module, a, use_reentrant=False).sum() 7254 c.backward() 7255 c_grad = a.grad 7256 7257 self.assertEqual(b_grad, c_grad) 7258 7259 def test_checkpointing_without_reentrant_parameter_used_in_an_out(self): 7260 """ 7261 Ensures that gradient hooks are only called once per tensor. 7262 """ 7263 w = torch.randn(10, 10, requires_grad=True) 7264 count = 0 7265 7266 def hook(grad): 7267 nonlocal count 7268 count += 1 7269 7270 w.register_hook(hook) 7271 x = torch.rand(10, 10, requires_grad=True) 7272 h = w * x # Using w outside the checkpoint 7273 out = checkpoint( 7274 lambda x: w * x, h, use_reentrant=False 7275 ) # Using w inside the checkpoint 7276 7277 out.sum().backward() 7278 # should only call hook once 7279 self.assertEqual(count, 1) 7280 7281 # https://github.com/pytorch/pytorch/issues/127115 7282 @xfailIfTorchDynamo 7283 def test_checkpointing_without_reentrant_arbitrary_input_output(self): 7284 """ 7285 Ensures checkpointing without reentrant autograd works with functions 7286 with arbitrary input/output structures. 7287 """ 7288 7289 class MyModel(torch.nn.Module): 7290 def __init__(self) -> None: 7291 super().__init__() 7292 self.layer = torch.nn.Linear(5, 5, bias=False) 7293 7294 def forward(self, dict_input): 7295 tensor = dict_input["tensor"] 7296 return {"result": self.layer(tensor)} 7297 7298 model_no_checkpoint = MyModel() 7299 model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint) 7300 7301 inp = {"tensor": torch.randn(5, 5)} 7302 7303 out_no_checkpoint = model_no_checkpoint(inp)["result"].sum() 7304 7305 out_checkpoint = checkpoint( 7306 model_checkpoint_without_reentrant, inp, use_reentrant=False 7307 )["result"].sum() 7308 7309 self.assertEqual(out_checkpoint, out_no_checkpoint) 7310 7311 out_no_checkpoint.backward() 7312 out_checkpoint.backward() 7313 7314 for param, checkpoint_param in zip( 7315 model_no_checkpoint.parameters(), 7316 model_checkpoint_without_reentrant.parameters(), 7317 ): 7318 self.assertEqual(param.grad, checkpoint_param.grad) 7319 7320 def test_callback_adds_callback(self): 7321 called = [0] 7322 7323 def callback_final(): 7324 called[0] += 1 7325 7326 def callback_adds_callback(): 7327 called[0] += 1 7328 Variable._execution_engine.queue_callback(callback_final) 7329 7330 class MyFunc(Function): 7331 @staticmethod 7332 def forward(ctx, input): 7333 return input 7334 7335 @staticmethod 7336 @once_differentiable 7337 def backward(ctx, grad): 7338 Variable._execution_engine.queue_callback(callback_adds_callback) 7339 return grad 7340 7341 a = torch.rand((3, 3), requires_grad=True) 7342 b = MyFunc.apply(a) 7343 b.sum().backward() 7344 7345 self.assertEqual(called[0], 2) 7346 7347 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 7348 def test_callback_propagates_errors_from_device_thread(self): 7349 def callback(): 7350 raise RuntimeError("blah") 7351 7352 def hook_with_callback(*args): 7353 torch.autograd.Variable._execution_engine.queue_callback(callback) 7354 7355 t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda")) 7356 t.register_hook(hook_with_callback) 7357 output = t**2 7358 loss = output.sum() 7359 7360 with self.assertRaisesRegex(RuntimeError, "blah"): 7361 loss.backward() 7362 7363 def _test_reentrant_with_callbacks(self, install_callbacks_in_depths): 7364 counter = {} 7365 counter["inner"] = 0 7366 counter["outer"] = 0 7367 7368 def inc_inner_counter(): 7369 counter["inner"] += 1 7370 7371 def inc_outer_counter(): 7372 counter["outer"] += 1 7373 7374 class MyFunc(Function): 7375 @staticmethod 7376 def forward(ctx, input): 7377 return input 7378 7379 @staticmethod 7380 @once_differentiable 7381 def backward(ctx, input): 7382 if 1 in install_callbacks_in_depths: 7383 # Add a callback to execute. 7384 Variable._execution_engine.queue_callback(inc_inner_counter) 7385 7386 return input 7387 7388 class MyReentrantFunc(Function): 7389 @staticmethod 7390 def forward(ctx, input): 7391 return input 7392 7393 @staticmethod 7394 @once_differentiable 7395 def backward(ctx, input): 7396 if 0 in install_callbacks_in_depths: 7397 # Add a callback to execute. 7398 Variable._execution_engine.queue_callback(inc_outer_counter) 7399 # Reentrant backward call. 7400 tmp_inp = input.detach().requires_grad_() 7401 with torch.enable_grad(): 7402 tmp_out = (MyFunc.apply(tmp_inp)).sum() 7403 tmp_out.backward() 7404 return input 7405 7406 t1 = torch.rand((3, 3), requires_grad=True) 7407 t2 = MyReentrantFunc.apply(t1) 7408 t3 = t2.sum() 7409 torch.autograd.backward([t3]) 7410 7411 return counter 7412 7413 def test_reentrant_with_callbacks_depth_0(self): 7414 # Verify callback is called only once. 7415 ret = self._test_reentrant_with_callbacks([0]) 7416 self.assertEqual(1, ret["outer"]) 7417 self.assertEqual(0, ret["inner"]) 7418 7419 def test_reentrant_with_callbacks_depth_1(self): 7420 # Verify callback is called only once. 7421 ret = self._test_reentrant_with_callbacks([1]) 7422 self.assertEqual(0, ret["outer"]) 7423 self.assertEqual(1, ret["inner"]) 7424 7425 def test_reentrant_with_callbacks_both_depths(self): 7426 # Verify callback is called twice. 7427 ret = self._test_reentrant_with_callbacks([0, 1]) 7428 self.assertEqual(1, ret["outer"]) 7429 self.assertEqual(1, ret["inner"]) 7430 7431 def test_reentrant_with_leaf_variable_hook(self): 7432 handle = None 7433 param = torch.rand(10, requires_grad=True) 7434 7435 def add_gradient_penalty_to_grad(grad): 7436 handle.remove() 7437 old_param_grad = grad 7438 param.grad = None 7439 # Add some sort of gradient penalty by directly updating the gradients 7440 with torch.enable_grad(): 7441 g = grad.detach().requires_grad_() 7442 new_param = param.detach().requires_grad_() 7443 out = ((g * 2) + new_param).sum() 7444 out.backward() 7445 res = g.grad + grad 7446 param.grad = old_param_grad 7447 return res 7448 7449 handle = param.register_hook(add_gradient_penalty_to_grad) 7450 # Forward pass 7451 tmp = param * param 7452 loss = tmp.sum() 7453 # Compute the gradients 7454 loss.backward() 7455 7456 def test_reentrant_with_non_leaf_variable_hook(self): 7457 handle = None 7458 param = torch.rand(10, requires_grad=True) 7459 7460 def manual_increase_gradient(grad): 7461 handle.remove() 7462 # Add some sort of gradient penalty by directly updating the gradients 7463 with torch.enable_grad(): 7464 g = grad.detach().requires_grad_() 7465 out = ((g * 2) + 5).sum() 7466 out.backward() 7467 res = g.grad + grad 7468 return res 7469 7470 # Forward pass 7471 tmp = param * param 7472 handle = tmp.register_hook(manual_increase_gradient) 7473 loss = tmp.sum() 7474 # Compute the gradients 7475 loss.backward() 7476 self.assertEqual(param.grad, 6 * param) 7477 7478 def test_grad_fn_attr_bindings(self): 7479 # Check that the getter of each type returns what we want 7480 # See `gen_autograd_functions.py` for how the getters are generated 7481 # 7482 # This test is only meant to check if the codegen'd bindings work 7483 # Please help update this test if you update the names of any the fields we check! 7484 # 7485 a = torch.ones(1, requires_grad=True) 7486 b = torch.zeros(1, requires_grad=True) 7487 out1 = torch.stack([a, b], dim=0) 7488 out2 = (a * 2) * b 7489 # TODO: I don't think we have a backward saving a list of tensors 7490 # at the moment. It used to be stack, but for no reason... 7491 # see discussion in #84993 7492 # self.assertEqual(out.grad_fn._saved_tensors, (a, b)) # TewnsorList -> Tuple[Tensor] 7493 self.assertEqual(out2.grad_fn._saved_self, a * 2) 7494 self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor) 7495 self.assertIsInstance( 7496 out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor 7497 ) 7498 self.assertEqual(out1.grad_fn._saved_dim, 0) # int64_t -> int 7499 self.assertIsInstance(out1.grad_fn._saved_dim, int) 7500 7501 out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 7502 7503 out2.sum().backward() 7504 with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7505 out2.grad_fn._saved_self 7506 # TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0], 7507 # not when using a saved tensor, see discussion in #84993 7508 # with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7509 # out2.grad_fn._raw_saved_self 7510 self.assertEqual(out1.grad_fn._saved_dim, 0) 7511 7512 a = torch.ones(2, 2, requires_grad=True) 7513 indices = torch.tensor([0, 1]) 7514 out = a[:, indices] 7515 self.assertEqual( 7516 out.grad_fn._saved_indices, (None, indices) 7517 ) # c10::List<std::optional<Tensor>> -> Tuple[Tensor?] 7518 self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor) 7519 self.assertIsInstance( 7520 out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor 7521 ) 7522 self.assertEqual( 7523 out.grad_fn._saved_self_sym_sizes, a.shape 7524 ) # SymIntArrayRef -> Tuple[SymInt] 7525 self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int) 7526 7527 out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x) 7528 with self.assertRaisesRegex(RuntimeError, "None is forbidden"): 7529 out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x) 7530 7531 out = a.mean() 7532 self.assertEqual( 7533 out.grad_fn._saved_self_sym_sizes, a.shape 7534 ) # IntArrayRef -> Tuple[int] 7535 7536 a = torch.ones(2, 2, requires_grad=True) 7537 out = a * a 7538 out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 7539 out.sum().backward() 7540 with self.assertRaisesRegex(RuntimeError, "after it has been freed"): 7541 out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 7542 7543 a = torch.ones(1, 1, 2, requires_grad=True) 7544 out = torch.nn.functional.interpolate(a, 4, mode="linear") 7545 self.assertEqual( 7546 out.grad_fn._saved_output_size, (4,) 7547 ) # std::optional<IntArrayRef> -> int[]? 7548 self.assertIsInstance(out.grad_fn._saved_output_size[0], int) 7549 self.assertEqual(out.grad_fn._saved_align_corners, False) # bool -> bool 7550 self.assertIsInstance(out.grad_fn._saved_align_corners, bool) 7551 if hasattr(out.grad_fn, "_saved_scale_factors"): 7552 self.assertIsNone( 7553 out.grad_fn._saved_scale_factors 7554 ) # std::optional<ArrayRef<double>> -> float[]? 7555 else: 7556 self.assertIsNone( 7557 out.grad_fn._saved_scales 7558 ) # std::optional<ArrayRef<double>> -> float[]? 7559 7560 a = torch.ones(1, 1, 3, 3, requires_grad=True) 7561 out = nn.Conv2d(1, 1, 3)(a) 7562 self.assertEqual( 7563 out.grad_fn._saved_bias_sym_sizes_opt, (1,) 7564 ) # std::optional<SymIntArrayRef> -> SymInt[]? 7565 out = nn.Conv2d(1, 1, 3, bias=False)(a) 7566 # TODO: This is BAD! we converted a std::nullopt into a (0,) 7567 self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,)) 7568 7569 a = torch.ones(1, 3, 3, requires_grad=True) 7570 out = torch.addbmm(a.squeeze(0), a, a) 7571 self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1) # int64_t 7572 self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3) # int64_t 7573 7574 a = torch.ones(1, 1, 3, 3, requires_grad=True) 7575 out = torch.nn.functional.unfold(a, 3) 7576 self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3) # SymInt 7577 self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3) # SymInt 7578 7579 a = torch.ones(1, 1, 2, requires_grad=True) 7580 out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear") 7581 self.assertEqual(out.grad_fn._saved_scales, 0.5) 7582 7583 a = torch.ones(2, 2, requires_grad=True) 7584 out = torch.pdist(a, p=1) 7585 self.assertEqual(out.grad_fn._saved_p, 1.0) # double -> float 7586 self.assertIsInstance(out.grad_fn._saved_p, float) 7587 7588 a = torch.ones(1, 1, 2, requires_grad=True) 7589 out = torch.logit(a, 1.0) 7590 self.assertEqual(out.grad_fn._saved_eps, 1.0) # c10:optional<double> -> float? 7591 self.assertIsInstance(out.grad_fn._saved_eps, float) 7592 out = torch.logit(a) 7593 self.assertIsNone(out.grad_fn._saved_eps) 7594 7595 if torch._C.has_lapack: 7596 a = torch.ones(1, 1, requires_grad=True) 7597 q, r = torch.linalg.qr(a, mode="reduced") 7598 self.assertEqual(q.grad_fn._saved_mode, "reduced") # std::string -> str 7599 7600 a = torch.tensor([1.0], requires_grad=True) 7601 out = torch.div(a, 2.0, rounding_mode="trunc") 7602 self.assertEqual( 7603 out.grad_fn._saved_rounding_mode, "trunc" 7604 ) # std::optional<std::string> -> str? 7605 out = torch.div(a, 2.0, rounding_mode=None) 7606 self.assertIsNone( 7607 out.grad_fn._saved_rounding_mode 7608 ) # std::optional<std::string> -> str? 7609 7610 x = torch.zeros(5, requires_grad=True) 7611 out = torch.threshold(x, threshold=(1 + 0j), value=(1 + 0j)) 7612 self.assertIsInstance( 7613 out.grad_fn._saved_threshold, complex 7614 ) # Scalar(complex double) -> complex 7615 cfloat = torch.tensor(1 + 0j, dtype=torch.complex64) 7616 out = torch.threshold(x, threshold=cfloat, value=(1 + 0j)) 7617 self.assertIsInstance( 7618 out.grad_fn._saved_threshold, complex 7619 ) # Scalar(complex float) -> complex 7620 out = torch.threshold(x, threshold=1.0, value=1.0) 7621 self.assertIsInstance( 7622 out.grad_fn._saved_threshold, float 7623 ) # Scalar(floating point) -> float 7624 out = torch.threshold(x, threshold=1, value=1) 7625 self.assertIsInstance( 7626 out.grad_fn._saved_threshold, int 7627 ) # Scalar(integral) -> int 7628 out = torch.threshold(x, threshold=False, value=False) 7629 self.assertIsInstance( 7630 out.grad_fn._saved_threshold, bool 7631 ) # Scalar(bool) -> bool 7632 7633 a = torch.ones(2, 2, requires_grad=True) 7634 out = a.as_strided((3,), (1,), 1) 7635 self.assertEqual( 7636 out.grad_fn._saved_storage_offset, 1 7637 ) # c10:optional<int64_t> -> int? 7638 self.assertIsInstance(out.grad_fn._saved_storage_offset, int) 7639 out = a.as_strided((3,), (1,)) 7640 self.assertIsNone(out.grad_fn._saved_storage_offset) 7641 7642 a = torch.ones(2, requires_grad=True) 7643 out = torch.tanh(a) 7644 self.assertEqual(out, out.grad_fn._saved_result) # saved variable when output 7645 7646 a = torch.randn(3, 5, requires_grad=True) 7647 b = torch.tensor([1, 0, 4]) 7648 loss = nn.NLLLoss() 7649 out = loss(a, b) 7650 self.assertIsNone(out.grad_fn._saved_weight) 7651 loss = nn.NLLLoss(weight=torch.ones((5,))) 7652 out = loss(a, b) 7653 self.assertEqual( 7654 out.grad_fn._saved_weight, torch.ones((5,)) 7655 ) # c10:optional<Tensor> -> Tensor? 7656 7657 out.sum().backward() 7658 with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7659 out.grad_fn._saved_weight 7660 7661 num_tensors = 3 7662 input_tensors = [ 7663 torch.ones(2, 2, requires_grad=True) for _ in range(num_tensors) 7664 ] 7665 scalars = [ 7666 0.0 for _ in range(num_tensors) 7667 ] # ArrayRef<Scalar> -> Tuple[Scalar, ...] 7668 results = torch._foreach_maximum(input_tensors, scalars) 7669 for t in results: 7670 self.assertEqual(t.grad_fn._saved_scalars, scalars) 7671 7672 def test_cant_create_saved_tensors(self): 7673 with self.assertRaisesRegex( 7674 RuntimeError, 7675 "Trying to create a SavedTensor object from Python is forbidden", 7676 ): 7677 torch.autograd.SavedTensor() 7678 7679 def test_custom_function_saved_tensors(self): 7680 def getFn(save=True): 7681 class MyFn(Function): 7682 @staticmethod 7683 def forward(ctx, x): 7684 if save: 7685 ctx.save_for_backward(x, None) 7686 return x 7687 7688 @staticmethod 7689 def backward(ctx, g): 7690 return g 7691 7692 return MyFn 7693 7694 a = torch.randn(5, requires_grad=True) 7695 7696 y = getFn(True).apply(a) 7697 7698 self.assertEqual((a, None), y.grad_fn.saved_tensors) 7699 saved = y.grad_fn._raw_saved_tensors 7700 self.assertIsInstance(saved[0], torch._C._autograd.SavedTensor) 7701 # We can't tell the underlying tensor is None without unpacking it 7702 self.assertIsInstance(saved[1], torch._C._autograd.SavedTensor) 7703 7704 # We catch that error when the user calls register_hooks on it 7705 with self.assertRaisesRegex(RuntimeError, "None is forbidden"): 7706 saved[1].register_hooks(lambda x: x, lambda x: x) 7707 7708 with self.assertRaisesRegex(TypeError, "incompatible function arguments"): 7709 saved[0].register_hooks(lambda x: x) 7710 with self.assertRaisesRegex(TypeError, "incompatible function arguments"): 7711 saved[0].register_hooks(1, 1) 7712 saved[0].register_hooks(lambda x: x, lambda x: x) 7713 with self.assertRaisesRegex(RuntimeError, "already been set"): 7714 saved[0].register_hooks(lambda x: x, lambda x: x) 7715 y.sum().backward() 7716 7717 # Using a reference to the SavedTensor object after the 7718 # saved variables have been released can lead to undefined behavior 7719 del saved 7720 with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7721 y.grad_fn._raw_saved_tensors 7722 with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7723 y.grad_fn.saved_tensors 7724 7725 y = getFn(False).apply(a) 7726 self.assertEqual(y.grad_fn.saved_tensors, ()) 7727 self.assertEqual(y.grad_fn._raw_saved_tensors, ()) 7728 7729 def test_autograd_node_isinstance(self): 7730 # Node is a "virtual" base class of codegen'd nodes. This means that 7731 # isinstance and issubclass are overridden, but mro is unchanged 7732 Node = torch.autograd.graph.Node 7733 7734 a = torch.rand(3, 3, requires_grad=True) 7735 b = a.exp() 7736 7737 # Some nodes have codegened registrations to the torch._C._function module 7738 self.assertIsInstance(b.grad_fn, Node) 7739 self.assertTrue(issubclass(type(b.grad_fn), Node)) 7740 self.assertTrue(Node not in type(b.grad_fn).mro()) 7741 7742 # Other nodes have manual registrations to the torch._C._function module 7743 self.assertNotIsInstance(torch._C._functions.AccumulateGrad, Node) 7744 self.assertTrue(issubclass(torch._C._functions.AccumulateGrad, Node)) 7745 self.assertIsInstance(b.grad_fn.next_functions[0][0], Node) 7746 self.assertTrue(issubclass(torch._C._functions.DelayedError, Node)) 7747 7748 # Special cases 7749 self.assertNotIsInstance(None, Node) 7750 self.assertNotIsInstance(1, Node) 7751 self.assertNotIsInstance(Node, Node) 7752 self.assertTrue(issubclass(Node, Node)) 7753 7754 # Custom function case 7755 self.assertTrue(issubclass(torch.autograd.function.BackwardCFunction, Node)) 7756 7757 class Func(torch.autograd.Function): 7758 @staticmethod 7759 def forward(ctx, x): 7760 self.assertIsInstance(ctx, Node) 7761 return x 7762 7763 @staticmethod 7764 def backward(ctx, x): 7765 self.assertIsInstance(ctx, Node) 7766 return x 7767 7768 out = Func.apply(a) 7769 self.assertIsInstance(out.grad_fn, Node) 7770 self.assertTrue(issubclass(type(out.grad_fn), Node)) 7771 self.assertTrue(Node not in type(out.grad_fn).mro()) 7772 out.sum().backward() 7773 7774 def test_autograd_views_codegen(self): 7775 # This is not necessarily the absolute correct behavior, but this is the current 7776 # one. This test is here to make sure that any change to this behavior is detected 7777 # and not silent. The TODOs below mark the places with unexpected behavior. 7778 # Note that any change in these test will be BC-breaking and should be done carefully. 7779 7780 # This test checks the behavior of two codegen functions (view_as and unbind) 7781 # with respect to view tracking and inplace operation on the output. 7782 7783 def run_test(grad_mode, requires_grad, is_view, should_raise_tuple): 7784 def maybe_check_raise(fn, should_raise): 7785 self.assertTrue(should_raise is None or isinstance(should_raise, str)) 7786 if should_raise is not None: 7787 with self.assertRaisesRegex(RuntimeError, should_raise): 7788 fn() 7789 else: 7790 fn() 7791 7792 inp = torch.rand(2, requires_grad=requires_grad).clone() 7793 with torch.set_grad_enabled(grad_mode): 7794 out = inp.view_as(inp) 7795 # Are they differentiable views? 7796 self.assertTrue(out._is_view() == is_view) 7797 # Are inplace allowed? 7798 maybe_check_raise(lambda: out.add_(1), should_raise_tuple[0]) 7799 7800 inp = torch.rand(2, requires_grad=requires_grad).clone() 7801 with torch.set_grad_enabled(grad_mode): 7802 out = inp.unbind() 7803 # Are they differentiable views? 7804 self.assertTrue(out[0]._is_view() == is_view) 7805 self.assertTrue(out[1]._is_view() == is_view) 7806 # Are inplace allowed? 7807 maybe_check_raise(lambda: out[0].add_(1), should_raise_tuple[1]) 7808 maybe_check_raise(lambda: out[1].add_(1), should_raise_tuple[2]) 7809 7810 # should_raise contains None if it should not raise 7811 # should_raise contains a string of the error if it should raise 7812 # The 3 elements are for view_as, first output of unbind and second output of unbind 7813 run_test( 7814 grad_mode=True, 7815 requires_grad=False, 7816 is_view=True, 7817 should_raise_tuple=(None, None, None), 7818 ) 7819 inp_change_err = ( 7820 "Output {} of UnbindBackward0 is a view and is being modified inplace." 7821 ) 7822 run_test( 7823 grad_mode=True, 7824 requires_grad=True, 7825 is_view=True, 7826 should_raise_tuple=( 7827 None, 7828 inp_change_err.format("0"), 7829 inp_change_err.format("1"), 7830 ), 7831 ) 7832 leaf_grad_err = ( 7833 "A view was created in no_grad mode and is being modified inplace" 7834 ) 7835 run_test( 7836 grad_mode=False, 7837 requires_grad=True, 7838 is_view=True, 7839 should_raise_tuple=(leaf_grad_err, leaf_grad_err, leaf_grad_err), 7840 ) 7841 run_test( 7842 grad_mode=False, 7843 requires_grad=False, 7844 is_view=True, 7845 should_raise_tuple=(None, None, None), 7846 ) 7847 7848 def test_inplace_not_requires_grad(self): 7849 class MyFn(torch.autograd.Function): 7850 @staticmethod 7851 def forward(ctx, inp): 7852 return inp.view_as(inp) 7853 7854 @staticmethod 7855 def backward(ctx, grad): 7856 return grad 7857 7858 # Original Tensor does not require grad 7859 a = torch.rand(1, 2) 7860 7861 # Tensor being written does require grad 7862 b = torch.rand(1, requires_grad=True) 7863 7864 # Take an invalid view on 'a' that should raise an error (warns during deprecation) 7865 view_a = MyFn.apply(a) 7866 7867 with self.assertRaisesRegex( 7868 RuntimeError, "This view was created inside a custom Function" 7869 ): 7870 view_a += b 7871 7872 # Extra test for copy_ that is a manual implementation and could be easily 7873 # forgotten when the codegen is updated (warns during deprecation) 7874 a = torch.rand(1, 2) 7875 b = torch.rand(1, requires_grad=True) 7876 view_a = MyFn.apply(a) 7877 7878 with self.assertRaisesRegex( 7879 RuntimeError, "This view was created inside a custom Function" 7880 ): 7881 view_a.copy_(b) 7882 7883 # Functions that should throw must properly throw 7884 a = torch.rand(1, 2) 7885 b = torch.rand(1, requires_grad=True) 7886 view_a = a.unbind()[0] 7887 with self.assertRaisesRegex( 7888 RuntimeError, 7889 "This view is the output of a function that returns " "multiple views.", 7890 ): 7891 view_a.copy_(b) 7892 7893 # Sanity check that views that should work still work 7894 a = torch.rand(1, 2) 7895 b = torch.rand(1, requires_grad=True) 7896 a.select(1, 0).copy_(b) 7897 7898 def _do_test_autograd_simple_views_python(self, dtype): 7899 # This is not necessarily the absolute correct behavior, but this is the current 7900 # one. This test is here to make sure that any change to this behavior is detected 7901 # and not silent. The TODOs below mark the places with unexpected behavior. 7902 # Note that any change in these test will be BC-breaking and should be done carefully. 7903 7904 # This checks the autograd.Function behavior when we return one or multiple outputs 7905 # while one of these is an input, a view of an input or of a temporary tensor. 7906 7907 # This indicator is used to track how many times the backward function was called 7908 bw_called = [0] 7909 # This indicator is used to check if the argument `ga` contains non-zero values 7910 ga_nz = [False] 7911 7912 class IdOneOutput(Function): 7913 @staticmethod 7914 def forward(ctx, a, b, make_view): 7915 if make_view: 7916 a = a.narrow(0, 0, 2) 7917 else: 7918 a = a.clone() 7919 return a 7920 7921 @staticmethod 7922 def backward(ctx, ga): 7923 bw_called[0] += 1 7924 return ga, None, None 7925 7926 class IdTwoOutput(Function): 7927 @staticmethod 7928 def forward(ctx, a, b, make_view): 7929 if make_view: 7930 a = a.narrow(0, 0, 2) 7931 else: 7932 a = a.clone() 7933 return a, a + b 7934 7935 @staticmethod 7936 def backward(ctx, ga, gab): 7937 bw_called[0] += 1 7938 if ga.eq(0).all(): 7939 ga_nz[0] = False 7940 else: 7941 ga_nz[0] = True 7942 return ga + gab, gab, None 7943 7944 class ViewOfTemp(Function): 7945 @staticmethod 7946 def forward(ctx, a, make_view): 7947 ctx.save_for_backward(a) 7948 if make_view: 7949 a = a.narrow(0, 0, 2) 7950 else: 7951 a = a.clone() 7952 b = a.clone() 7953 return b.select(0, 0) 7954 7955 @staticmethod 7956 def backward(ctx, grad): 7957 bw_called[0] += 1 7958 (a,) = ctx.saved_tensors 7959 res = torch.zeros_like(a) 7960 res.select(0, 0).copy_(grad) 7961 return res, None 7962 7963 fn_id_to_inplace_on_view_err_msg = { 7964 "one_output": ( 7965 "Output 0 of IdOneOutputBackward is a view and is being " 7966 "modified inplace. This view was created inside a custom Function" 7967 ), 7968 "two_output": ( 7969 "Output 0 of IdTwoOutputBackward is a view and is being modified inplace." 7970 " This view is the output of a function that returns multiple views." 7971 ), 7972 "view_of_temp": ( 7973 "Output 0 of ViewOfTempBackward is a view and is being " 7974 "modified inplace. This view was created inside a custom Function" 7975 ), 7976 } 7977 7978 for fn_id in ["one_output", "two_output", "view_of_temp"]: 7979 for inplace in [True, False]: 7980 for make_view in [True, False]: 7981 # Used for special casing the tests below 7982 output_is_a_view = make_view or fn_id == "view_of_temp" 7983 7984 def fn(a, b): 7985 # never modify a, b inplace for gracheck 7986 a = a.clone() 7987 b = b.clone() 7988 if fn_id == "two_output": 7989 tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view) 7990 if inplace: 7991 tmp1 += 3 7992 tmp2 += 3 7993 else: 7994 tmp1 = tmp1 + 3 7995 tmp2 = tmp2 + 3 7996 tmp = tmp1 * tmp2 7997 else: 7998 if fn_id == "one_output": 7999 tmp = IdOneOutput.apply(a, b, make_view) 8000 else: 8001 tmp = ViewOfTemp.apply(a + b, make_view) 8002 if inplace: 8003 tmp += 3 8004 else: 8005 tmp = tmp + 3 8006 8007 return tmp.sum() 8008 8009 a = torch.ones(2, dtype=dtype, requires_grad=True) 8010 b = torch.ones(2, dtype=dtype, requires_grad=True) 8011 8012 err_msg = fn_id_to_inplace_on_view_err_msg[fn_id] 8013 8014 if not inplace or not output_is_a_view: 8015 gradcheck(fn, (a, b), check_batched_grad=False) 8016 8017 # Was the custom backward called properly 8018 bw_called[0] = 0 8019 ga_nz[0] = True # For the case where the backward is called 8020 8021 if inplace and output_is_a_view: 8022 with self.assertRaisesRegex(RuntimeError, err_msg): 8023 fn(a, b) 8024 else: 8025 fn(a, b).abs().backward() 8026 8027 expected_called = 1 8028 expected_ga_nz = True 8029 8030 if output_is_a_view and inplace: 8031 expected_called = 0 8032 8033 self.assertTrue(bw_called[0] == expected_called) 8034 self.assertTrue(ga_nz[0] == expected_ga_nz) 8035 8036 def test_autograd_simple_views_python(self): 8037 self._do_test_autograd_simple_views_python(torch.double) 8038 self._do_test_autograd_simple_views_python(torch.cdouble) 8039 8040 def test_autograd_inplace_views_creation_meta(self): 8041 # Tests creation_meta properly handled for inplace views 8042 8043 class Func(torch.autograd.Function): 8044 @staticmethod 8045 def forward(ctx, x): 8046 return x.view_as(x) 8047 8048 @staticmethod 8049 def backward(ctx, x): 8050 return x 8051 8052 view_custom = Func.apply 8053 8054 def run_test( 8055 fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2 8056 ): 8057 # This test checks the behavior of inplace-view functions when 8058 # the views are created in grad mode or not 8059 base = torch.rand(2, 3, requires_grad=requires_grad).clone() 8060 # 1. Create a view with `grad_mode=grad_mode_view` 8061 with torch.set_grad_enabled(grad_mode_view): 8062 if fn_type == "multi_view": 8063 inp = base.unbind()[0] 8064 elif fn_type == "custom": 8065 inp = view_custom(base) 8066 else: 8067 inp = base.view_as(base) 8068 8069 # 2. Perform inplace view with `grad_mode=grad_mode_iview` 8070 with torch.set_grad_enabled(grad_mode_iview): 8071 if error1 is not None: 8072 with self.assertRaisesRegex(RuntimeError, error1): 8073 fn(inp) 8074 return 8075 else: 8076 # If error is None, check that runs without error 8077 fn(inp) 8078 # 3. Do inplace on the (new) view 8079 if error2 is not None: 8080 with self.assertRaisesRegex(RuntimeError, error2): 8081 inp.add_(1) 8082 else: 8083 # If error is None, check that runs without error 8084 inp.add_(1) 8085 8086 no_grad_err = "A view was created in no_grad mode" 8087 multi_view_err = "function that returns multiple views" 8088 custom_err = "view was created inside a custom Function" 8089 8090 def run_tests(fn): 8091 for fn_type in ("normal", "multi_view", "custom"): 8092 for grad_mode_view in (True, False): 8093 for grad_mode_iview in (True, False): 8094 for requires_grad in (True, False): 8095 error1 = None # expected error when we do inplace_view on original view 8096 error2 = None # expected error when we do inplace on the resulting view 8097 8098 if requires_grad: 8099 if not grad_mode_view and grad_mode_iview: 8100 error1 = no_grad_err 8101 if not grad_mode_view and not grad_mode_iview: 8102 error2 = no_grad_err 8103 8104 if fn_type == "multi_view": 8105 if grad_mode_view and grad_mode_iview: 8106 error1 = multi_view_err 8107 if grad_mode_view and not grad_mode_iview: 8108 error2 = multi_view_err 8109 8110 if fn_type == "custom": 8111 if grad_mode_view and grad_mode_iview: 8112 error1 = custom_err 8113 if grad_mode_view and not grad_mode_iview: 8114 error2 = custom_err 8115 8116 run_test( 8117 fn, 8118 fn_type, 8119 grad_mode_view, 8120 grad_mode_iview, 8121 requires_grad, 8122 error1, 8123 error2, 8124 ) 8125 8126 # This list was created by logging gen_inplace_or_view_type.py 8127 # detach_ is excluded for this test because it cannot be applied to 8128 # views and thus does not return a view 8129 run_tests(lambda v: v.as_strided_((1, 0), (2, 2))) 8130 run_tests(lambda v: v.transpose_(0, 0)) 8131 run_tests(lambda v: v.t_()) 8132 run_tests(lambda v: v.squeeze_(0)) 8133 run_tests(lambda v: v.unsqueeze_(0)) 8134 run_tests(lambda v: v.swapdims_(0, 0)) 8135 run_tests(lambda v: v.swapaxes_(0, 0)) 8136 8137 def test_autograd_print_tensor(self): 8138 a = torch.ones(1, requires_grad=True) 8139 a_clone = a.clone() 8140 self.assertEqual(repr(a), "tensor([1.], requires_grad=True)") 8141 self.assertEqual(repr(a_clone), "tensor([1.], grad_fn=<CloneBackward0>)") 8142 8143 with torch.no_grad(): 8144 b = a[:] 8145 b *= 2 8146 8147 # Special handling for printing view created in no-grad and modified 8148 # in-placed in no-grad. 8149 self.assertEqual(repr(b), "tensor([2.], grad_fn=<Invalid>)") 8150 8151 class Func(torch.autograd.Function): 8152 @staticmethod 8153 def forward(ctx, x): 8154 return x 8155 8156 @staticmethod 8157 def backward(ctx, x): 8158 return x 8159 8160 c = Func.apply(a) 8161 self.assertEqual(repr(c), "tensor([2.], grad_fn=<FuncBackward>)") 8162 8163 def test_autograd_inplace_view_of_view(self): 8164 x = torch.zeros(2) 8165 with torch.no_grad(): 8166 y = x.view(2) 8167 y.requires_grad_(True) 8168 z = y.view(2) 8169 with self.assertRaisesRegex( 8170 RuntimeError, "a view of a view .* is being .* inside the no_grad block" 8171 ): 8172 z /= 2 8173 8174 x = torch.zeros(2) 8175 with torch.inference_mode(): 8176 y = x.view(2) 8177 y.requires_grad_(True) 8178 z = y.view(2) 8179 with self.assertRaisesRegex( 8180 RuntimeError, "a view of a view .* is being .* inside the inference_mode" 8181 ): 8182 z /= 2 8183 8184 # TODO This is not the correct behavior - 8185 # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627 8186 def test_autograd_inplace_views_cross_dtype(self): 8187 # This test is here to make sure that any change to this behavior is detected 8188 # and not silent. The TODOs below mark the places with unexpected behavior. 8189 a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64) 8190 a = a_orig.clone() 8191 b = torch.view_as_real(a) 8192 b = b.transpose(0, 1) 8193 b += 1 8194 b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2)) 8195 non_inplace_grad = a_orig.grad 8196 8197 a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64) 8198 a = a_orig.clone() 8199 b = torch.view_as_real(a) 8200 b.transpose_(0, 1) 8201 b += 1 8202 b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2)) 8203 inplace_grad = a_orig.grad 8204 8205 # TODO: this is a bug! 8206 # once this is fixed, it should have the transpose removed: 8207 # self.assertEqual(non_inplace_grad, inplace_grad) 8208 self.assertEqual(non_inplace_grad.T, inplace_grad) 8209 8210 def test_autograd_multiple_views_python(self): 8211 # This is not necessarily the absolute correct behavior, but this is the current 8212 # one. This test is here to make sure that any change to this behavior is detected 8213 # and not silent. The TODOs below mark the places with unexpected behavior. 8214 # Note that any change in these test will be BC-breaking and should be done carefully. 8215 8216 # This checks that multiples views in the forward are properly traced and how they 8217 # behave with respect to inplace operations. 8218 8219 # This indicator is used to track how many times the backward function was called 8220 bw_called = [0] 8221 8222 class ComplexView(Function): 8223 @staticmethod 8224 def forward(ctx, a, idx): 8225 res = a.narrow(0, idx, 1) 8226 res = a.select(0, idx) 8227 ctx.save_for_backward(a) 8228 ctx.idx = idx 8229 return res 8230 8231 @staticmethod 8232 def backward(ctx, grad): 8233 bw_called[0] += 1 8234 (a,) = ctx.saved_tensors 8235 res = torch.zeros_like(a) 8236 res.select(0, ctx.idx).copy_(grad) 8237 return res, None 8238 8239 a = torch.ones(2, requires_grad=True) 8240 idx = 1 8241 8242 bw_called[0] = 0 8243 out = ComplexView.apply(a.clone(), idx) 8244 out.sum().backward() 8245 self.assertTrue(bw_called[0] == 1) 8246 8247 out = ComplexView.apply(a.clone(), idx) 8248 with self.assertRaisesRegex( 8249 RuntimeError, 8250 "Output 0 of ComplexViewBackward is a view and is being modified inplace", 8251 ): 8252 out += 1 8253 8254 def test_autograd_python_custom_function_inplace(self): 8255 # This is not necessarily the absolute correct behavior, but this is the current 8256 # one. This test is here to make sure that any change to this behavior is detected 8257 # and not silent. The TODOs below mark the places with unexpected behavior. 8258 # Note that any change in these test will be BC-breaking and should be done carefully. 8259 8260 # This test checks custom autograd.Function that perform inplace operations 8261 8262 bw_called = [0] 8263 8264 # I) Single output 8265 class MyAdder(Function): 8266 @staticmethod 8267 def forward(ctx, a, b): 8268 a.add_(b) 8269 ctx.mark_dirty(a) 8270 return a 8271 8272 @staticmethod 8273 def backward(ctx, grad): 8274 bw_called[0] += 1 8275 return grad, grad 8276 8277 a = torch.ones(2, requires_grad=True) 8278 b = torch.ones(2, requires_grad=True) 8279 8280 # No extra inplace 8281 c = MyAdder.apply(a.clone(), b) 8282 c.sum().backward() 8283 self.assertTrue(bw_called[0] == 1) 8284 8285 # With extra inplace on the output 8286 bw_called[0] = 0 8287 c = MyAdder.apply(a.clone(), b) 8288 c += 2 8289 c.sum().backward() 8290 self.assertTrue(bw_called[0] == 1) 8291 8292 # The input is a view 8293 bw_called[0] = 0 8294 c = MyAdder.apply(a.clone().view_as(a), b) 8295 c.sum().backward() 8296 self.assertTrue(bw_called[0] == 1) 8297 8298 # Should not give non-inputs to mark_dirty 8299 class MyAdderBad(Function): 8300 @staticmethod 8301 def forward(ctx, a, b): 8302 c = 3 * a 8303 c.add_(b) 8304 ctx.mark_dirty(c) 8305 return c 8306 8307 @staticmethod 8308 def backward(ctx, grad): 8309 bw_called[0] += 1 8310 grad = 3 * grad 8311 return grad, grad 8312 8313 a = torch.ones(2, requires_grad=True) 8314 b = torch.ones(2, requires_grad=True) 8315 8316 with warnings.catch_warnings(record=True) as w: 8317 MyAdderBad.apply(a.clone(), b) 8318 self.assertEqual(len(w), 1) 8319 8320 # II) Multiple outputs 8321 class MyBadAdder(Function): 8322 @staticmethod 8323 def forward(ctx, a, b): 8324 a.add_(b) 8325 ctx.mark_dirty(a) 8326 return a, a + b 8327 8328 @staticmethod 8329 def backward(ctx, ga, gab): 8330 bw_called[0] += 1 8331 return ga + gab, ga + gab 8332 8333 # No extra inplace 8334 bw_called[0] = 0 8335 c, d = MyBadAdder.apply(a.clone(), b) 8336 (c * d).sum().backward() 8337 self.assertTrue(bw_called[0] == 1) 8338 8339 # With extra inplace on the output 8340 bw_called[0] = 0 8341 c, d = MyBadAdder.apply(a.clone(), b) 8342 c += 2 8343 (c * d).sum().backward() 8344 self.assertTrue(bw_called[0] == 1) 8345 8346 # The input is a view 8347 inplace_on_view_err = ( 8348 "your Function modifies inplace an input that is a view of another Tensor" 8349 ) 8350 with self.assertRaisesRegex(RuntimeError, inplace_on_view_err): 8351 c, d = MyBadAdder.apply(a.clone().view_as(a), b) 8352 8353 # III) Inplace + other op 8354 class MyOutPlaceAdder(Function): 8355 @staticmethod 8356 def forward(ctx, a, b): 8357 a.add_(b) 8358 ctx.mark_dirty(a) 8359 return a.clone(), a + b 8360 8361 @staticmethod 8362 def backward(ctx, ga, gab): 8363 bw_called[0] += 1 8364 return ga + gab, ga + 2 * gab 8365 8366 # We don't reuse the input 8367 def fn(a, b): 8368 orig_a = a.clone().view_as(a) 8369 c, d = MyOutPlaceAdder.apply(orig_a, b) 8370 return (c * d).sum() 8371 8372 bad_mark_dirty_err = "Some elements marked as dirty during the forward method were not returned as output." 8373 with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err): 8374 fn(a, b) 8375 8376 def test_custom_function_mark_dirty_not_differentiable(self): 8377 def get_custom_fn(jvp_err): 8378 class InplaceMul(torch.autograd.Function): 8379 @staticmethod 8380 def forward(ctx, x): 8381 result = x.mul_(2) 8382 ctx.mark_dirty(result) 8383 return result 8384 8385 @staticmethod 8386 def backward(ctx, grad_output): 8387 pass 8388 8389 @staticmethod 8390 def jvp(ctx, x_t): 8391 if jvp_err: 8392 return x_t 8393 else: 8394 return x_t.mul_(2) 8395 8396 return InplaceMul 8397 8398 for requires_grad, jvp_err in product([True, False], repeat=2): 8399 InplaceMul = get_custom_fn(jvp_err) 8400 # Make sure that tensor is always returned as-is if marked dirty 8401 z = torch.tensor(1.0, requires_grad=requires_grad) 8402 x = z.clone() 8403 y = InplaceMul.apply(x) 8404 self.assertTrue(x is y) 8405 self.assertEqual(x, z * 2) 8406 8407 # jvp must properly modify the input grad if mark_dirty is set 8408 with fwAD.dual_level(): 8409 x_tangent = torch.ones_like(x) 8410 x_dual = fwAD.make_dual(x, x_tangent) 8411 8412 if jvp_err: 8413 bad_mark_dirty_err = ( 8414 "jvp function must modify the corresponding gradient inplace" 8415 ) 8416 with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err): 8417 InplaceMul.apply(x_dual) 8418 else: 8419 out_dual = InplaceMul.apply(x_dual) 8420 _, out_tangent = fwAD.unpack_dual(out_dual) 8421 self.assertTrue(out_dual is x_dual) 8422 self.assertTrue(out_tangent is x_tangent) 8423 8424 def test_named_tensor_for_complex_views(self): 8425 names = ["batch", "height", "width", "complex"] 8426 z = torch.ones((2, 1, 2, 2), requires_grad=True) 8427 z_named = z.refine_names(*names) 8428 z_complex = torch.view_as_complex(z_named.rename(None)).refine_names( 8429 *names[:-1] 8430 ) 8431 z_complex.sum().abs().backward() 8432 expected = torch.ones_like(z_complex).rename(None) 8433 abs_1_1j = abs(1 + 1j) 8434 expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2)) 8435 self.assertEqual(z.grad, torch.view_as_real(expected)) 8436 8437 def test_custom_function_return_view_in_nograd(self): 8438 class Alias(Function): 8439 @staticmethod 8440 def forward(ctx, x): 8441 return x[:] 8442 8443 @staticmethod 8444 def backward(ctx, gx): 8445 return gx 8446 8447 inp = torch.rand(2, requires_grad=True) 8448 8449 with torch.no_grad(): 8450 output = Alias.apply(inp) 8451 8452 with torch.no_grad(): 8453 expected_output = inp[:] 8454 8455 # Calling the custom function should operate as if we called an equivalent op 8456 self.assertEqual(output.requires_grad, expected_output.requires_grad) 8457 8458 # Check that in-place modification on view throws 8459 leaf_grad_err = ( 8460 "A view was created in no_grad mode and is being modified inplace" 8461 ) 8462 with self.assertRaisesRegex(RuntimeError, leaf_grad_err): 8463 output.zero_() 8464 8465 def test_custom_function_preserve_torch_function_when_return_as_is(self): 8466 class Custom(torch.Tensor): 8467 def __init__(self, data): 8468 super().__init__() 8469 self._data = data 8470 8471 @classmethod 8472 def __torch_function__(cls, func, types, args=(), kwargs=None): 8473 kwargs = {} if kwargs is None else kwargs 8474 args = tuple(a._data if isinstance(a, cls) else a for a in args) 8475 out = func(*args, **kwargs) 8476 if isinstance(out, torch.Tensor): 8477 out = cls(out) 8478 return out 8479 8480 class Fn(torch.autograd.Function): 8481 @staticmethod 8482 def forward(ctx, input): 8483 return input 8484 8485 @staticmethod 8486 def backward(ctx): 8487 pass 8488 8489 x = Custom(torch.randn(2, 3)) 8490 y = Fn.apply(x) 8491 self.assertTrue(isinstance(y, Custom)) 8492 8493 def test_grad_mode_restored_reentrant(self): 8494 class MyFunction(Function): 8495 @staticmethod 8496 def forward(ctx, inp): 8497 return inp.clone() 8498 8499 @staticmethod 8500 def backward(ctx, go): 8501 original = torch._C.is_grad_enabled() 8502 with torch.enable_grad(): 8503 self.assertTrue(torch._C.is_grad_enabled()) 8504 foo = torch.rand(go.size(), requires_grad=True) 8505 (grad,) = torch.autograd.grad(foo**3, foo, grad_outputs=go) 8506 self.assertTrue(torch._C.is_grad_enabled()) 8507 self.assertTrue(torch._C.is_grad_enabled() == original) 8508 return grad 8509 8510 inp = torch.rand(3, requires_grad=True) 8511 8512 # Case where original==False 8513 MyFunction.apply(inp).sum().backward() 8514 # Case where original==True 8515 MyFunction.apply(inp).sum().backward(create_graph=True) 8516 8517 def test_power_function(self): 8518 a = torch.tensor([0.0, 0.0, 0.0]) 8519 b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True) 8520 c = torch.sum(a**b) 8521 c.backward() 8522 self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0])) 8523 8524 s = 0 8525 b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True) 8526 c = torch.sum(s**b) 8527 c.backward() 8528 self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0])) 8529 8530 def test_custom_function_error(self): 8531 class BadFw(Function): 8532 @staticmethod 8533 def backward(ctx, foo): 8534 return foo 8535 8536 class BadBw(Function): 8537 @staticmethod 8538 def forward(ctx, foo): 8539 return foo.clone() 8540 8541 class BadBw2(Function): 8542 @staticmethod 8543 def forward(ctx, foo): 8544 return foo.clone() 8545 8546 @staticmethod 8547 def backward(ctx, foo): 8548 return foo 8549 8550 @staticmethod 8551 def vjp(ctx, foo): 8552 return foo 8553 8554 class BadJvp(Function): 8555 @staticmethod 8556 def forward(ctx, foo): 8557 return foo.clone() 8558 8559 inp = torch.rand(1, requires_grad=True) 8560 with self.assertRaisesRegex(NotImplementedError, "must implement the forward"): 8561 BadFw.apply(inp) 8562 8563 with self.assertRaisesRegex(RuntimeError, "must implement either the backward"): 8564 BadBw.apply(inp).sum().backward() 8565 8566 with self.assertRaisesRegex( 8567 RuntimeError, "Implementing both 'backward' and 'vjp'" 8568 ): 8569 BadBw2.apply(inp).sum().backward() 8570 8571 with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"): 8572 with fwAD.dual_level(): 8573 d = fwAD.make_dual(inp, torch.rand_like(inp)) 8574 res = BadJvp.apply(d) 8575 8576 def test_custom_function_forward_mode_view_checks(self): 8577 flag_to_error = { 8578 "ok": None, 8579 "not_a_view": "jvp is not returning a view", 8580 "not_a_view_of_inp": "jvp is not returning a view of the given", 8581 "not_a_view_of_inp_base": "jvp is not returning a view of the same base", 8582 } 8583 8584 class ViewFn(Function): 8585 @staticmethod 8586 def forward(ctx, foo, flag): 8587 ctx.flag = flag 8588 ctx.size = foo.size() 8589 return foo.narrow(0, 0, 2) 8590 8591 @staticmethod 8592 def vjp(ctx, gO): 8593 gI = gO.new_zeros(ctx.size) 8594 gI.narrow(0, 0, 2).copy_(gO) 8595 return gI, None 8596 8597 @staticmethod 8598 def jvp(ctx, gI, _): 8599 res = gI.narrow(0, 0, 2) 8600 if ctx.flag != "ok": 8601 # Break the view in the gradients! 8602 res = res.clone() 8603 if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]: 8604 # Result should be a view, just of the wrong thing 8605 res = res.view_as(res) 8606 return res 8607 8608 inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 8609 8610 for flag, msg in flag_to_error.items(): 8611 8612 def test_fn(inp): 8613 if flag == "not_a_view_of_inp_base": 8614 inp = inp.view_as(inp) 8615 return ViewFn.apply(inp, flag) 8616 8617 if msg is None: 8618 gradcheck(test_fn, inp, check_forward_ad=True) 8619 else: 8620 with self.assertRaisesRegex(RuntimeError, msg): 8621 gradcheck(test_fn, inp, check_forward_ad=True) 8622 8623 def test_custom_function_forward_mode_inplace_checks(self): 8624 class InplaceFn(Function): 8625 @staticmethod 8626 def forward(ctx, foo, flag): 8627 ctx.mark_dirty(foo) 8628 ctx.flag = flag 8629 foo.mul_(2) 8630 return foo 8631 8632 @staticmethod 8633 def vjp(ctx, gO): 8634 return 2 * gO, None 8635 8636 @staticmethod 8637 def jvp(ctx, gI, _): 8638 if ctx.flag: 8639 # Don't do the change inplace 8640 return 2 * gI 8641 else: 8642 gI.mul_(2) 8643 return gI 8644 8645 inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 8646 8647 def test_fn(inp, flag): 8648 inp = inp.clone() 8649 return InplaceFn.apply(inp, flag) 8650 8651 gradcheck(test_fn, (inp, False), check_forward_ad=True) 8652 8653 with self.assertRaisesRegex( 8654 RuntimeError, 8655 "inplace custom Function is not modifying the forward mode gradients inplace", 8656 ): 8657 gradcheck(test_fn, (inp, True), check_forward_ad=True) 8658 8659 def test_custom_function_forward_mode_wrong_formula(self): 8660 class UserFn(Function): 8661 @staticmethod 8662 def forward(ctx, foo, should_fail): 8663 ctx.should_fail = should_fail 8664 return foo * 2 8665 8666 @staticmethod 8667 def vjp(ctx, gO): 8668 return 2 * gO, None 8669 8670 @staticmethod 8671 def jvp(ctx, gI, _): 8672 if ctx.should_fail: 8673 # Wrong gradient formula 8674 return 3 * gI 8675 else: 8676 return 2 * gI 8677 8678 inp = torch.rand(10, dtype=torch.double, requires_grad=True) 8679 gradcheck(UserFn.apply, (inp, False), check_forward_ad=True) 8680 8681 with self.assertRaisesRegex( 8682 RuntimeError, "Jacobian computed with forward mode mismatch for output 0" 8683 ): 8684 gradcheck(UserFn.apply, (inp, True), check_forward_ad=True) 8685 8686 def test_custom_function_forward_mode_non_tensor_before_tensor_args(self): 8687 class MyFn(torch.autograd.Function): 8688 @staticmethod 8689 def forward(ctx, nt, x, nt2, y): 8690 return x * 2 + y * 3 8691 8692 @staticmethod 8693 def jvp(ctx, nt, x_t, nt2, y_t): 8694 self.assertIsNone(nt) 8695 self.assertIsNone(nt2) 8696 return x_t * 2 + y_t * 3 8697 8698 x = torch.tensor(1.0, dtype=torch.double) 8699 t = torch.tensor(1.0, dtype=torch.double) 8700 y = torch.tensor(1.0, dtype=torch.double) 8701 8702 with fwAD.dual_level(): 8703 dual_x = fwAD.make_dual(x, t) 8704 MyFn.apply(1, dual_x, 1, y) 8705 8706 gradcheck( 8707 MyFn.apply, 8708 (1, x.requires_grad_(True), 1, y.requires_grad_(True)), 8709 check_forward_ad=True, 8710 check_backward_ad=False, 8711 check_batched_grad=False, 8712 ) 8713 8714 def test_custom_function_forward_mode_forward_is_no_op(self): 8715 error_regex = ( 8716 "A custom Function's forward is returning a view \\(or an input as-is\\)" 8717 ) 8718 8719 return_lambdas = { 8720 # If we return an input as-is in forward, that is treated 8721 # as if self.view_as(self) is performed. If jvp returns x.view_as(x), 8722 # this is OK. 8723 "view_as": lambda x: x.view_as(x), 8724 # Expect this to raise an error 8725 "self": lambda x: x, 8726 # Expect this to raise the same error 8727 "mul_by_2": lambda x: x * 2, 8728 } 8729 8730 for k, fn in return_lambdas.items(): 8731 8732 class MyFn(torch.autograd.Function): 8733 @staticmethod 8734 def forward(ctx, x, y): 8735 return x + y, x 8736 8737 @staticmethod 8738 def vjp(ctx, gO1, gO2): 8739 return gO1 + gO2, gO1 8740 8741 @staticmethod 8742 def jvp(ctx, x_t, y_t): 8743 return x_t + y_t, fn(x_t) 8744 8745 a = torch.tensor(1.0, dtype=torch.double, requires_grad=True) 8746 t = torch.tensor(1.0, dtype=torch.double) 8747 b = torch.tensor(1.0, dtype=torch.double, requires_grad=True) 8748 8749 c = torch.tensor(1.0, dtype=torch.double) 8750 t2 = torch.tensor(1.0, dtype=torch.double) 8751 d = torch.tensor(1.0, dtype=torch.double) 8752 8753 with fwAD.dual_level(): 8754 a_dual = fwAD.make_dual(a, t) 8755 c_dual = fwAD.make_dual(c, t2) 8756 8757 if k == "view_as": 8758 _, out2 = MyFn.apply(a_dual, b) 8759 self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t) 8760 8761 _, out2 = MyFn.apply(c_dual, d) 8762 self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t2) 8763 else: 8764 with self.assertRaisesRegex(RuntimeError, error_regex): 8765 MyFn.apply(a_dual, b) 8766 8767 with self.assertRaisesRegex(RuntimeError, error_regex): 8768 MyFn.apply(c_dual, d) 8769 8770 if k == "view_as": 8771 gradcheck(MyFn.apply, (a, c), check_forward_ad=True) 8772 else: 8773 with self.assertRaisesRegex(RuntimeError, error_regex): 8774 gradcheck(MyFn.apply, (a, c), check_forward_ad=True) 8775 8776 def test_custom_function_save_for_forward(self): 8777 class Func(torch.autograd.Function): 8778 @staticmethod 8779 def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): 8780 ctx.save_for_backward(x, y) 8781 ctx.save_for_forward(x, y) 8782 ctx.z = z 8783 ctx.prod = x * y 8784 return z * ctx.prod 8785 8786 @staticmethod 8787 def jvp(ctx, x_t, y_t, _): 8788 x_p, y_p = ctx.saved_tensors 8789 z = ctx.z 8790 return z * (y_p * x_t + x_p * y_t) 8791 8792 @staticmethod 8793 def vjp(ctx, grad_out): 8794 x, y = ctx.saved_tensors 8795 z = ctx.z 8796 return z * grad_out * y, z * grad_out * x, None 8797 8798 a = torch.tensor(1.0, requires_grad=True, dtype=torch.double) 8799 t = torch.tensor(1.0, dtype=torch.double) 8800 b = torch.tensor(2.0, requires_grad=True, dtype=torch.double) 8801 c = 4 8802 8803 with fwAD.dual_level(): 8804 a_dual = fwAD.make_dual(a, t) 8805 out = Func.apply(a_dual, b, c) 8806 out.backward() 8807 8808 gradcheck(Func.apply, (a, b, c), check_forward_ad=True) 8809 8810 # When saved for backward, but not saved for forward 8811 class Func(torch.autograd.Function): 8812 @staticmethod 8813 def forward(ctx, x: torch.Tensor): 8814 ctx.save_for_backward(x) 8815 return x.clone() 8816 8817 @staticmethod 8818 def jvp(ctx, x_t): 8819 self.assertEqual(len(ctx.saved_tensors), 0) 8820 return x_t 8821 8822 @staticmethod 8823 def vjp(ctx, grad_out): 8824 (x,) = ctx.saved_tensors 8825 self.assertEqual(len(ctx.saved_tensors), 1) 8826 return grad_out 8827 8828 with fwAD.dual_level(): 8829 a_dual = fwAD.make_dual(a, t) 8830 out = Func.apply(a_dual) 8831 out.backward() 8832 8833 gradcheck(Func.apply, (a,), check_forward_ad=True) 8834 8835 @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 8836 def test_custom_function_forward_mode_non_differentiable(self): 8837 # returns differentiable type, marked non-differentiable 8838 class Func(torch.autograd.Function): 8839 @staticmethod 8840 def forward(ctx, x, y): 8841 out = y.clone() 8842 ctx.mark_non_differentiable(out) 8843 return x.clone(), out 8844 8845 @staticmethod 8846 def jvp(ctx, x_tangent, y_tangent): 8847 return x_tangent, None 8848 8849 x = torch.tensor(2.0) 8850 x_tangent = torch.tensor(1.0) 8851 y = torch.tensor(3.0) 8852 8853 with fwAD.dual_level(): 8854 x_dual = fwAD.make_dual(x, x_tangent) 8855 _, out2_dual = Func.apply(x_dual, y) 8856 self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None) 8857 8858 y = torch.tensor(3) 8859 8860 # returns non-differentiable type, NOT marked non-differentiable 8861 class Func(torch.autograd.Function): 8862 @staticmethod 8863 def forward(ctx, x, y): 8864 return x.clone(), y.clone() 8865 8866 @staticmethod 8867 def jvp(ctx, x_tangent, y_tangent): 8868 self.assertIsNone(y_tangent) 8869 return x_tangent, None 8870 8871 with fwAD.dual_level(): 8872 x_dual = fwAD.make_dual(x, x_tangent) 8873 _, out2_dual = Func.apply(x_dual, y) 8874 self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None) 8875 8876 class FuncWrong(torch.autograd.Function): 8877 @staticmethod 8878 def forward(ctx, x, y): 8879 out = y.clone() 8880 ctx.mark_non_differentiable(out) 8881 return x.clone(), out 8882 8883 @staticmethod 8884 def jvp(ctx, x_tangent, y_tangent): 8885 return x_tangent, x_tangent.clone() 8886 8887 with fwAD.dual_level(): 8888 x_dual = fwAD.make_dual(x, x_tangent) 8889 with self.assertRaisesRegex( 8890 RuntimeError, "You should return None at that position instead" 8891 ): 8892 FuncWrong.apply(x_dual, y) 8893 8894 # returns non-tensor 8895 class Func(torch.autograd.Function): 8896 @staticmethod 8897 def forward(ctx, x): 8898 return x.clone(), object(), x.clone() 8899 8900 @staticmethod 8901 def jvp(ctx, x_tangent): 8902 return x_tangent, None, x_tangent 8903 8904 with fwAD.dual_level(): 8905 x_dual = fwAD.make_dual(x, x_tangent) 8906 out_dual, _, out2_dual = Func.apply(x_dual) 8907 self.assertEqual(fwAD.unpack_dual(out_dual).tangent, x_tangent) 8908 self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, x_tangent) 8909 8910 def test_custom_function_local_inplace(self): 8911 class MyFn(torch.autograd.Function): 8912 @staticmethod 8913 def forward(ctx, inp, inplace): 8914 view = inp.clone()[:3] 8915 if inplace: 8916 view += 2 8917 return view 8918 8919 @staticmethod 8920 def backward(ctx, grad): 8921 return grad, None 8922 8923 base = torch.rand(10, requires_grad=True) 8924 8925 foo = MyFn.apply(base, False) 8926 self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward") 8927 8928 foo = MyFn.apply(base, True) 8929 self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward") 8930 8931 def test_integer_outputs(self): 8932 inp = torch.rand(4, requires_grad=True) 8933 8934 out = inp.argmax() 8935 self.assertFalse(out.dtype.is_floating_point) 8936 self.assertFalse(out.requires_grad) 8937 8938 out = inp.argmin() 8939 self.assertFalse(out.dtype.is_floating_point) 8940 self.assertFalse(out.requires_grad) 8941 8942 out = inp.argsort() 8943 self.assertFalse(out.dtype.is_floating_point) 8944 self.assertFalse(out.requires_grad) 8945 8946 val = torch.rand((), requires_grad=True) 8947 8948 out = torch.searchsorted(inp, val) 8949 self.assertFalse(out.dtype.is_floating_point) 8950 self.assertFalse(out.requires_grad) 8951 8952 bins = torch.linspace(0, 1.0, steps=100, requires_grad=True) 8953 vals = torch.rand(5, 5, requires_grad=True) 8954 out = torch.bucketize(vals, bins) 8955 self.assertFalse(out.dtype.is_floating_point) 8956 self.assertFalse(out.requires_grad) 8957 8958 val = torch.empty(5).requires_grad_() 8959 out = val.count_nonzero() 8960 self.assertFalse(out.requires_grad) 8961 8962 def assert_only_first_requires_grad(res): 8963 if not isinstance(res, tuple): 8964 res = (res,) 8965 self.assertTrue(res[0].requires_grad) 8966 for out in res[1:]: 8967 if out is not None: 8968 self.assertFalse(out.requires_grad) 8969 8970 for sort in [True, False]: 8971 for return_inverse in [True, False]: 8972 for return_counts in [True, False]: 8973 res = torch.unique( 8974 inp, 8975 sorted=sort, 8976 return_inverse=return_inverse, 8977 return_counts=return_counts, 8978 ) 8979 assert_only_first_requires_grad(res) 8980 8981 res = torch.unique( 8982 inp, 8983 sorted=sort, 8984 return_inverse=return_inverse, 8985 return_counts=return_counts, 8986 dim=0, 8987 ) 8988 assert_only_first_requires_grad(res) 8989 8990 res = torch.unique_consecutive( 8991 inp, return_inverse=return_inverse, return_counts=return_counts 8992 ) 8993 assert_only_first_requires_grad(res) 8994 8995 res = torch.unique_consecutive( 8996 inp, 8997 return_inverse=return_inverse, 8998 return_counts=return_counts, 8999 dim=0, 9000 ) 9001 assert_only_first_requires_grad(res) 9002 9003 # Here we test the internal functions to make sure all of them are 9004 # covered on top of the public API 9005 res = torch._unique(inp, sorted=sort, return_inverse=return_inverse) 9006 assert_only_first_requires_grad(res) 9007 9008 # This looks public but is actually manually deleted from the 9009 # torch namespace in torch/functional.py 9010 res = torch._VF.unique_dim( 9011 inp, 9012 dim=0, 9013 sorted=sort, 9014 return_inverse=return_inverse, 9015 return_counts=return_counts, 9016 ) 9017 assert_only_first_requires_grad(res) 9018 9019 # We don't test `unique_dim_consecutive` here. 9020 # It looks public but the python binding is actually manually disabled in 9021 # tools/autograd/gen_python_functions.py 9022 9023 res = torch._unique2( 9024 inp, 9025 sorted=sort, 9026 return_inverse=return_inverse, 9027 return_counts=return_counts, 9028 ) 9029 assert_only_first_requires_grad(res) 9030 9031 def test_custom_function_cycle(self): 9032 class MyFn(Function): 9033 @staticmethod 9034 def forward(ctx, x, metadata): 9035 x = x.clone() 9036 ctx.meta = metadata 9037 ctx.save_for_backward(x) 9038 return x 9039 9040 @staticmethod 9041 def backward(ctx, gO): 9042 (x,) = ctx.saved_tensors 9043 self.assertEqual(x, 3.14) 9044 self.assertEqual(ctx.meta["foo"], 3.14) 9045 return gO * x, None 9046 9047 def get_refs(with_backward): 9048 a = torch.tensor(3.14, requires_grad=True) 9049 9050 metadata = {} 9051 out = MyFn.apply(a, metadata) 9052 9053 metadata["foo"] = out 9054 9055 if with_backward: 9056 out.sum().backward() 9057 self.assertEqual(a.grad, a) 9058 9059 return torch._C._WeakTensorRef(out) 9060 9061 with disable_gc(): 9062 ref = get_refs(False) 9063 self.assertFalse(ref.expired()) 9064 gc.collect() 9065 self.assertTrue(ref.expired()) 9066 9067 # The backward clears the saved_variables but not the __dict__ 9068 with disable_gc(): 9069 ref = get_refs(True) 9070 self.assertFalse(ref.expired()) 9071 gc.collect() 9072 self.assertTrue(ref.expired()) 9073 9074 def test_create_graph_and_full_backward_hook_cycle(self): 9075 # If BackwardHook saves grad_output, it can create a cycle when we perform backward 9076 # with create_graph=True 9077 # 9078 # grad_output -> grad_output.grad_fn -> graph -> hook -> grad_output 9079 # 9080 class TestCls: 9081 # Dummy class for the purpose of creating a weakref 9082 pass 9083 9084 def get_ref(input_requires_grad, nb_hooks): 9085 t = torch.randn(10, requires_grad=input_requires_grad) 9086 a = torch.tensor(1.0, requires_grad=True) 9087 9088 class Test(nn.Module): 9089 def forward(self, x): 9090 return x**2 * a**2 9091 9092 mod = Test() 9093 9094 for _ in range(nb_hooks): 9095 mod.register_full_backward_hook(lambda a, b, c: None) 9096 9097 tmp = mod(t) 9098 9099 # Save dummy object to graph and get a weak ref to it 9100 test = TestCls() 9101 ref = weakref.ref(test) 9102 tmp.grad_fn.metadata["a"] = test 9103 9104 with set_warn_always_context(True): 9105 with warnings.catch_warnings(record=True) as w: 9106 tmp.exp().sum().backward(create_graph=True) 9107 self.assertTrue(len(w) == 1) 9108 self.assertTrue( 9109 "Using backward() with create_graph=True" in str(w[0].message) 9110 ) 9111 9112 # Remove the backward + create_graph=True cycle 9113 a.grad = None 9114 t.grad = None 9115 9116 return ref 9117 9118 for nb_hooks in (1, 2, 3): 9119 for input_requires_grad in (True, False): 9120 ref_ = get_ref( 9121 input_requires_grad=input_requires_grad, 9122 nb_hooks=nb_hooks, 9123 ) 9124 gc.collect() 9125 self.assertIsNone(ref_()) 9126 9127 @parametrize("use_custom_function", [True, False]) 9128 @parametrize("use_tensor_hook", [True, False]) 9129 def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook): 9130 # This creates a cycle between the hook and grad_fn_b 9131 # hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp) 9132 # -> dict -> hook 9133 # 9134 # This test is testing that the grad_fn_b (python) only traverses the 9135 # dict if it is the only one holding a reference to the grad_fn_b (cpp) 9136 # shared_ptr 9137 # 9138 # See: https://github.com/pytorch/pytorch/issues/102174 9139 class Function(torch.autograd.Function): 9140 @staticmethod 9141 def forward(ctx, x): 9142 return x 9143 9144 @staticmethod 9145 def backward(ctx, grad): 9146 return grad 9147 9148 class Test: 9149 pass 9150 9151 count = [0] 9152 9153 def scope(): 9154 a = torch.tensor(1.0, requires_grad=True) 9155 if use_custom_function: 9156 b = Function.apply(a) 9157 else: 9158 b = a.clone() 9159 grad_fn_b = b.grad_fn 9160 obj = Test() 9161 9162 def hook(*args): 9163 # Make sure this hook's closure holds onto grad_fn_b 9164 # This forms a cycle between the hook and grad_fn_b 9165 # We also hold onto a sentinel object 'obj' to track 9166 # whether this cycle is still alive. See 'ref' below. 9167 grad_fn_b 9168 obj 9169 count[0] += 1 9170 9171 if use_tensor_hook: 9172 b.register_hook(hook) 9173 else: 9174 b.grad_fn.register_hook(hook) 9175 c = b.clone() 9176 ref = weakref.ref(obj) 9177 return c, ref 9178 9179 with disable_gc(): 9180 out, ref = scope() 9181 out.backward(retain_graph=True) 9182 9183 gc.collect() 9184 9185 # Make sure gc does not clear the cycle noted above. 9186 # e.g. the hook is alive and gets fired even after gc runs 9187 out.backward(retain_graph=True) 9188 self.assertEqual(count[0], 2) 9189 9190 # ref is still alive because the use_count of the cpp grad_fn 9191 # shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the 9192 # rest of the graph holds onto the shared_ptr 9193 self.assertIsNotNone(ref()) 9194 9195 # Then delete the rest of the graph and check that ref is dead 9196 del out 9197 gc.collect() 9198 self.assertIsNone(ref()) 9199 9200 def test_full_backward_hook_double_backward(self): 9201 x = torch.rand(1, requires_grad=True) 9202 y = torch.rand_like(x) 9203 9204 func = torch.nn.MSELoss() 9205 counter = [0] 9206 9207 def hook(module, grad_input, grad_output): 9208 counter[0] += 1 9209 9210 func.register_full_backward_hook(hook) 9211 9212 f = func(x, y) 9213 9214 (gradx_f,) = torch.autograd.grad(f, x, create_graph=True) 9215 self.assertEqual(counter[0], 1) 9216 _ = torch.autograd.grad(gradx_f, x) 9217 # We should not error, and counter should not be incremented 9218 self.assertEqual(counter[0], 1) 9219 9220 def test_input_buffer_accum(self): 9221 leaf = torch.rand(2, 2, requires_grad=True) 9222 9223 # An op that returns sparse gradients 9224 ind = torch.tensor([[0, 0]], dtype=torch.long) 9225 out2 = leaf.gather(0, ind, sparse_grad=True) 9226 9227 # An op that returns the gradients as-is 9228 out1 = leaf.clone() 9229 9230 grad_out1_original = torch.rand_like(out1) 9231 grad_out1 = grad_out1_original.clone() 9232 grad_out2 = torch.rand_like(out2) 9233 9234 torch.autograd.backward((out1, out2), (grad_out1, grad_out2)) 9235 9236 # Given gradients should not be modified inplace 9237 self.assertEqual(grad_out1, grad_out1_original) 9238 9239 def test_no_unnecessary_unwrapping(self): 9240 a = torch.randn(5, requires_grad=True) 9241 a_orig = a.detach().clone() 9242 b = a * a 9243 c = a * b 9244 d = torch.exp(a) 9245 9246 # a is leaf 9247 self.assertIs(b.grad_fn._saved_self, a) 9248 self.assertIs(b.grad_fn._saved_other, a) 9249 self.assertIs(c.grad_fn._saved_self, a) 9250 9251 # b is not an output 9252 self.assertIs(c.grad_fn._saved_other, b) 9253 9254 # d is an output 9255 self.assertEqual(d.grad_fn._saved_result, d) 9256 self.assertIsNot(d.grad_fn._saved_result, d) 9257 9258 c.sum().backward() 9259 9260 with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 9261 c.grad_fn._saved_self 9262 9263 # a is left untouched 9264 self.assertEqual(a, a_orig) 9265 9266 def test_saved_variable_version_counter(self): 9267 a = torch.rand(2, requires_grad=True) 9268 9269 b = torch.exp(a) 9270 9271 b_unpacked = b.grad_fn._saved_result 9272 self.assertEqual(b, b_unpacked) 9273 self.assertEqual(b._version, b_unpacked._version) 9274 9275 with torch.no_grad(): 9276 b += 1 9277 9278 self.assertEqual(b, b_unpacked) 9279 self.assertEqual(b._version, b_unpacked._version) 9280 9281 def test_saved_variable_packing_unpacking_saved_original_with_hooks(self): 9282 # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks 9283 # The saved_original / did_not_save_original distinction corresponds to the `save_original` 9284 # attribute of `SavedVariable`. 9285 9286 def test(get_input, is_leaf): 9287 a = get_input() 9288 grad_fn = a.grad_fn 9289 y = a * a 9290 y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x / 2) 9291 self.assertEqual(a, y.grad_fn._saved_self) 9292 if not is_leaf: 9293 self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn) 9294 y.sum().backward() 9295 else: 9296 y.sum().backward() 9297 self.assertEqual(2 * a, a.grad) 9298 9299 a = get_input() 9300 grad_fn = a.grad_fn 9301 y = a * a 9302 y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x) 9303 self.assertEqual(2 * a, y.grad_fn._saved_self) 9304 if not is_leaf: 9305 self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn) 9306 y.sum().backward() 9307 else: 9308 y.sum().backward() 9309 self.assertEqual(3 * a, a.grad) 9310 9311 # double backward 9312 a = get_input() 9313 grad_fn = a.grad_fn 9314 y = a**3 9315 y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 9316 s = torch.sum(y) 9317 (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9318 if not is_leaf: 9319 self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn) 9320 g.sum().backward() 9321 else: 9322 g.sum().backward() 9323 self.assertEqual(6 * a, a.grad) 9324 9325 a = get_input() 9326 y = a * a 9327 y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: 1) 9328 with self.assertRaisesRegex( 9329 TypeError, "Output of saved tensor unpack_hook expected to be a Tensor" 9330 ): 9331 print(y.grad_fn._saved_self) 9332 9333 a = get_input() 9334 y = a * a 9335 with self.assertRaisesRegex( 9336 TypeError, "missing 1 required positional argument" 9337 ): 9338 y.grad_fn._raw_saved_self.register_hooks(lambda x, b: x, lambda x: x) 9339 9340 a = get_input() 9341 y = a * a 9342 with self.assertRaisesRegex( 9343 TypeError, "missing 1 required positional argument" 9344 ): 9345 y.grad_fn._raw_saved_self.register_hooks( 9346 lambda x, b: (x, b), lambda x: x 9347 ) 9348 9349 def inplace_double(x): 9350 x *= 2 9351 return x 9352 9353 a = get_input() 9354 t = a * a 9355 9356 with self.assertRaisesRegex( 9357 RuntimeError, 9358 "A saved tensor pack hook is modifying its input in place.", 9359 ): 9360 t.grad_fn._raw_saved_self.register_hooks( 9361 inplace_double, lambda x: x / 2 9362 ) 9363 9364 # leaf 9365 test(lambda: torch.randn(5, requires_grad=True), True) 9366 9367 # not leaf, not output 9368 test(lambda: (1 + torch.randn(5, requires_grad=True)), False) 9369 9370 def test_saved_variable_saved_original_inplace_detach(self): 9371 # Detaching a tensor that is saved input raises 9372 a = torch.tensor(1.0, requires_grad=True).clone() 9373 b = a.sin() 9374 a.detach_() 9375 with self.assertRaisesRegex( 9376 RuntimeError, "Trying to use a saved tensor that has been detached" 9377 ): 9378 b.backward() 9379 9380 # Detaching a tensor that is saved as output is OK 9381 a = torch.tensor(1.0, requires_grad=True).clone() 9382 b = a.exp() 9383 a.detach_() 9384 b.backward() 9385 9386 def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): 9387 # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks 9388 # The saved_original / did_not_save_original distinction corresponds to the `save_original` 9389 # attribute of `SavedVariable`. 9390 9391 a = torch.randn(5, requires_grad=True) 9392 y = torch.exp(a) 9393 y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x) 9394 self.assertEqual(y, y.grad_fn._saved_result) 9395 self.assertIs(y.grad_fn, y.grad_fn._saved_result.grad_fn) 9396 y.sum().backward() 9397 self.assertEqual(a.grad, y) 9398 9399 def test_saved_variable_packing_unpacking_saved_original_with_default_hooks(self): 9400 # Tests that default hooks are properly registered, used and reset 9401 # The saved_original / did_not_save_original distinction corresponds to the `save_original` 9402 # attribute of `SavedVariable`. 9403 # See also: 9404 # - test_saved_variable_packing_unpacking_saved_original_with_hooks 9405 9406 def pack(x): 9407 warnings.warn("pack") 9408 return x 9409 9410 with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): 9411 a = torch.ones(5, requires_grad=True) 9412 9413 with warnings.catch_warnings(record=True) as w: 9414 warnings.simplefilter("always") 9415 y = a * a 9416 # should raise two warnings from a being saved twice 9417 self.assertEqual(len(w), 2) 9418 9419 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9420 a = torch.randn(5, requires_grad=True) 9421 y = a * a 9422 self.assertEqual(a, y.grad_fn._saved_self) 9423 self.assertEqual(a, y.grad_fn._saved_other) 9424 y.sum().backward() 9425 self.assertEqual(2 * a, a.grad) 9426 9427 with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x / 2): 9428 a = torch.randn(5, requires_grad=True) 9429 y = a * a 9430 self.assertEqual(a, y.grad_fn._saved_self) 9431 self.assertEqual(a, y.grad_fn._saved_other) 9432 y.sum().backward() 9433 self.assertEqual(2 * a, a.grad) 9434 9435 with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9436 a = torch.randn(5, requires_grad=True) 9437 y = a * a 9438 self.assertEqual(2 * a, y.grad_fn._saved_self) 9439 self.assertEqual(2 * a, y.grad_fn._saved_other) 9440 y.sum().backward() 9441 self.assertEqual(4 * a, a.grad) 9442 9443 # Exited hooks correctly 9444 a = torch.randn(5, requires_grad=True) 9445 y = a * a 9446 self.assertEqual(a, y.grad_fn._saved_self) 9447 self.assertEqual(a, y.grad_fn._saved_other) 9448 y.sum().backward() 9449 self.assertEqual(2 * a, a.grad) 9450 9451 def test_saved_variable_packing_unpacking_did_not_save_original_with_default_hooks( 9452 self, 9453 ): 9454 # See also test_saved_variable_packing_unpacking_did_not_save_original_with_hooks 9455 9456 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9457 a = torch.randn(5, requires_grad=True) 9458 y = torch.exp(a) 9459 self.assertEqual(y, y.grad_fn._saved_result) 9460 y.sum().backward() 9461 self.assertEqual(a.grad, y) 9462 9463 def test_setting_default_saved_variable_hooks_twice_should_not_fail(self): 9464 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9465 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9466 pass 9467 9468 def test_setting_default_saved_variable_hooks_twice_should_use_inner(self): 9469 with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x): 9470 b = torch.randn(5, requires_grad=True) 9471 with torch.autograd.graph.saved_tensors_hooks( 9472 lambda x: 5 * x, lambda x: 5 * x 9473 ): 9474 a = torch.randn(5, requires_grad=True) 9475 y = a * a 9476 z = b * b 9477 y.sum().backward() 9478 z.sum().backward() 9479 self.assertEqual(2 * 5 * 5 * a, a.grad) 9480 self.assertEqual(2 * 3 * 3 * b, b.grad) 9481 9482 def test_disabling_saved_tensor_hooks(self): 9483 with torch.autograd.graph.disable_saved_tensors_hooks("error message"): 9484 with self.assertRaisesRegex(RuntimeError, "error message"): 9485 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9486 pass 9487 9488 self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9489 9490 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9491 with self.assertRaisesRegex(RuntimeError, "error message"): 9492 with torch.autograd.graph.disable_saved_tensors_hooks("error message"): 9493 pass 9494 9495 self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9496 9497 def test_disabling_saved_tensor_hooks_nested(self): 9498 with torch.autograd.graph.disable_saved_tensors_hooks("outer"): 9499 with torch.autograd.graph.disable_saved_tensors_hooks("inner"): 9500 with self.assertRaisesRegex(RuntimeError, "inner"): 9501 with torch.autograd.graph.saved_tensors_hooks( 9502 lambda x: x, lambda x: x 9503 ): 9504 pass 9505 9506 self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9507 9508 self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9509 9510 def test_saved_tensor_hooks_custom_error_propagation(self): 9511 class CustomError(Exception): 9512 pass 9513 9514 class error_on_pack_hook(torch.autograd.graph.saved_tensors_hooks): 9515 def __init__(self) -> None: 9516 def pack_hook(x): 9517 raise CustomError("pack") 9518 9519 super().__init__(pack_hook, lambda x: x) 9520 9521 class error_on_unpack_hook(torch.autograd.graph.saved_tensors_hooks): 9522 def __init__(self) -> None: 9523 def unpack_hook(x): 9524 raise CustomError("unpack") 9525 9526 super().__init__(lambda x: x, unpack_hook) 9527 9528 a = torch.tensor(1.0, requires_grad=True) 9529 9530 with error_on_pack_hook(): 9531 with self.assertRaisesRegex(CustomError, "pack"): 9532 out = torch.sin(a) 9533 9534 with error_on_unpack_hook(): 9535 out = torch.sin(a) 9536 with self.assertRaisesRegex(CustomError, "unpack"): 9537 out.backward() 9538 9539 def test_saved_tensor_hooks_custom_function_intermediates(self): 9540 class Func(torch.autograd.Function): 9541 @staticmethod 9542 def forward(ctx, x): 9543 intermediate = x.exp() 9544 ctx.save_for_backward( 9545 intermediate.clone().detach_().requires_grad_(True) 9546 ) 9547 return x.exp() 9548 9549 @staticmethod 9550 def backward(ctx, grad_out): 9551 (intermediate,) = ctx.saved_tensors 9552 return grad_out * intermediate 9553 9554 a = torch.tensor(1.0, requires_grad=True) 9555 9556 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9557 out = Func.apply(a) 9558 out.backward() 9559 9560 def test_unpack_hooks_exec_count(self): 9561 def f(x, y): 9562 return x * y 9563 9564 pack_count = 0 9565 unpack_count = 0 9566 9567 def pack_hook(x): 9568 nonlocal pack_count 9569 pack_count += 1 9570 return x 9571 9572 # unpack hook shouldn't run during compilation, while we trace the forward 9573 def unpack_hook(x): 9574 nonlocal unpack_count 9575 unpack_count += 1 9576 return x 9577 9578 x = torch.ones(4, requires_grad=True) 9579 y = torch.ones(4, requires_grad=False) 9580 with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 9581 out_test = f(x, y) 9582 self.assertEqual(pack_count, 1) 9583 self.assertEqual(unpack_count, 0) 9584 out_test.sum().backward() 9585 self.assertEqual(pack_count, 1) 9586 self.assertEqual(unpack_count, 1) 9587 9588 def test_saved_tensors_hook_version_counter_not_shared(self): 9589 class Test(torch.autograd.Function): 9590 @staticmethod 9591 def forward(ctx, x): 9592 ctx.save_for_backward(x) 9593 return x.sin() 9594 9595 @staticmethod 9596 def backward(ctx, grad_output): 9597 (x,) = ctx.saved_tensors 9598 before = a._version 9599 x.add_(1) 9600 self.assertEqual(a._version, before) 9601 return grad_output 9602 9603 a = torch.tensor(1.0, requires_grad=True) 9604 a_replacement = a.clone() 9605 9606 def pack_hook(x): 9607 return a_replacement 9608 9609 def unpack_hook(x): 9610 return x 9611 9612 with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 9613 b = Test.apply(a) 9614 9615 b.backward() 9616 9617 def test_save_on_cpu_and_checkpoint(self): 9618 a = torch.randn(2, 2, requires_grad=True) 9619 9620 b = a.pow(2).pow(2).pow(2).pow(2) 9621 b.sum().backward() 9622 b_grad = a.grad.clone() 9623 a.grad.zero_() 9624 9625 with torch.autograd.graph.save_on_cpu(): 9626 h = a.pow(2) 9627 h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False) 9628 c = h.pow(2) 9629 c.sum().backward() 9630 c_grad = a.grad.clone() 9631 a.grad.zero_() 9632 9633 def f(a): 9634 h = a.pow(2) 9635 with torch.autograd.graph.save_on_cpu(): 9636 h = h.pow(2).pow(2) 9637 return h.pow(2) 9638 9639 d = checkpoint(f, a, use_reentrant=False) 9640 d.sum().backward() 9641 d_grad = a.grad.clone() 9642 9643 self.assertEqual(b_grad, c_grad) 9644 self.assertEqual(b_grad, d_grad) 9645 9646 def test_pack_hook_with_inplace_modification_should_fail(self): 9647 a = torch.randn(5, requires_grad=True) 9648 9649 def inc(x): 9650 x += 1 9651 return x 9652 9653 with torch.autograd.graph.saved_tensors_hooks(inc, lambda x: x): 9654 with self.assertRaisesRegex( 9655 RuntimeError, 9656 "A saved tensor pack hook is modifying its input in place.", 9657 ): 9658 y = torch.exp(a) 9659 9660 y = torch.exp(a) 9661 with self.assertRaisesRegex( 9662 RuntimeError, "A saved tensor pack hook is modifying its input in place." 9663 ): 9664 y.grad_fn._raw_saved_result.register_hooks(inc, lambda x: x) 9665 9666 def test_saving_variable_to_disk(self): 9667 with tempfile.TemporaryDirectory() as tmp_dir: 9668 9669 def pack(x): 9670 name = os.path.join(tmp_dir, str(uuid.uuid4())) 9671 torch.save(x, name) 9672 return name 9673 9674 def unpack(name): 9675 return torch.load(name) 9676 9677 with torch.autograd.graph.saved_tensors_hooks(pack, unpack): 9678 a = torch.ones(5, requires_grad=True) 9679 y = a * a 9680 self.assertEqual(a, y.grad_fn._saved_self) 9681 9682 y.sum().backward() 9683 self.assertEqual(2 * a, a.grad) 9684 9685 def test_default_saved_tensors_hooks_double_backward(self): 9686 with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9687 a = torch.randn(5, requires_grad=True) 9688 y = a**3 9689 s = torch.sum(y) 9690 (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9691 g.sum().backward() 9692 self.assertEqual(6 * a, a.grad) 9693 9694 with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9695 a = torch.randn(5, requires_grad=True) 9696 y = a**3 9697 s = torch.sum(y) 9698 (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9699 g.sum().backward() 9700 # factor 2 because only a is saved once 9701 self.assertEqual(6 * 2 * a, a.grad) 9702 9703 a = torch.randn(5, requires_grad=True) 9704 y = a**3 9705 s = torch.sum(y) 9706 with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9707 (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9708 g.sum().backward() 9709 # factor 4 because pow_backward is grad * (exp * self.pow(exp - 1)) 9710 # so grad is saved and self (i.e. a) is saved 9711 self.assertEqual(6 * 4 * a, a.grad) 9712 9713 with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9714 a = torch.randn(5, requires_grad=True) 9715 y = a**3 9716 s = torch.sum(y) 9717 (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9718 g.sum().backward() 9719 # combining the two above blocks: 2 * 4 = 8 9720 # note that in that sense, a is saved twice 9721 self.assertEqual(6 * 8 * a, a.grad) 9722 9723 def test_wrapped_number_saved_tensors_hooks(self): 9724 def err_hook(x): 9725 raise RuntimeError("this hook should not be called") 9726 9727 with torch.autograd.graph.saved_tensors_hooks(err_hook, err_hook): 9728 a = torch.randn(5, requires_grad=True) 9729 out = (a * 3).sum() 9730 # 3 is saved as a saved tensor because it is a wrapped number, but 9731 # wrapped numbers should be special cased to not trigger saved variable hooks 9732 torch.autograd.grad(out, (a,)) 9733 9734 def test_graph_save_on_cpu(self): 9735 def test(get_input, cuda, pin_memory): 9736 with torch.autograd.graph.save_on_cpu(pin_memory): 9737 a = get_input() 9738 if cuda: 9739 a.cuda() 9740 y = a * a 9741 self.assertEqual(a, y.grad_fn._saved_self) 9742 self.assertEqual(a, y.grad_fn._saved_other) 9743 self.assertEqual(a.dtype, y.grad_fn._saved_self.dtype) 9744 self.assertEqual(a.layout, y.grad_fn._saved_self.layout) 9745 if y.is_sparse: 9746 y = y.to_dense() 9747 y.sum().backward() 9748 9749 actual = 2 * a 9750 expected = a.grad 9751 if a.is_sparse: 9752 actual = actual.coalesce() 9753 expected = expected.coalesce() 9754 9755 self.assertEqual(actual, expected) 9756 9757 for cuda in [False] + ([True] if torch.cuda.is_available() else []): 9758 for pin_memory in [True, False]: 9759 # FloatTensor 9760 test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory) 9761 # DoubleTensor 9762 test( 9763 lambda: torch.randn(5, requires_grad=True, dtype=torch.double), 9764 cuda, 9765 pin_memory, 9766 ) 9767 # Sparse tensor 9768 x = torch.sparse_coo_tensor( 9769 torch.tensor([[1, 1]]).long(), 9770 torch.tensor([1.0, 1.0]), 9771 requires_grad=True, 9772 ) 9773 test(lambda: x, cuda, pin_memory) 9774 9775 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 9776 def test_graph_save_on_cpu_cuda(self): 9777 def f(x): 9778 a = x + 1 9779 return a * a 9780 9781 # with grad 9782 a = torch.ones(1, requires_grad=True, device="cuda") 9783 y = f(a) 9784 memory_with_grad = torch.cuda.memory_allocated() 9785 9786 del a 9787 del y 9788 9789 # without grad 9790 a = torch.ones(1, requires_grad=True, device="cuda") 9791 with torch.no_grad(): 9792 y = f(a) 9793 memory_without_grad = torch.cuda.memory_allocated() 9794 9795 self.assertGreater(memory_with_grad, memory_without_grad) 9796 9797 del a 9798 del y 9799 9800 # with hooks 9801 with torch.autograd.graph.save_on_cpu(): 9802 a = torch.ones(1, requires_grad=True, device="cuda") 9803 y = f(a) 9804 memory_with_hooks = torch.cuda.memory_allocated() 9805 self.assertEqual(memory_with_hooks, memory_without_grad) 9806 9807 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 9808 def test_scalar_grad_mixed_device(self): 9809 x = torch.tensor(1.0, requires_grad=True) 9810 y = torch.randn(2, 2, device="cuda") 9811 out = x * y 9812 out.sum().backward() 9813 9814 def test_multi_grad_all_hooks(self): 9815 t1 = torch.rand(2, requires_grad=True) 9816 t2 = torch.rand(2, requires_grad=True) 9817 t3 = torch.rand(2, requires_grad=True) 9818 t4 = torch.rand(2, requires_grad=True) 9819 9820 # Ensure we properly detect all types of Nodes here 9821 # C++ Node 9822 t1 = t1.mul(2) 9823 9824 # Python custom Function 9825 class Foo(Function): 9826 @staticmethod 9827 def forward(ctx, a): 9828 return a.clone() 9829 9830 @staticmethod 9831 def backward(ctx, gO): 9832 return gO 9833 9834 t2 = Foo.apply(t2) 9835 9836 # C++ Node 9837 t3 = torch._C._functions.UndefinedGrad()(t3) 9838 9839 # C++ Custom Op 9840 cpp_source = """ 9841struct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 9842 static torch::Tensor forward( 9843 torch::autograd::AutogradContext* ctx, 9844 const torch::Tensor& x) { 9845 return x.clone(); 9846 } 9847 9848 static torch::autograd::variable_list backward( 9849 torch::autograd::AutogradContext *ctx, 9850 torch::autograd::variable_list grad_output) { 9851 return grad_output; 9852 } 9853}; 9854 9855torch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { 9856 return CustomOpAutogradFunction::apply(x); 9857} 9858 9859TORCH_LIBRARY(test_autograd_cpp_node, m) { 9860 m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 9861} 9862 """ 9863 9864 module = load_inline( 9865 name="test_autograd_cpp_node", 9866 cpp_sources=cpp_source, 9867 functions="custom_op_backed_by_autograd_fn", 9868 verbose=True, 9869 ) 9870 9871 t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4) 9872 9873 res = [None] * 4 9874 count = [0] 9875 9876 def hook(grads): 9877 nonlocal res 9878 count[0] += 1 9879 res = [g is not None for g in grads] 9880 9881 handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook) 9882 9883 out = t2 * t3 9884 9885 out.sum().backward(inputs=(t2, t3), retain_graph=True) 9886 self.assertEqual(count[0], 1) 9887 self.assertEqual(res, [False, True, True, False]) 9888 9889 out.sum().backward(inputs=(t1, t4), retain_graph=True) 9890 self.assertEqual(count[0], 1) 9891 9892 out.sum().backward(inputs=(t1, t3), retain_graph=True) 9893 self.assertEqual(count[0], 2) 9894 self.assertEqual(res, [False, False, True, False]) 9895 9896 class Func(torch.autograd.Function): 9897 @staticmethod 9898 def forward(ctx, x): 9899 return x 9900 9901 @staticmethod 9902 def backward(ctx, gO): 9903 raise RuntimeError("error message") 9904 9905 out = Func.apply(t2) * t3 9906 with self.assertRaisesRegex(RuntimeError, "error message"): 9907 out.sum().backward(inputs=(t2, t3), retain_graph=True) 9908 self.assertEqual(count[0], 2) 9909 9910 handle.remove() 9911 out.sum().backward(inputs=(t1, t3), retain_graph=True) 9912 self.assertEqual(count[0], 2) 9913 9914 def test_multi_grad_any_hooks(self): 9915 hook_id = 0 9916 any_hook_handles: List[RemovableHandle] = [] 9917 9918 class MultiOutputModule(nn.Module): 9919 def __init__(self) -> None: 9920 super().__init__() 9921 self.lin = nn.Linear(3, 3) 9922 9923 def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 9924 z = self.lin(x) 9925 out = torch.sin(z), torch.cos(z) 9926 nonlocal hook_id 9927 z.register_hook(partial(hook, hook_id)) 9928 hook_id += 1 9929 any_hook_handles.append( 9930 torch.autograd.graph.register_multi_grad_hook( 9931 out, partial(hook, hook_id), mode="any" 9932 ) 9933 ) 9934 hook_id += 1 9935 return out 9936 9937 class Model(nn.Module): 9938 def __init__(self) -> None: 9939 super().__init__() 9940 self.mod1 = MultiOutputModule() 9941 self.mod2 = MultiOutputModule() 9942 9943 def forward(self, x: torch.Tensor) -> torch.Tensor: 9944 y = self.mod1(x) 9945 z = y[0] + y[1] 9946 return self.mod2(z) 9947 9948 hook_order: List[int] = [] 9949 hook_count = 0 9950 9951 def hook(hook_id: int, *unused): 9952 nonlocal hook_count 9953 nonlocal hook_order 9954 hook_count += 1 9955 hook_order.append(hook_id) 9956 9957 # Any hooks: IDs 1 and 3; regular hooks: IDs 0 and 2 9958 model = Model() 9959 inp = torch.randn((2, 3)) 9960 out = model(inp) 9961 (out[0] + out[1]).sum().backward() 9962 # Check that the any-hook runs only once and before the regular hook 9963 # for each module 9964 self.assertEqual(len(any_hook_handles), 2) 9965 self.assertEqual(hook_order, [3, 2, 1, 0]) 9966 9967 hook_id = 0 9968 hook_order.clear() 9969 any_hook_handles.clear() 9970 out = model(inp) 9971 for handle in any_hook_handles: 9972 handle.remove() 9973 (out[0] + out[1]).sum().backward() 9974 # Check that the any-hook does not run if removed 9975 self.assertEqual(hook_order, [2, 0]) 9976 9977 def test_multi_grad_hooks_invalid_mode(self): 9978 t1 = torch.rand(2, requires_grad=True) 9979 t2 = torch.rand(2, requires_grad=True) 9980 regex = r"Expects mode to be one of \('all', 'any'\) but got foo" 9981 with self.assertRaisesRegex(ValueError, regex): 9982 torch.autograd.graph.register_multi_grad_hook( 9983 (t1, t2), lambda _: None, mode="foo" 9984 ) 9985 9986 def test_pynode_destruction_deadlock(self): 9987 script = """ 9988import torch 9989 9990class Foo(torch.autograd.Function): 9991 @staticmethod 9992 def forward(ctx, x): 9993 return x.clone() 9994 9995 @staticmethod 9996 def forward(ctx, gO): 9997 return gO.clone() 9998 9999def get_out(): 10000 inp = torch.rand(2, requires_grad=True) 10001 10002 # The python function is first so that it runs 10003 # last in the backward pass 10004 right = Foo.apply(inp) 10005 10006 # An op that creates new memory 10007 left1 = inp.clone() 10008 # An op that saves its input 10009 left2 = left1 ** 2 10010 10011 # Inplace modify so that the backward for 10012 # left2 always raises an error 10013 left1 += 1 10014 10015 # An op that takes both side as input. 10016 # After running, both side's last op will be in 10017 # the ready queue 10018 # And the op for left will run first as it was 10019 # executed last during the forward 10020 out = left2 + right 10021 10022 return out 10023 10024# Nothing should be global variables here as, from what 10025# I can see, python leaks all the global objects 10026get_out().sum().backward() 10027 10028# This used to deadlock when the PyNode is being destroyed after 10029# the error is raised. 10030""" 10031 try: 10032 subprocess.check_output( 10033 [sys.executable, "-c", script], 10034 stderr=subprocess.STDOUT, 10035 # On Windows, opening the subprocess with the default CWD makes `import torch` 10036 # fail, so just set CWD to this script's directory 10037 cwd=os.path.dirname(os.path.realpath(__file__)), 10038 # It is ok to have an extra long timeout here as a timeout means the test failed 10039 timeout=20, 10040 ) 10041 except subprocess.TimeoutExpired as e: 10042 self.fail( 10043 msg="Example code timed out! See the code sample in the test for details." 10044 ) 10045 except subprocess.CalledProcessError as e: 10046 if e.returncode < 0: 10047 # Sometimes we segfault instead of deadlocking 10048 self.fail("Subprocess exited with a fatal signal") 10049 else: 10050 err_msg = ( 10051 "RuntimeError: one of the variables needed for gradient computation" 10052 ) 10053 self.assertTrue(err_msg in e.output.decode("utf-8")) 10054 10055 def test_view_func_replay(self): 10056 with torch.autograd._force_original_view_tracking(True): 10057 10058 def _assert_match_metadata(a, b): 10059 self.assertEqual(a.size(), b.size()) 10060 self.assertEqual(a.stride(), b.stride()) 10061 self.assertEqual(a.storage_offset(), b.storage_offset()) 10062 self.assertEqual(a.device, b.device) 10063 self.assertEqual(a.dtype, b.dtype) 10064 10065 def _test_fn(fn, inp, *args, use_unsafe_view_func=False): 10066 outs = fn(inp, *args) 10067 # handle functions that return multiple views (e.g. split) 10068 if isinstance(outs, torch.Tensor): 10069 outs = [outs] 10070 10071 for out in outs: 10072 self.assertTrue(out._is_view()) 10073 self.assertTrue(out._base is inp) 10074 10075 # forward view_func 10076 new_inp = inp.clone() 10077 _assert_match_metadata(new_inp, inp) 10078 if use_unsafe_view_func: 10079 new_out = out._view_func_unsafe(new_inp) 10080 else: 10081 new_out = out._view_func(new_inp) 10082 _assert_match_metadata(new_out, out) 10083 self.assertEqual(new_out, out) 10084 10085 # reverse view_func 10086 new_out = out.detach() 10087 new_inp = out._rev_view_func_unsafe(new_out) 10088 _assert_match_metadata(new_inp, inp) 10089 self.assertTrue(new_inp._is_view()) 10090 self.assertTrue(new_inp._base is new_out) 10091 10092 # test individual view ops 10093 _test_fn(torch.ops.aten.alias.default, torch.rand(2, 2)) 10094 _test_fn(torch.as_strided, torch.rand(2, 2), (4,), (1,)) 10095 _test_fn(torch.chunk, torch.rand(2, 4), 2, -1) 10096 _test_fn(torch.diagonal, torch.rand(4, 4)) 10097 _test_fn(torch.ops.aten.expand.default, torch.rand(4, 1), (-1, 3)) 10098 _test_fn(torch.narrow, torch.rand(2, 2), 0, 1, 1) 10099 _test_fn(torch.permute, torch.rand(2, 3, 4), (1, 0, 2)) 10100 _test_fn(torch.select, torch.rand(2, 2), 0, 0) 10101 _test_fn(torch.ops.aten.slice.Tensor, torch.rand(2, 2), 1, 1, 2) 10102 _test_fn(torch.split, torch.rand(2, 2), 1) 10103 _test_fn(torch.split_with_sizes, torch.rand(2, 4), [1, 3], -1) 10104 _test_fn(torch.squeeze, torch.rand(2, 1, 4)) 10105 _test_fn(torch.squeeze, torch.rand(2, 1, 4), 1) 10106 _test_fn(torch.squeeze, torch.rand(2, 1, 1, 4), [1, 2]) 10107 _test_fn(torch.t, torch.rand(2, 4)) 10108 _test_fn(torch.transpose, torch.rand(2, 4), 0, 1) 10109 _test_fn(torch.unbind, torch.rand(1, 5)) 10110 _test_fn(torch.ops.aten.unfold.default, torch.rand(1, 5), 1, 3, 2) 10111 _test_fn(torch.unsqueeze, torch.rand(2, 4), -2) 10112 _test_fn(torch.ops.aten.view.default, torch.rand(2, 10), (-1, 5, 2)) 10113 _test_fn(torch.view_as_complex, torch.rand(2, 2)) 10114 _test_fn(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat)) 10115 10116 # test view chains 10117 _test_fn( 10118 lambda x: x.unsqueeze(-1).transpose(-1, -2).squeeze(1), 10119 torch.randn(2, 4), 10120 ) 10121 _test_fn( 10122 lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1), 10123 torch.randn(2, 3, 4), 10124 ) 10125 _test_fn( 10126 lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0), 10127 torch.randn(2, 3, 4), 10128 ) 10129 10130 # chains with missing view_func()s use as_strided() to cover the gaps 10131 def chain_with_only_parent_view_func(x): 10132 with torch.autograd._force_original_view_tracking(True): 10133 x = x.split_with_sizes([1, 3], -1)[0] 10134 10135 with torch.autograd._force_original_view_tracking(False): 10136 x = x.chunk(2, 0) 10137 10138 return x 10139 10140 _test_fn(chain_with_only_parent_view_func, torch.randn(2, 3, 4)) 10141 10142 def chain_with_only_current_view_func(x): 10143 with torch.autograd._force_original_view_tracking(False): 10144 x = x.split_with_sizes([1, 3], -1)[0] 10145 10146 with torch.autograd._force_original_view_tracking(True): 10147 x = x.chunk(2, 0) 10148 10149 return x 10150 10151 _test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4)) 10152 10153 # TODO: Move this somewhere else 10154 # test NT views 10155 from torch.nested._internal.nested_tensor import ( 10156 nested_view_from_values_offsets, 10157 ) 10158 10159 values = torch.randn(10, 5) 10160 offsets = torch.tensor([0, 3, 6, 10]) 10161 _test_fn(nested_view_from_values_offsets, values, offsets) 10162 10163 nt = nested_view_from_values_offsets(values, offsets).clone().detach() 10164 _test_fn( 10165 torch.ops.aten._nested_get_values.default, nt, use_unsafe_view_func=True 10166 ) 10167 10168 def chain_nt_to_dense_back_and_forth(nt): 10169 # NJT1 -> dense -> NJT2 -> dense 10170 offsets2 = nt.offsets().clone().detach() 10171 return nested_view_from_values_offsets(nt.values(), offsets2).values() 10172 10173 _test_fn(chain_nt_to_dense_back_and_forth, nt, use_unsafe_view_func=True) 10174 10175 def chain_dense_to_nt_back_and_forth(values, offsets): 10176 offsets2 = offsets.clone().detach() 10177 # dense -> NJT1 -> dense -> NJT2 10178 return nested_view_from_values_offsets( 10179 nested_view_from_values_offsets(values, offsets).values(), offsets2 10180 ) 10181 10182 _test_fn( 10183 chain_dense_to_nt_back_and_forth, 10184 values, 10185 offsets, 10186 use_unsafe_view_func=True, 10187 ) 10188 10189 def test_view_func_replay_with_modified_state(self): 10190 with torch.autograd._force_original_view_tracking(True): 10191 base = torch.randn(3, 4, 5) 10192 view = base.select(1, 2) 10193 10194 def symint_visitor_fn(x): 10195 # modify saved index 10196 return x + 1 10197 10198 # ensure modifying state changes view replay 10199 new_base = torch.randn_like(base) 10200 new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn) 10201 self.assertEqual(new_view, new_base.select(1, 3)) 10202 10203 # ensure saved state reverts back afterwards 10204 self.assertEqual(view._view_func(new_base), new_base.select(1, 2)) 10205 10206 # check modifying tensor state. currently, slice_inverse() is the only 10207 # view that saves a tensor 10208 base = torch.randn(3, 4, 5) 10209 sliced = base[:, 2:3, :].detach() 10210 view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1) 10211 10212 replacement_shape = (1, 2, 3) 10213 10214 def tensor_visitor_fn(x): 10215 # return tensor with a smaller shape than the saved one 10216 return torch.randn(*replacement_shape) 10217 10218 # ensure modifying state changes view replay 10219 new_sliced = torch.ones_like(base)[:, 2:3, :].detach() 10220 new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn) 10221 self.assertEqual(new_view.shape, replacement_shape) 10222 self.assertEqual( 10223 new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1)) 10224 ) 10225 10226 # ensure saved state reverts back afterwards 10227 self.assertEqual(view._view_func(sliced), base) 10228 10229 def test_setup_context_when_forward_has_default_args(self): 10230 class PowFunction(Function): 10231 @staticmethod 10232 def forward(x, y=3): 10233 return torch.pow(x, y) 10234 10235 @staticmethod 10236 def setup_context(ctx, inputs, output): 10237 x, y = inputs 10238 ctx.save_for_backward(x) 10239 ctx.y = y 10240 10241 @staticmethod 10242 def backward(ctx, gO): 10243 (x,) = ctx.saved_tensors 10244 y = ctx.y 10245 return gO * y * torch.pow(x, y - 1), None 10246 10247 class PowFunctionWithClassmethod(Function): 10248 @classmethod 10249 def forward(cls, x, y=3): 10250 return torch.pow(x, y) 10251 10252 @classmethod 10253 def setup_context(cls, ctx, inputs, output): 10254 x, y = inputs 10255 ctx.save_for_backward(x) 10256 ctx.y = y 10257 10258 @classmethod 10259 def backward(cls, ctx, gO): 10260 (x,) = ctx.saved_tensors 10261 y = ctx.y 10262 return gO * y * torch.pow(x, y - 1), None 10263 10264 x = torch.tensor(2.0, requires_grad=True) 10265 10266 y = torch.tensor(8.0) 10267 y_expected = torch.tensor(12.0) 10268 10269 y1 = PowFunction.apply(x) 10270 (y1_expected,) = torch.autograd.grad(y1, x) 10271 10272 y2 = PowFunctionWithClassmethod.apply(x) 10273 (y2_expected,) = torch.autograd.grad(y2, x) 10274 10275 self.assertEqual(y, y1) 10276 self.assertEqual(y_expected, y1_expected) 10277 self.assertEqual(y, y2) 10278 self.assertEqual(y_expected, y2_expected) 10279 10280 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 10281 def test_gradcheck_default_device_placement_context(self): 10282 # During gradcheck with fast_mode=True, we create a random vector on the CPU device using a CPU generator. 10283 # This test ensures that this still works when the default device is set to something else by the user. 10284 with torch.device("cuda"): 10285 x = torch.randn(3, dtype=torch.double, requires_grad=True) 10286 10287 def func(inp): 10288 return inp**2.0 10289 10290 self.assertTrue(gradcheck(func, x, fast_mode=True)) 10291 10292 10293def index_perm_variable(shape, max_indices): 10294 if not isinstance(shape, tuple): 10295 shape = (shape,) 10296 10297 index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape) 10298 return index 10299 10300 10301def bernoulli_scalar(): 10302 return torch.tensor(0, dtype=torch.uint8).bernoulli_() 10303 10304 10305class TestAutogradForwardModeBatchedGrad(TestCase): 10306 def test_out_of_place_basic(self): 10307 a = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 10308 b = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 10309 self.assertTrue( 10310 gradcheck( 10311 torch.sin, 10312 a, 10313 check_forward_ad=True, 10314 check_batched_grad=True, 10315 check_batched_forward_grad=True, 10316 ) 10317 ) 10318 self.assertTrue( 10319 gradcheck( 10320 torch.add, 10321 (a, b), 10322 check_forward_ad=True, 10323 check_batched_grad=True, 10324 check_batched_forward_grad=True, 10325 ) 10326 ) 10327 10328 def test_out_of_place_not_same_layout(self): 10329 input = torch.zeros([2, 2]).transpose(0, 1) 10330 tangent = torch.zeros([2, 2, 2]) 10331 10332 def jvp(tangent): 10333 with fwAD.dual_level(): 10334 x = fwAD.make_dual(input, tangent) 10335 return fwAD.unpack_dual(x)[1] 10336 10337 x_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent) 10338 10339 self.assertIsNot(x_tangent, tangent) 10340 10341 def test_inplace_on_view_same_layout(self): 10342 input = torch.zeros([2, 2]) 10343 tangent = torch.zeros([2, 2, 2]) 10344 base = torch.zeros([2, 2]) 10345 view = base.view_as(base) 10346 10347 def jvp(tangent): 10348 with fwAD.dual_level(): 10349 x = fwAD.make_dual(input, tangent) 10350 view.copy_(x) 10351 return ( 10352 fwAD.unpack_dual(x)[1], 10353 fwAD.unpack_dual(view)[1], 10354 fwAD.unpack_dual(view._base)[1], 10355 ) 10356 10357 x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)( 10358 tangent 10359 ) 10360 10361 self.assertFalse( 10362 view_tangent._is_view() 10363 ) # Optimization to share the same tensor! 10364 self.assertIs(view_tangent, base_tangent) 10365 self.assertIs(x_tangent, tangent) 10366 10367 def test_inplace_on_view_not_same_layout(self): 10368 input = torch.zeros([2, 2]) 10369 tangent = torch.zeros([2, 2, 2]) 10370 view = torch.zeros([2, 2]).transpose(0, 1) 10371 10372 def jvp(tangent): 10373 with fwAD.dual_level(): 10374 x = fwAD.make_dual(input, tangent) 10375 view.copy_(x) 10376 return ( 10377 fwAD.unpack_dual(x)[1], 10378 fwAD.unpack_dual(view)[1], 10379 fwAD.unpack_dual(view._base)[1], 10380 ) 10381 10382 x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)( 10383 tangent 10384 ) 10385 10386 self.assertIs(view_tangent._base, base_tangent) 10387 self.assertIs(x_tangent, tangent) 10388 self.assertIsNot(view_tangent, tangent) 10389 10390 def test_metadata_check_for_storage_numel_skipped(self): 10391 # See: test_metadata_check_checks_storage_numel for the reverse of this test 10392 primal = torch.randn(5)[:4].detach() 10393 self.assertEqual(len(primal.storage()), 5) 10394 tangent = torch.randn(10, 4) 10395 10396 def jvp(tangent): 10397 with fwAD.dual_level(): 10398 dual = fwAD.make_dual(primal, tangent) 10399 _, unpacked_tangent = fwAD.unpack_dual(dual) 10400 10401 # No copy is made 10402 self.assertIs(tangent, unpacked_tangent) 10403 10404 # as_strided raises 10405 with self.assertRaisesRegex( 10406 RuntimeError, "can access memory outside of `tensor`" 10407 ): 10408 dual.as_strided((5,), (1,), 0) 10409 return unpacked_tangent 10410 10411 torch._vmap_internals._vmap(jvp, 0, 0)(tangent) 10412 10413 10414class TestAutogradForwardMode(TestCase): 10415 def tearDown(self): 10416 # Ensure that a failing test won't make others fail 10417 while fwAD._current_level >= 0: 10418 fwAD.exit_dual_level() 10419 10420 super().tearDown() 10421 10422 def test_forward_level_cleanup(self): 10423 def get_tensor_and_weak_ref(): 10424 # Create a new Tensor and weak reference 10425 t = torch.rand(2, requires_grad=True) 10426 return t, torch._C._WeakTensorRef(t) 10427 10428 # Sanity check that the helper function works as expected 10429 t, t_ref = get_tensor_and_weak_ref() 10430 self.assertFalse(t_ref.expired()) 10431 10432 del t 10433 self.assertTrue(t_ref.expired()) 10434 10435 # Main test code 10436 foo = torch.rand(2) 10437 10438 with fwAD.dual_level(): 10439 tangent, tangent_ref = get_tensor_and_weak_ref() 10440 self.assertFalse(tangent_ref.expired()) 10441 10442 dual = fwAD.make_dual(foo, tangent) 10443 self.assertFalse(tangent_ref.expired()) 10444 10445 # Make sure that the tangent we provided has been re-used as is 10446 self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent) 10447 10448 # Make sure that dual is keeping the tangent alive 10449 del tangent 10450 self.assertFalse(tangent_ref.expired()) 10451 10452 # Make sure that the dual level does not keep the c++ 10453 # version of the tangent alive 10454 del dual 10455 self.assertTrue(tangent_ref.expired()) 10456 10457 def test_size_check(self): 10458 foo = torch.rand(2) 10459 tangent = torch.rand(3) 10460 10461 with fwAD.dual_level(): 10462 with self.assertRaisesRegex( 10463 RuntimeError, 10464 "Trying to set a forward gradient that has a different size", 10465 ): 10466 dual = fwAD.make_dual(foo, tangent) 10467 10468 dual = fwAD.make_dual(foo, tangent[1:]) 10469 10470 def test_metadata_check_checks_storage_numel(self): 10471 primal = torch.randn(5)[:4].detach() 10472 self.assertEqual(len(primal.storage()), 5) 10473 tangent = torch.randn(4) 10474 10475 with fwAD.dual_level(): 10476 dual = fwAD.make_dual(primal, tangent) 10477 _, unpacked_tangent = fwAD.unpack_dual(dual) 10478 10479 # # Verify that mutating unpacked tangent does not affect the original tangent 10480 tangent_clone = tangent.clone() 10481 unpacked_tangent *= 2 10482 self.assertTrue(torch.allclose(tangent_clone, tangent)) 10483 10484 # as_strided runs without error 10485 dual.as_strided((5,), (1,), 0) 10486 10487 def test_metadata_check_checks_ignores_size_zero(self): 10488 a = torch.ones(0).as_strided((0, 1), (1, 1), 0) 10489 b = torch.ones(0).as_strided((0, 1), (1, 0), 0) 10490 10491 with fwAD.dual_level(): 10492 dual = fwAD.make_dual(a, b) 10493 torch.diagonal(dual, offset=0) 10494 10495 input = torch.rand([0, 1], dtype=torch.complex128, requires_grad=True) 10496 func = partial(torch.diagonal, offset=0) 10497 torch.autograd.gradcheck(func, (input,), check_forward_ad=True) 10498 10499 def test_metadata_check_when_primal_has_conj_bit(self): 10500 # Make sure the _has_same_storage_numel is a fallthrough, so that 10501 # conj bit does not materialize. If it materializes it would 10502 # cause the layout check to fail for views that do not index the 10503 # the entire storage. 10504 a = torch.randn(2, 2, dtype=torch.cdouble).conj() 10505 b = torch.rand_like(a) 10506 10507 self.assertTrue(torch.is_conj(a)) 10508 self.assertEqual(len(a.storage()), len(b.storage())) 10509 10510 with fwAD.dual_level(): 10511 dual = fwAD.make_dual(a, b) 10512 dual[1:] 10513 10514 def test_metadata_check_when_primal_has_neg_bit(self): 10515 # Make sure the _has_same_storage_numel is a fallthrough, so that 10516 # conj bit does not materialize. If it materializes it would 10517 # cause the layout check to fail for views that do not index the 10518 # the entire storage. 10519 a = torch.randn(2, 2, dtype=torch.cdouble).conj().imag 10520 b = torch.randn(2, 2, dtype=torch.cdouble).imag 10521 10522 self.assertTrue(torch.is_neg(a)) 10523 self.assertEqual(len(a.storage()), len(b.storage())) 10524 10525 with fwAD.dual_level(): 10526 dual = fwAD.make_dual(a, b) 10527 dual[1:] 10528 10529 def test_metadata_check_check_conj(self): 10530 keys = { 10531 "NEITHER": lambda x: x, 10532 "CONJ": lambda x: x.conj(), 10533 "NEG": lambda x: x._neg_view(), 10534 } 10535 10536 for primal_key, tangent_key in product(keys, keys): 10537 x = keys[primal_key](torch.randn(2, 3, 4, dtype=torch.cdouble)) 10538 t = keys[tangent_key](torch.randn(2, 3, 4, dtype=torch.cdouble)) 10539 10540 if primal_key == tangent_key: 10541 with fwAD.dual_level(): 10542 dual = fwAD.make_dual(x, t) 10543 self.assertTrue(fwAD.unpack_dual(dual).tangent is t) 10544 torch.real(dual) 10545 torch.imag(dual) 10546 else: 10547 with fwAD.dual_level(): 10548 dual = fwAD.make_dual(x, t) 10549 self.assertTrue(fwAD.unpack_dual(dual).tangent is not t) 10550 torch.real(dual) 10551 torch.imag(dual) 10552 10553 def test_metadata_check_ignore_storage_offset_for_zero_numel_tensor(self): 10554 # See https://github.com/pytorch/pytorch/issues/80507 10555 a = torch.tensor([1.0]).as_strided((0,), (1,), 1) 10556 b = torch.tensor([1.0]).as_strided((0,), (1,), 2) 10557 10558 with fwAD.dual_level(): 10559 dual_input = fwAD.make_dual(a, b) 10560 # Check that no copy is made 10561 self.assertIs(fwAD.unpack_dual(dual_input).tangent, b) 10562 10563 a = torch.tensor([1.0]).as_strided((1,), (2,), 0) 10564 b = torch.tensor([1.0]).as_strided((1,), (1,), 0) 10565 10566 with fwAD.dual_level(): 10567 dual_input = fwAD.make_dual(a, b) 10568 dual_input[1:] 10569 10570 # The following test functions want to ensure all the following behaviors: 10571 # - Ensure that default level system in the python binding works 10572 # - Ensure that only level 0 exists and nesting is properly disabled 10573 # - Ensure that printing works fine 10574 # - Ensure that basic packing/unpacking works 10575 # - Ensure that advanced packing/unpacking works 10576 # - For memory / version counter share 10577 # - For backward AD (regular ops) 10578 # - Ensure that view + inplace for both modes work fine 10579 # - Ensure we do proper cleanup on exit of a level 10580 10581 def test_default_level(self): 10582 foo = torch.rand(2) 10583 bar = torch.rand(2) 10584 10585 with fwAD.dual_level(): 10586 baz = fwAD.make_dual(foo, bar) 10587 baz_primal, baz_tangent = fwAD.unpack_dual(baz) 10588 self.assertEqual(baz_primal, foo) 10589 # We don't actually need to enforce that these two are the exact same python 10590 # object, feel free to relax in the future 10591 self.assertIs(baz_tangent, bar) 10592 10593 baz_primal, baz_tangent = fwAD.unpack_dual(baz) 10594 self.assertEqual(baz_primal, foo) 10595 self.assertEqual(baz_tangent, None) 10596 10597 def test_fwd_grad_enabled(self): 10598 # Tests some private helper functions to enable/disable fwd grad mode 10599 enabled = fwAD._is_fwd_grad_enabled() 10600 self.assertTrue(enabled) 10601 10602 try: 10603 torch._C._set_fwd_grad_enabled(False) 10604 enabled = fwAD._is_fwd_grad_enabled() 10605 self.assertFalse(enabled) 10606 finally: 10607 torch._C._set_fwd_grad_enabled(True) 10608 10609 enabled = fwAD._is_fwd_grad_enabled() 10610 self.assertTrue(enabled) 10611 10612 def test_set_fwd_grad_enabled(self): 10613 # Tests a private helper function 10614 try: 10615 torch._C._set_fwd_grad_enabled(False) 10616 enabled = fwAD._is_fwd_grad_enabled() 10617 self.assertFalse(enabled) 10618 10619 with fwAD._set_fwd_grad_enabled(True): 10620 enabled = fwAD._is_fwd_grad_enabled() 10621 self.assertTrue(enabled) 10622 10623 enabled = fwAD._is_fwd_grad_enabled() 10624 self.assertFalse(enabled) 10625 finally: 10626 torch._C._set_fwd_grad_enabled(True) 10627 10628 def test_nested_level(self): 10629 with fwAD.dual_level() as level: 10630 # For now only level 0 exists 10631 self.assertEqual(level, 0) 10632 10633 with fwAD.dual_level(): 10634 with self.assertRaisesRegex( 10635 RuntimeError, "Nested forward mode AD is not supported at the moment" 10636 ): 10637 nest_level = fwAD.enter_dual_level() 10638 10639 def test_set_fw_grad_having_own_fw_grad_at_same_level(self): 10640 foo = torch.rand(2) 10641 bar = torch.rand(2) 10642 baz = torch.rand(2) 10643 10644 with fwAD.dual_level(): 10645 dual = fwAD.make_dual(foo, bar) 10646 with self.assertRaisesRegex( 10647 RuntimeError, "has a forward gradient at the same level" 10648 ): 10649 fwAD.make_dual(baz, dual) 10650 10651 def test_codegen_ignores_undefined_outputs(self): 10652 # This test checks that codegen silently ignores undefined outputs 10653 # Below, grad_input is specified as False in grad_output_mask, so 10654 # convolution backward will return a undefined tensor in that position. 10655 # Note that for this test to work we need to make sure either grad_output 10656 # or weight to be a dual tensor, so grad_input requires forward grad 10657 weight = torch.randn(6, 1, 30, 30) 10658 inp = torch.rand((1, 1, 32, 32)) 10659 out = torch.nn.functional.conv2d(inp, weight) 10660 grad_out = torch.ones_like(out) 10661 10662 with fwAD.dual_level(): 10663 dual_weight = fwAD.make_dual(weight, torch.ones_like(weight)) 10664 grad_input, _, _ = torch.ops.aten.convolution_backward( 10665 grad_out, 10666 inp, 10667 dual_weight, 10668 (0,), 10669 (1, 1), 10670 (0, 0), 10671 (1, 1), 10672 False, 10673 (0, 0), 10674 1, 10675 (False, True, False), 10676 ) 10677 self.assertIsNone(grad_input) 10678 10679 def test_make_dual_inference_tensor_in_inference_mode(self): 10680 with torch.inference_mode(): 10681 foo = torch.rand(2) 10682 bar = torch.rand(2) 10683 foo_copy = foo.clone() 10684 10685 with fwAD.dual_level(): 10686 dual = fwAD.make_dual(foo, bar) 10687 self.assertFalse(dual._is_view()) 10688 10689 dual += 1 10690 self.assertFalse(torch.allclose(foo, foo_copy)) 10691 10692 def test_make_dual_torch_dispatch(self): 10693 counter = [0] 10694 10695 class MySubclass(torch.Tensor): 10696 def __new__(cls, data=None): 10697 return torch.Tensor._make_subclass(cls, data) 10698 10699 @classmethod 10700 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 10701 if func.overloadpacket == torch.ops.aten.alias: 10702 counter[0] += 1 10703 10704 # Make sure we can re-enable autograd here 10705 with torch.overrides.enable_reentrant_dispatch(): 10706 foo = torch.rand(1, requires_grad=True) 10707 self.assertIsNotNone(foo.exp().grad_fn) 10708 10709 with no_dispatch(): 10710 return func(*args, **kwargs) 10711 10712 a = torch.tensor(1.0) 10713 s = MySubclass(a) 10714 10715 with fwAD.dual_level(): 10716 # Only the primal has "alias" called on it 10717 fwAD.make_dual(s, torch.rand_like(s)) 10718 self.assertEqual(counter[0], 1) 10719 fwAD.make_dual(torch.rand_like(s), s) 10720 self.assertEqual(counter[0], 1) 10721 10722 def test_make_dual_forbid_integral_dtype(self): 10723 primal_f = torch.ones(2, 2, dtype=torch.float) 10724 primal_l = torch.ones(2, 2, dtype=torch.long) 10725 10726 tangent_f = torch.ones(2, 2, dtype=torch.float) 10727 tangent_l = torch.ones(2, 2, dtype=torch.long) 10728 10729 with fwAD.dual_level(): 10730 # Float Primal and Long Tangent 10731 with self.assertRaisesRegex( 10732 ValueError, "Expected tangent to be floating point or complex" 10733 ): 10734 fwAD.make_dual(primal_f, tangent_l) 10735 10736 # Long Primal and Long Tangent 10737 with self.assertRaisesRegex( 10738 ValueError, "Expected primal to be floating point or complex" 10739 ): 10740 fwAD.make_dual(primal_l, tangent_l) 10741 10742 # Long Primal and Float Tangent 10743 with self.assertRaisesRegex( 10744 ValueError, "Expected primal to be floating point or complex" 10745 ): 10746 fwAD.make_dual(primal_l, tangent_f) 10747 10748 def test_print(self): 10749 with fwAD.dual_level() as level: 10750 a = torch.rand(3) 10751 self.assertFalse("tangent=" in str(a)) 10752 10753 b = fwAD.make_dual(a, torch.rand(3)) 10754 self.assertFalse("tangent=" in str(a)) 10755 self.assertTrue("tangent=" in str(b)) 10756 10757 b_primal, b_tangent = fwAD.unpack_dual(b) 10758 self.assertFalse("tangent=" in str(b_primal)) 10759 self.assertFalse("tangent=" in str(b_tangent)) 10760 10761 def test_basic_packing_unpacking(self): 10762 foo = torch.rand(2) 10763 bar = torch.rand(2) 10764 10765 with fwAD.dual_level(): 10766 baz = fwAD.make_dual(foo, bar) 10767 baz_primal, baz_tangent = fwAD.unpack_dual(baz) 10768 self.assertEqual(baz_primal, foo) 10769 self.assertIs(baz_tangent, bar) 10770 10771 # Check unpacked dual is returned as a named tuple 10772 # NB: Every invocation of unpack_dual returns a new tensor view 10773 self.assertIsNot(baz_primal, fwAD.unpack_dual(baz).primal) 10774 self.assertEqual(baz_primal, fwAD.unpack_dual(baz).primal) 10775 self.assertIs(baz_tangent, fwAD.unpack_dual(baz).tangent) 10776 10777 # Check that packing/unpacking did not change the input 10778 foo_primal, foo_tangent = fwAD.unpack_dual(foo) 10779 self.assertEqual(foo_primal, foo) 10780 self.assertIsNone(foo_tangent) 10781 10782 def test_advanced_packing_unpacking(self): 10783 foo = torch.rand(2) 10784 bar = torch.ones(2) 10785 10786 # Memory and version counter check 10787 with fwAD.dual_level(): 10788 dual = fwAD.make_dual(foo, bar) 10789 10790 # Ensure that they are sharing memory and version counter 10791 self.assertEqual(dual.storage().data_ptr(), foo.storage().data_ptr()) 10792 10793 # Ensure we properly share the version counter 10794 self.assertEqual(foo._version, dual._version) 10795 foo.add_(1) 10796 self.assertEqual(foo._version, dual._version) 10797 10798 # Unpacking should only create aliases as well 10799 dual_primal, dual_tangent = fwAD.unpack_dual(dual) 10800 self.assertEqual(dual_primal.storage().data_ptr(), foo.storage().data_ptr()) 10801 self.assertEqual( 10802 dual_tangent.storage().data_ptr(), bar.storage().data_ptr() 10803 ) 10804 # And the tangent is actually re-used as-is so it is still the same Tensor 10805 self.assertIs(dual_tangent, bar) 10806 10807 # Ensure we properly share the version counter 10808 self.assertEqual(foo._version, dual_primal._version) 10809 foo.add_(1) 10810 self.assertEqual(foo._version, dual_primal._version) 10811 self.assertEqual(bar._version, dual_tangent._version) 10812 bar.add_(1) 10813 self.assertEqual(bar._version, dual_tangent._version) 10814 10815 # backward mode check 10816 with fwAD.dual_level(): 10817 foo.requires_grad_() 10818 bar.requires_grad_() 10819 10820 # Check that backward gradients properly propagates through packing/unpacking 10821 dual = fwAD.make_dual(foo, bar) 10822 p, t = fwAD.unpack_dual(dual) 10823 10824 gfoo, gbar = torch.autograd.grad( 10825 p.sum(), (foo, bar), retain_graph=True, allow_unused=True 10826 ) 10827 self.assertEqual(gfoo, torch.ones_like(foo)) 10828 self.assertIsNone(gbar) 10829 10830 gfoo, gbar = torch.autograd.grad( 10831 t.sum(), (foo, bar), retain_graph=True, allow_unused=True 10832 ) 10833 self.assertIsNone(gfoo) 10834 self.assertEqual(gbar, torch.ones_like(bar)) 10835 10836 # Check that forward gradients are impacted by detach() 10837 detached_dual = dual.detach() 10838 out = detached_dual * 2 10839 p, t = fwAD.unpack_dual(out) 10840 self.assertFalse(p.requires_grad) 10841 self.assertEqual(p, foo * 2) 10842 self.assertIsNone(t) 10843 10844 # Check that forward gradients are not impacted by no_grad 10845 with torch.no_grad(): 10846 out = dual * 3 10847 p, t = fwAD.unpack_dual(out) 10848 self.assertFalse(p.requires_grad) 10849 self.assertFalse(t.requires_grad) 10850 self.assertEqual(p, foo * 3) 10851 self.assertEqual(t, bar * 3) 10852 10853 # Check that forward gradients are not impacted by inplace detach 10854 dual = dual.clone() 10855 dual.detach_() 10856 out = dual * 2 10857 p, t = fwAD.unpack_dual(out) 10858 self.assertFalse(p.requires_grad) 10859 self.assertEqual(p, foo * 2) 10860 self.assertIsNone(t) 10861 10862 def test_view_inplace_non_differentiable_views(self): 10863 original_foo = torch.rand(2, dtype=torch.double) 10864 original_bar = torch.ones(2, dtype=torch.double) 10865 10866 # Do clones to be able to compare the values updated inplace 10867 # with the original content of these Tensors 10868 foo = original_foo.clone() 10869 bar = original_bar.clone() 10870 10871 with fwAD.dual_level(): 10872 # Note that in this test, we use "update" to mean computing the right tangent for the dual 10873 # All the inplace operations here are expected to update the primal value of the Tensors but 10874 # not always their tangents. 10875 # Also all mentions of "non differentiable view" here means non forward differentiable view 10876 # unless specified otherwise. 10877 # See note [Forward Grad View/inplace] for more details on how these views work. 10878 10879 # Check that inplace ops do not update non-differentiable views 10880 # Non differentiable view 10881 dual = fwAD.make_dual(foo, bar) 10882 dual *= 2 10883 # Check that non differentiable view's tangent was not updated 10884 self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10885 # Check that the computed result is correct 10886 self.assertEqual(bar, original_bar * 2) 10887 self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2) 10888 self.assertEqual(foo, original_foo * 2) 10889 self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 2) 10890 # Other non differentiable view 10891 dual_primal, dual_tangent = fwAD.unpack_dual(dual) 10892 self.assertIsNone(fwAD.unpack_dual(dual_primal)[1]) 10893 self.assertIsNone(fwAD.unpack_dual(dual_tangent)[1]) 10894 dual_primal *= 2 10895 # Ensure dual's tangent did not change 10896 self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4) 10897 self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2) 10898 dual_tangent *= 2 10899 # Ensure dual's primal did not change 10900 self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4) 10901 self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 4) 10902 10903 def test_view_inplace_differentiable_views(self): 10904 original_foo = torch.rand(2) 10905 original_bar = torch.ones(2) 10906 10907 # Do clones to be able to compare the values updated inplace 10908 # with the original content of these Tensors 10909 foo = original_foo.clone() 10910 bar = original_bar.clone() 10911 10912 with fwAD.dual_level(): 10913 # Check that inplace ops do update differentiable view but stop at non differentiable ones 10914 # A non differentiable view 10915 dual = fwAD.make_dual(foo, bar) 10916 # A differentiable view 10917 view = dual.narrow(0, 0, 1) 10918 view *= 2 10919 # Check that non differentiable view was not updated 10920 self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10921 # Check that differentiable view was updated 10922 self.assertEqual(fwAD.unpack_dual(dual)[1], torch.tensor([2.0, 1.0])) 10923 self.assertEqual(fwAD.unpack_dual(view)[1], torch.tensor([2.0])) 10924 10925 # Check that we track differentiable view even for Tensors that are not dual 10926 baz = torch.rand(2) 10927 baz += dual 10928 self.assertEqual(fwAD.unpack_dual(baz)[1], fwAD.unpack_dual(dual)[1]) 10929 # Updates on view should as well 10930 baz = torch.rand(2) 10931 baz[0] = dual[0] 10932 self.assertEqual(fwAD.unpack_dual(baz)[1][0], fwAD.unpack_dual(dual)[1][0]) 10933 # Unused values get a gradient of 0 10934 self.assertEqual(fwAD.unpack_dual(baz)[1][1], 0.0) 10935 10936 # Check that forward non-differentiable views do prevent gradient update 10937 baz = torch.rand(2) 10938 view = baz.detach() 10939 view += dual 10940 self.assertIsNone(fwAD.unpack_dual(baz)[1]) 10941 10942 def test_view_inplace_always_creates_a_view(self): 10943 # See https://github.com/pytorch/pytorch/issues/67800 10944 # The codepath may depend on the op. At the time writing, when self is not a dual tensor 10945 # the resulting forward grad for self for... 10946 # - add_ has the same layout as self 10947 # - mul_ has the same layout as other 10948 # This is kind of fragile because the above depends on how the forward grad expression 10949 # is written. For add and mul at least, the output inherits the layout of LHS. 10950 # We want to handle at least these two cases. 10951 inplace_binary_ops = ( # Add more to this list? 10952 lambda x, y: x.add_(y), 10953 lambda x, y: x.mul_(y), 10954 lambda x, y: x.copy_(y), 10955 ) 10956 10957 for inplace_binary_op in inplace_binary_ops: 10958 base = torch.randn(2, 2) 10959 view = base.transpose(0, 1) 10960 10961 primal = torch.randn(2, 2) 10962 tangent = torch.randn(2, 2) 10963 10964 with fwAD.dual_level(): 10965 dual = fwAD.make_dual(primal, tangent) 10966 inplace_binary_op(view, dual) 10967 10968 # Verify that a view relationship is created for both the primal and tangent 10969 p, t = fwAD.unpack_dual(base) 10970 p_clone = p.clone() 10971 t_clone = t.clone() 10972 view *= 2 10973 p, t = fwAD.unpack_dual(base) 10974 10975 self.assertTrue(torch.allclose(p_clone * 2, p)) 10976 self.assertTrue(torch.allclose(t_clone * 2, t)) 10977 10978 def test_grad_cleanup(self): 10979 foo = torch.rand(2) 10980 bar = torch.rand(2) 10981 baz = torch.rand(2) 10982 10983 with fwAD.dual_level(): 10984 dual = fwAD.make_dual(foo, bar) 10985 self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10986 self.assertIs(fwAD.unpack_dual(dual)[1], bar) 10987 10988 self.assertIsNone(fwAD.unpack_dual(dual)[1]) 10989 10990 with fwAD.dual_level(): 10991 self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10992 new_dual = fwAD.make_dual(foo, baz) 10993 10994 dual_primal, dual_tangent = fwAD.unpack_dual(dual) 10995 new_dual_primal, new_dual_tangent = fwAD.unpack_dual(new_dual) 10996 self.assertEqual(dual_primal, new_dual_primal) 10997 self.assertIsNone(dual_tangent) 10998 self.assertEqual(new_dual_tangent, baz) 10999 11000 def test_detach_view_tracking(self): 11001 # Default detach is both forward and backward non-differentiable 11002 foo = torch.rand(2) 11003 foo_weak = torch._C._WeakTensorRef(foo) 11004 11005 out = foo.detach() 11006 11007 del foo 11008 self.assertTrue(foo_weak.expired()) 11009 11010 def test_out_variant(self): 11011 with fwAD.dual_level(): 11012 foo = fwAD.make_dual(torch.rand(2), torch.rand(2)) 11013 bar = torch.rand(2) 11014 11015 with self.assertRaisesRegex(RuntimeError, "out= function"): 11016 torch.add(bar, bar, out=foo) 11017 11018 with self.assertRaisesRegex(RuntimeError, "out= function"): 11019 torch.add(foo, bar, out=bar) 11020 11021 def test_non_differentiable(self): 11022 with fwAD.dual_level(): 11023 foo = fwAD.make_dual(torch.rand(2), torch.rand(2)) 11024 bar = torch.rand(2) 11025 11026 # No differentiable outputs, shouldn't error 11027 eq = foo == bar 11028 11029 # Inplace 11030 foo.eq_(bar) 11031 11032 def test_create_new_zeros_with_same_meta(self): 11033 new_zeroes_fn = torch.ops.aten._new_zeros_with_same_feature_meta 11034 11035 def check(a, b): 11036 def assert_same_meta(t, target): 11037 for num_bdim in range(t.dim()): 11038 result = new_zeroes_fn(t, target, self_num_batch_dims=num_bdim) 11039 11040 self.assertEqual(result.dim(), target.dim() + num_bdim) 11041 11042 # Check size/strides match for feature dims only 11043 for i in range(num_bdim, result.dim()): 11044 self.assertEqual(result.size()[i], target.size()[i - num_bdim]) 11045 self.assertEqual( 11046 result.stride()[i], target.stride()[i - num_bdim] 11047 ) 11048 11049 # Check that we generate strides reasonably 11050 if target.is_contiguous(): 11051 self.assertTrue(result.is_contiguous()) 11052 11053 self.assertEqual(result.storage_offset(), target.storage_offset()) 11054 11055 prod_of_t_bdims = reduce(operator.mul, t.size()[:num_bdim], 1) 11056 self.assertEqual( 11057 len(result.storage()), len(target.storage()) * prod_of_t_bdims 11058 ) 11059 11060 # TensorOptions is same 11061 self.assertEqual(result.dtype, target.dtype) 11062 11063 assert_same_meta(a, b) 11064 assert_same_meta(b, a) 11065 11066 a = torch.randn(5, dtype=torch.float) 11067 b = torch.randn(2, 3, 4, dtype=torch.double) 11068 check(a, b) 11069 11070 # non-contiguous case 11071 a = torch.randn(2, 3, 4).transpose(0, 1).contiguous().transpose(0, 1) 11072 b = torch.randn(2, 3, 4) 11073 check(a, b) 11074 11075 a = torch.randn(5).narrow(0, 1, 2) 11076 b = torch.randn(2) 11077 check(a, b) 11078 11079 # tensor is not a view, but still does not index entirety of storage 11080 a = torch.randn(5).resize_(4) 11081 b = torch.randn(4) 11082 check(a, b) 11083 11084 # Zero-numel tensors 11085 a = torch.randn(1, 0, 2) 11086 b = torch.randn(1, 2) 11087 check(a, b) 11088 11089 # Scalar tensor 11090 a = torch.tensor(1.0) 11091 b = torch.randn(1, 2) 11092 check(a, b) 11093 11094 def test_backward_graph_destruction(self): 11095 def fn(): 11096 a = torch.rand(10, requires_grad=True) 11097 11098 da = fwAD.make_dual(torch.rand_like(a), a) 11099 11100 # Create an object with a c++ cycle as: 11101 # db -> AutogradMeta -> ForwardGrad -> db's grad 11102 # db's grad -> AutogradMeta -> MulBackward 11103 # MulBackward -> SavedVariable -> db 11104 db = da.exp() 11105 11106 with fwAD.dual_level(): 11107 fn() 11108 # This test make sure that we don't deadlock on exit of this 11109 # context manager. If you do, there is something wrong with the 11110 # locking of the forward ad level most likely 11111 11112 11113# Generic device type autograd tests. 11114class TestAutogradDeviceType(TestCase): 11115 def test_min_max_median_backprops_to_all_values(self, device): 11116 for f in [torch.min, torch.max, torch.median, torch.nanmedian]: 11117 x1 = torch.tensor( 11118 [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True 11119 ) 11120 x2 = torch.tensor( 11121 [float("nan"), float("nan"), float("nan")], requires_grad=True 11122 ) 11123 for x in [x1, x2]: 11124 y = f(x) 11125 y.backward() 11126 self.assertEqual(x.grad.sum(), 1.0) 11127 self.assertEqual((x.grad == 1 / 3).sum(), 3) 11128 11129 def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device): 11130 # tests that gradients are evenly distributed when there are multiple max/min values 11131 # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad 11132 # as is the case for test_min_max_median_backprops_to_all_values above 11133 fns = (torch.scatter_reduce, torch.index_reduce) 11134 reduces = ("amin", "amax") 11135 for fn, reduction in product(fns, reduces): 11136 input = torch.randn( 11137 (2, 3), device=device, dtype=torch.float64, requires_grad=True 11138 ) 11139 src = input.clone().detach_().requires_grad_(True) 11140 idx = torch.arange(2).to(dtype=torch.long, device=device) 11141 if fn == torch.scatter_reduce: 11142 idx = idx.unsqueeze(-1).expand((2, 3)) 11143 11144 gradcheck(fn, (input, 0, idx, src, reduction), check_batched_grad=False) 11145 11146 def test_scatter_index_reduce_prod_gradgrad_error(self, device): 11147 # test that double backward raises an error for the case where 2 zeros in src 11148 # are scattered to the same position in self 11149 input = torch.tensor( 11150 [1.0], device=device, dtype=torch.float64, requires_grad=True 11151 ) 11152 src = torch.tensor( 11153 [0.0, 0.0], device=device, dtype=torch.float64, requires_grad=True 11154 ) 11155 idx = torch.tensor([0, 0], device=device, dtype=torch.long) 11156 11157 for fn in (torch.scatter_reduce, torch.index_reduce): 11158 # check that this case passes on gradcheck 11159 gradcheck(fn, (input, 0, idx, src, "prod"), check_batched_grad=False) 11160 with self.assertRaisesRegex( 11161 RuntimeError, "Double backward is unsupported for" 11162 ): 11163 gradgradcheck(fn, (input, 0, idx, src, "prod")) 11164 11165 @skipIfMps # the test doesn't work on MPS as double types are not supported 11166 def test_parameter_resize(self, device): 11167 asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device)) 11168 11169 for i in range(2): 11170 with torch.no_grad(): 11171 asd.set_(asd[1:]) 11172 asd.grad = None 11173 11174 m = torch.cat((asd, asd)) 11175 m.sum().backward() 11176 11177 @skipIfMps # the test doesn't work on MPS as double types are not supported 11178 @dtypes(torch.double, torch.cdouble) 11179 def test_sparse_ctor_getter_backward(self, device, dtype): 11180 # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test 11181 def _test(size, sparse_dim, nnz, device): 11182 v_size = [nnz] + list(size[sparse_dim:]) 11183 i = torch.rand(sparse_dim, nnz) 11184 i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) 11185 i = i.to(torch.long) 11186 11187 inp = torch.randn( 11188 v_size, dtype=torch.double, device=device, requires_grad=True 11189 ) 11190 other = self.genSparseTensor( 11191 size, sparse_dim, nnz, is_uncoalesced=True, device=device, dtype=dtype 11192 )[0] 11193 11194 def fn(v): 11195 x = torch.sparse_coo_tensor(i, v, size, dtype=dtype, device=device) 11196 y = (x + other).coalesce() 11197 yv = y.values() 11198 new_v = yv.tanh() 11199 z = torch.sparse_coo_tensor(y.indices(), new_v, y.size()) 11200 return z.coalesce().values() 11201 11202 gradcheck(fn, (inp,), check_batched_grad=False) 11203 # FIXME: make gradgradcheck work. 11204 # gradgradcheck(fn, (inp,), check_batched_grad=False) 11205 11206 # assert that _values is non-differentiable 11207 with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): 11208 other.detach().requires_grad_()._values().backward( 11209 torch.ones_like(other._values()) 11210 ) 11211 11212 for empty_i, empty_v, empty_nnz in product([True, False], repeat=3): 11213 sparse_size = [] if empty_i else [2, 1] 11214 dense_size = [1, 0, 2] if empty_v else [1, 2] 11215 nnz = 0 if empty_nnz else 5 11216 _test(sparse_size + dense_size, len(sparse_size), nnz, device) 11217 11218 @skipMeta 11219 @skipIfMps 11220 @dtypes(torch.double, torch.cdouble) 11221 def test_sparse_backward(self, device, dtype): 11222 class FixedGradientFunction(Function): 11223 @staticmethod 11224 def forward(ctx, x, grad_x): 11225 ctx.save_for_backward(grad_x) 11226 return x 11227 11228 @staticmethod 11229 def backward(ctx, grad_x): 11230 (saved_grad_x,) = ctx.saved_tensors 11231 return saved_grad_x, None 11232 11233 size = torch.Size([6, 3, 2]) 11234 i1 = torch.tensor([[0, 3, 4], [0, 2, 2]], dtype=torch.long) 11235 v1 = make_tensor([3, 2], dtype=dtype, device=device) 11236 sparse_grad1 = torch.sparse_coo_tensor(i1, v1, size, dtype=dtype, device=device) 11237 i2 = torch.tensor([[0, 1, 3, 4], [0, 1, 2, 2]], dtype=torch.long) 11238 v2 = make_tensor([4, 2], dtype=dtype, device=device) 11239 sparse_grad2 = torch.sparse_coo_tensor(i2, v2, size, dtype=dtype, device=device) 11240 dense_grad = torch.rand(size, device=device, dtype=dtype) 11241 fn = FixedGradientFunction 11242 11243 # sparse first 11244 x = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 11245 ( 11246 fn.apply(x, sparse_grad1) 11247 + fn.apply(x, dense_grad) 11248 + fn.apply(x, sparse_grad2) 11249 ).sum().abs().backward() 11250 self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) 11251 # dense first 11252 x = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 11253 ( 11254 fn.apply(x, dense_grad) 11255 + fn.apply(x, sparse_grad1) 11256 + fn.apply(x, sparse_grad2) 11257 ).sum().abs().backward() 11258 self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) 11259 # sparse only 11260 x = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 11261 (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward() 11262 self.assertEqual(x.grad, sparse_grad1 + sparse_grad2) 11263 11264 @skipIfMps 11265 def test_sparse_mask_autograd(self, device): 11266 tensor = torch.randn(3, requires_grad=True, device=device) 11267 mask = torch.ones(3, device=device) 11268 mask[1] = 0 11269 mask = mask.to_sparse() 11270 converted = tensor.sparse_mask(mask).to_dense() 11271 converted.sum().backward() 11272 self.assertEqual(tensor.grad, mask.to_dense()) 11273 11274 @skipIfMps # the test doesn't work on MPS as double types are not supported 11275 def test_pyscalar_conversions(self, device): 11276 def _test_pyscalar_conversions(t, integral_conv): 11277 # integral -> integral 11278 l = t(torch.zeros(1, 1, 1, dtype=torch.long)) 11279 pyscalar = -12345 11280 l[0] = pyscalar 11281 self.assertEqual(integral_conv(l), pyscalar) 11282 11283 # floating point -> floating point 11284 f = Variable(t(torch.randn(1, 1, dtype=torch.double))) 11285 pyscalar = -12345.1 11286 f[0] = pyscalar 11287 self.assertEqual(float(f), pyscalar) 11288 f[0] = nan 11289 self.assertTrue(math.isnan(float(f))) 11290 f[0] = inf 11291 self.assertEqual(float(f), inf) 11292 f[0] = -inf 11293 self.assertEqual(float(f), -inf) 11294 11295 # integral -> floating point 11296 # check we can convert something that loses precision 11297 pyscalar = 1234567890123456789 11298 self.assertNotEqual(pyscalar, integral_conv(float(pyscalar))) 11299 l[0] = pyscalar 11300 self.assertEqual(float(l), float(pyscalar)) 11301 11302 # floating point -> integral 11303 f[0] = nan 11304 self.assertRaises(ValueError, lambda: integral_conv(f[0])) 11305 f[0] = inf 11306 self.assertRaises(OverflowError, lambda: integral_conv(f[0])) 11307 f[0] = -inf 11308 self.assertRaises(OverflowError, lambda: integral_conv(f[0])) 11309 f[0] = sys.float_info.max 11310 self.assertEqual(integral_conv(f), sys.float_info.max) 11311 11312 # bool, nonzero 11313 def test_nonzero(tensor, value, expected): 11314 tensor[0] = value 11315 self.assertEqual(expected, bool(tensor)) 11316 self.assertEqual(expected, True if tensor else False) 11317 11318 test_nonzero(l, 0, False) 11319 test_nonzero(l, -2, True) 11320 test_nonzero(f, 0.0, False) 11321 test_nonzero(f, sys.float_info.min, True) 11322 test_nonzero(f, nan, bool(nan)) 11323 test_nonzero(f, inf, bool(inf)) 11324 test_nonzero(f, -inf, bool(-inf)) 11325 11326 _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x)) 11327 11328 @dtypesIfMPS(torch.float32) 11329 @dtypesIfCUDA( 11330 torch.half, 11331 torch.float, 11332 torch.double, 11333 torch.int8, 11334 torch.int16, 11335 torch.int32, 11336 torch.int64, 11337 ) 11338 @dtypes( 11339 torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64 11340 ) 11341 def test_set_requires_grad_only_for_floats(self, device, dtype): 11342 def f1(): 11343 a = torch.ones(1, dtype=dtype, device=device) 11344 a.requires_grad_() 11345 11346 def f2(): 11347 a = torch.ones(1, dtype=dtype, device=device) 11348 a.requires_grad = True 11349 11350 def f3(): 11351 torch.ones(1, dtype=dtype, device=device, requires_grad=True) 11352 11353 a = torch.ones(1, dtype=dtype, device=device) 11354 a.requires_grad = False # should always work 11355 a.requires_grad_(False) 11356 11357 for f in [f1, f2, f3]: 11358 if dtype.is_floating_point: 11359 f() 11360 else: 11361 with self.assertRaisesRegex( 11362 RuntimeError, 11363 "floating point", 11364 msg=f"dt: {a.dtype} device: {a.device}", 11365 ): 11366 f() 11367 11368 @onlyCUDA 11369 def test_advanced_indexing_backwards_large(self, device): 11370 # See https://github.com/pytorch/pytorch/issues/22843 11371 n = 1 << 16 11372 x = torch.rand(n, 1, device=device, requires_grad=True) 11373 a = x[:, [0]] 11374 a.sum().backward() 11375 self.assertEqual(x.grad, torch.ones(n, 1, device=device)) 11376 11377 def test_advanced_indexing_backwards_memory_format(self, device): 11378 # See https://github.com/pytorch/pytorch/issues/36956 11379 shape = (2, 8, 1, 2) 11380 i = torch.randint(1, shape, device=device).contiguous( 11381 memory_format=torch.channels_last 11382 ) 11383 x = torch.randn(shape, requires_grad=True, device=device) 11384 x[i].sum().backward() 11385 11386 def _test_reentrant_parent_error_on_cpu(self, device): 11387 t1 = torch.rand([3, 3], requires_grad=True) 11388 t2 = torch.rand([3, 3], device=device, requires_grad=True) 11389 t3 = torch.rand([3, 3], device=device, requires_grad=True) 11390 11391 # Parent graph cpu graph. 11392 t4 = t1 * t1 11393 t5 = TestAutograd.SimulateBackwardError.apply(t4) 11394 11395 # Child gpu graph (much longer than parent graph). 11396 prev = t2 * t2 11397 for i in range(10): 11398 prev = prev * t2 11399 reentrant_root = prev 11400 11401 class ReentrantFunc(Function): 11402 @staticmethod 11403 def forward(ctx, inp): 11404 return inp.clone() 11405 11406 @staticmethod 11407 def backward(ctx, grad): 11408 # Reentrant backward in child will take much longer. 11409 reentrant_root.backward() 11410 return grad 11411 11412 # Parent gpu graph. 11413 t6 = ReentrantFunc.apply(t3) 11414 t7 = t6 * t6 11415 11416 # Parent graph will error out first, while child graph will continue executing. 11417 with self.assertRaisesRegex(Exception, "Simulate error"): 11418 torch.autograd.backward([t5.sum(), t7.sum()]) 11419 11420 # No grads should be accumulated since child graph will stop execution 11421 # after parent receives error. 11422 self.assertIsNone(t2.grad) 11423 self.assertIsNone(t1.grad) 11424 self.assertIsNone(t3.grad) 11425 11426 @onlyCUDA 11427 def test_reentrant_parent_error_on_cpu(self, device): 11428 def _get_cuda_memory_usage(): 11429 # we don't need CUDA synchronize because the statistics are not tracked at 11430 # actual freeing, but at when marking the block as free. 11431 num_devices = torch.cuda.device_count() 11432 gc.collect() 11433 return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices)) 11434 11435 before = _get_cuda_memory_usage() 11436 11437 # Run as separate function so that gc can clean up everything when we 11438 # check for memory usage. 11439 self._test_reentrant_parent_error_on_cpu(device) 11440 11441 # Wait for autograd thread to cleanup failed tasks. 11442 after = _get_cuda_memory_usage() 11443 start = time.time() 11444 while before != after and time.time() - start < 30: 11445 time.sleep(0.1) 11446 after = _get_cuda_memory_usage() 11447 11448 self.assertEqual(before, after) 11449 11450 @skipIfMps # the test doesn't work on MPS 11451 # TODO: see if these tests can be ported to OpInfos or moved to where's test suite 11452 def test_where_functional(self, device): 11453 x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True) 11454 y = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True) 11455 cond = mask_not_all_zeros((5, 5)).to(device=device) 11456 11457 def where(cond, x, y): 11458 return torch.where(cond, x, y) 11459 11460 gradcheck(where, [cond, x, y], raise_exception=True) 11461 gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)]) 11462 11463 x = torch.randn(5, 1, 5, dtype=torch.double, device=device, requires_grad=True) 11464 y = torch.randn(5, 5, 1, dtype=torch.double, device=device, requires_grad=True) 11465 gradcheck(where, [cond, x, y], raise_exception=True) 11466 gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)]) 11467 11468 @skipIfMps # the test doesn't work on MPS 11469 def test_where_scalar(self, device): 11470 x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True) 11471 scalar = 4.0 11472 cond = mask_not_all_zeros((5, 5)).to(device=device) 11473 11474 def where_scalar_first(cond, x): 11475 return torch.where(cond, scalar, x) 11476 11477 def where_scalar_second(cond, x): 11478 return torch.where(cond, x, scalar) 11479 11480 gradcheck(where_scalar_first, (cond, x)) 11481 gradgradcheck(where_scalar_first, (cond, x)) 11482 11483 gradcheck(where_scalar_second, (cond, x)) 11484 gradgradcheck(where_scalar_second, (cond, x)) 11485 11486 @onlyCUDA 11487 def test_free_unneeded_tensor(self, device): 11488 x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True) 11489 m = torch.randn(1, 3, 1, 1, device=device) 11490 11491 z = x.sum() 11492 base_mem = torch.cuda.memory_allocated() 11493 z = ((x + 2) * m).sum() 11494 end_mem = torch.cuda.memory_allocated() 11495 11496 # In the end the memory usage should remain equal, because neither of 11497 # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the 11498 # previous allocation of z had the same size as the current one. 11499 self.assertEqual(base_mem, end_mem) 11500 11501 @onlyCUDA 11502 def test_pin_memory(self, device): 11503 x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 11504 self.assertEqual(x, x.pin_memory()) 11505 self.assertIsNot(x, x.pin_memory()) 11506 self.assertTrue(x.pin_memory().requires_grad) 11507 gradcheck(lambda x: x.pin_memory(), [x]) 11508 gradgradcheck(lambda x: x.pin_memory(), [x]) 11509 11510 @onlyCUDA 11511 def test_profiler_emit_nvtx(self, device): 11512 # This test is not intended to ensure correctness of nvtx ranges. 11513 # That would require something a great deal more complex (you'd have to create a 11514 # profile in a subprocess, open it, and parse the sql somehow). 11515 # This test is merely intended to catch if emit_nvtx breaks on construction. 11516 a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device) 11517 with torch.cuda.profiler.profile(): 11518 with emit_nvtx(): 11519 a.add(1.0) 11520 11521 @onlyCUDA 11522 def test_rnn_backward_to_input_but_not_parameters(self, device): 11523 # this checks whether it is possible to not require 11524 # weight parameters, but require inputs, see #7722 11525 l = torch.nn.LSTM(2, 3).to(device) 11526 for p in l.parameters(): 11527 p.requires_grad = False 11528 s = torch.randn(1, 1, 2, requires_grad=True, device=device) 11529 out, _ = l(s) 11530 out.sum().backward() 11531 self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0) 11532 11533 @unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required") 11534 def test_profiler_emit_itt(self, device): 11535 # This test is not intended to ensure correctness of itt ranges. 11536 # That would require something a great deal more complex (you'd have to create a 11537 # profile in a subprocess, open it, and parse the sql somehow). 11538 # This test is merely intended to catch if emit_itt breaks on construction. 11539 a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device) 11540 with emit_itt(): 11541 a.add(1.0) 11542 11543 @skipIfMps # the test doesn't work as randn is not supported with type long 11544 @deviceCountAtLeast(1) 11545 def test_grad_assignment(self, devices): 11546 x = torch.randn(5, 5, device=devices[0]) 11547 11548 # Tests that the wrong type raises 11549 with self.assertRaisesRegex(TypeError, "expected to be a Tensor or None"): 11550 x.grad = 0 11551 11552 # Tests that the wrong shape raises 11553 with self.assertRaises(RuntimeError): 11554 x.grad = torch.randn(2, 2, device=devices[0]) 11555 11556 # Tests that the wrong dtype raises 11557 with self.assertRaises(RuntimeError): 11558 x.grad = torch.randn(5, 5, dtype=torch.long, device=devices[0]) 11559 11560 # Tests that self-assignment raises 11561 with self.assertRaises(RuntimeError): 11562 x.grad = x 11563 11564 # Tests device -> cpu grad assignment raises 11565 if self.device_type != "cpu": 11566 with self.assertRaises(RuntimeError): 11567 t_cpu = torch.rand(5, 5) 11568 t_cpu.grad = torch.randn(5, 5, device=devices[0]) 11569 11570 # Tests half type on CUDA 11571 if self.device_type == "cuda": 11572 x = x.to(dtype=torch.half, device=devices[0]) 11573 x.grad = torch.zeros_like(x) 11574 11575 # Tests cross-device assignment raises 11576 if len(devices) > 1: 11577 x = torch.randn(5, 5, device=devices[0]) 11578 with self.assertRaises(RuntimeError): 11579 x.grad = torch.randn(5, 5, device=devices[1]) 11580 11581 @dtypesIfMPS(torch.float32) 11582 @deviceCountAtLeast(1) 11583 @dtypes(torch.float, torch.double) 11584 def test_requires_grad_factory(self, devices, dtype): 11585 fns = [torch.ones_like, torch.randn_like] 11586 x = torch.randn(2, 3, dtype=dtype, device=devices[0]) 11587 11588 for fn in fns: 11589 for requires_grad in [True, False]: 11590 output = fn( 11591 x, dtype=dtype, device=devices[0], requires_grad=requires_grad 11592 ) 11593 self.assertEqual(requires_grad, output.requires_grad) 11594 self.assertIs(dtype, output.dtype) 11595 self.assertEqual(devices[0], str(x.device)) 11596 11597 @deviceCountAtLeast(2) 11598 def test_unused_output_device(self, devices): 11599 from torch.nn.parallel._functions import Broadcast 11600 11601 x = torch.randn(5, 5, dtype=torch.float, device=devices[0], requires_grad=True) 11602 outputs = Broadcast.apply(list(range(len(devices))), x) 11603 y = outputs[-1] * 2 11604 y.sum().backward() 11605 self.assertEqual(x.grad, torch.ones(5, 5) * 2) 11606 11607 @deviceCountAtLeast(2) 11608 def test_backward_device(self, devices): 11609 # check that current device matches the variable's device 11610 device = [None] 11611 11612 class Identity(torch.autograd.Function): 11613 @staticmethod 11614 def forward(ctx, x): 11615 return x.clone() 11616 11617 @staticmethod 11618 def backward(ctx, grad_output): 11619 device[0] = grad_output.device 11620 return grad_output.clone() 11621 11622 v = torch.randn(1, device=devices[1], requires_grad=True) 11623 Identity.apply(v).backward() 11624 self.assertEqual(str(device[0]), devices[1]) 11625 11626 @deviceCountAtLeast(2) 11627 def test_inputbuffer_add_multidevice(self, devices): 11628 input = torch.randn(1, device=devices[0], requires_grad=True) 11629 output = input.to(device=devices[1]) + input.to(device=devices[1]) 11630 output.backward() 11631 11632 @onlyCPU 11633 def test_copy_(self, device): 11634 # At the time of writing this test, copy_ is not generated from native_functions.yaml 11635 # there was a bug that bfloat16 was not recognized as floating. 11636 x = torch.randn(10, device=device, requires_grad=True) 11637 floating_dt = floating_types_and(torch.half, torch.bfloat16) 11638 for dt in floating_dt: 11639 y = torch.empty(10, device=device, dtype=dt) 11640 y.copy_(x) 11641 self.assertTrue(y.requires_grad) 11642 z = x.to(torch.bfloat16) 11643 self.assertTrue(z.requires_grad) 11644 11645 def test_copy_forward_ad_broadcasting(self, device): 11646 # copy_ allows the src to have a different shape from self as long as src is 11647 # broadcastable to self. Make sure forward AD handles this case. 11648 primal = torch.rand(3, 3, device=device) 11649 tangent = torch.rand(3, 3, device=device) 11650 non_dual = torch.rand(1, 3, 3, device=device) 11651 11652 with fwAD.dual_level(): 11653 dual = fwAD.make_dual(primal, tangent) 11654 non_dual.copy_(dual) 11655 11656 def test_copy_forward_ad_same_layout_copies_grad(self, device): 11657 primal = torch.tensor([[3.0], [4.0]], device=device) 11658 tangent = torch.tensor([[5.0], [6.0]], device=device) 11659 11660 with fwAD.dual_level(): 11661 x_dual = fwAD.make_dual(primal, tangent) 11662 non_dual = torch.tensor([[1.0], [2.0]]) 11663 non_dual.copy_(x_dual) 11664 self.assertTrue(fwAD.unpack_dual(non_dual).tangent is not tangent) 11665 11666 @onlyCUDA 11667 def test_simple_reentrant_cross_device(self, device): 11668 class ReentrantFunc(Function): 11669 _cpu_mode = True 11670 11671 @staticmethod 11672 def forward(ctx, x): 11673 return x * (x + 2) 11674 11675 @staticmethod 11676 def backward(ctx, grad_output): 11677 with torch.enable_grad(): 11678 if ReentrantFunc._cpu_mode: 11679 new_param = torch.randn(2, 2, requires_grad=True) 11680 (new_param**2).sum().backward() 11681 else: 11682 new_param = torch.randn(2, 2, device=device, requires_grad=True) 11683 (new_param**2).sum().backward() 11684 return grad_output 11685 11686 # Reentrant starts on GPU thread, finishs on GPU thread 11687 x = torch.randn(2, 2, device=device, requires_grad=True) 11688 out = ReentrantFunc.apply(x) 11689 out.sum().backward() 11690 11691 # Reentrant starts on CPU thread, finishs on GPU thread 11692 x = torch.randn(2, 2, requires_grad=True) 11693 # set ReentrantFunc node to GPU to emit tasks to GPU queue 11694 ReentrantFunc._cpu_mode = False 11695 out = ReentrantFunc.apply(x) 11696 out.sum().backward() 11697 11698 # Reentrant starts on GPU thread, finishs on CPU thread 11699 x = torch.randn(2, 2, device=device, requires_grad=True) 11700 # set ReentrantFunc node to CPU to emit tasks to CPU queue 11701 ReentrantFunc._cpu_mode = True 11702 out = ReentrantFunc.apply(x) 11703 out.sum().backward() 11704 11705 @onlyCUDA 11706 def test_cross_device_reentrant_autograd(self, device): 11707 # Output on gpu so that this task will be associated with the gpu thread 11708 def fn_on_gpu(inp): 11709 # Artificially increase the priority of the next op to make sure it runs 11710 # as soon as we reach it before the ops of branch1. 11711 dummy = inp * 2 * 2 * 2 * 2 11712 return inp.to(device=device) 11713 11714 def parent_on_cpu(inp): 11715 # Slow branch of ops on gpu so that the work queue for the gpu thread 11716 # won't empty too quickly. They also have smaller priorities than the 11717 # ones created by fn_on_gpu 11718 branch1 = inp.to(device=device) 11719 branch1 = branch1 / branch1 11720 branch1 = branch1 / branch1 11721 branch1 = branch1 / branch1 11722 # Perform checkpoint on cpu tensors. So the last op performed in the reentrant 11723 # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread. 11724 # So the cpu thread will notify the gpu thread with an empty NodeTask. 11725 branch2 = checkpoint(fn_on_gpu, inp, use_reentrant=True) 11726 out = branch2 + branch1 11727 return out 11728 11729 inp = torch.rand(2, requires_grad=True) 11730 out = parent_on_cpu(inp) 11731 # This will segfault if the empty NodeTask is not handled properly in the 11732 # gpu thread ReadyQueue 11733 out.sum().backward() 11734 11735 def test_inplace_on_view_backprop_base(self, device): 11736 # modify view and back-prop through base 11737 root = torch.randn(2, 2, device=device, requires_grad=True) 11738 x = root.clone() 11739 v1 = x.narrow(0, 0, 1) 11740 v1.mul_(2) 11741 x.sum().backward() 11742 self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]]) 11743 11744 def test_inplace_on_view_backprop_view_of_view(self, device): 11745 # modify view and backprop through view-of-view 11746 root = torch.randn(2, 2, device=device, requires_grad=True) 11747 x = root.clone() 11748 v1 = x.narrow(0, 0, 1) 11749 v2 = x.narrow(0, 0, 1) 11750 v1.mul_(2) 11751 v2.sum().backward() 11752 self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]]) 11753 11754 def test_inplace_on_view_of_view(self, device): 11755 # modify view-of-view and backprop through base 11756 root = torch.randn(2, 2, device=device, requires_grad=True) 11757 x = root.clone() 11758 11759 v1 = x.narrow(0, 0, 1) 11760 v2 = v1.narrow(1, 1, 1) 11761 v2.mul_(2) 11762 x.sum().backward() 11763 self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]]) 11764 11765 @skipIfMps # the test doesn't work on MPS as double types are not supported 11766 def test_inplace_on_view_then_no_grad(self, device): 11767 # Perform an in-place operation on a view of a non-leaf variable. 11768 a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True) 11769 b = a * 2 11770 c = b.view_as(b) 11771 c[0][0] = 3 11772 11773 # Force a graph update with grad disabled. 11774 with torch.no_grad(): 11775 c.grad_fn 11776 11777 c.sum().backward() 11778 11779 @skipIfMps # the test doesn't work on MPS as double types are not supported 11780 def test_inplace_on_view_gradcheck(self, device): 11781 # gradcheck modifications to views 11782 a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True) 11783 b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True) 11784 11785 def func(root, b): 11786 x = root.clone() 11787 x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b) 11788 x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b) 11789 return x 11790 11791 gradcheck(func, [a, b], raise_exception=True) 11792 go = torch.randn( 11793 a.size(), dtype=torch.double, device=device, requires_grad=True 11794 ) 11795 gradgradcheck(func, (a, b), (go,)) 11796 11797 def test_inplace_on_view_multiple_outputs(self, device): 11798 root = torch.arange(9.0, dtype=torch.double).reshape(3, 3).requires_grad_() 11799 x = root.clone() 11800 v1 = x.unbind() 11801 with self.assertRaises(RuntimeError): 11802 v1[0].mul_(2) 11803 11804 @skipIfMps # the test doesn't work on MPS as double types are not supported 11805 def test_inplace_on_view_of_multiple_output_view(self, device): 11806 a = torch.rand( 11807 10, dtype=torch.double, device=device, requires_grad=True 11808 ).clone() 11809 b = a.unbind(0) 11810 c = b[0].view_as(b[0]) 11811 with self.assertRaises(RuntimeError): 11812 c.mul_(2) 11813 11814 @skipIfMps # MPS backend doesn't support double types 11815 def test_inplace_multiple_output_view_of_view(self, device): 11816 a = torch.rand( 11817 10, dtype=torch.double, device=device, requires_grad=True 11818 ).clone() 11819 b = a.view_as(a) 11820 c = b.unbind(0) 11821 with self.assertRaises(RuntimeError): 11822 c[0].mul_(2) 11823 11824 @skipIfMps # MPS backend doesn't support double types 11825 def test_inplace_on_view_makes_base_require_grad(self, device): 11826 # in-place modification to view makes base require grad 11827 a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False) 11828 b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True) 11829 11830 def func(root, b): 11831 x = root.clone() 11832 self.assertFalse(x.requires_grad) 11833 x.narrow(1, 2, 2).mul_(b) 11834 self.assertTrue(x.requires_grad) 11835 return x 11836 11837 gradcheck(func, [a, b], raise_exception=True) 11838 go = torch.randn( 11839 a.size(), dtype=torch.double, device=device, requires_grad=True 11840 ) 11841 gradgradcheck(func, (a, b), (go,)) 11842 11843 def test_inplace_on_view_backprop_view(self, device): 11844 # modify view and backprop through view 11845 a = torch.tensor([2.0, 5.0], device=device, requires_grad=False) 11846 b = torch.tensor([3.0], device=device, requires_grad=True) 11847 res = a.narrow(0, 1, 1).mul_(b) 11848 res.sum().backward() 11849 self.assertEqual(b.grad.tolist(), [5]) 11850 self.assertIsNone(a.grad) 11851 11852 @skipIfMps # the test doesn't work on MPS as double types are not supported 11853 def test_inplace_on_view_modify_base(self, device): 11854 # Test that an in-place operation on a base that forced it to require 11855 # grad also forces any previous views to require grad and backprop 11856 # correctly 11857 r = torch.ones(1, dtype=torch.double, device=device, requires_grad=True) 11858 11859 def fn(r): 11860 x = torch.ones(5, dtype=torch.double, device=device) 11861 v = x.select(0, 1) 11862 self.assertFalse(v.requires_grad) 11863 self.assertIsNone(v.grad_fn) 11864 x.add_(r) # v is now dependent on r due to the in-place op on x 11865 self.assertTrue(v.requires_grad) 11866 return v 11867 11868 gradcheck(fn, [r]) 11869 gradgradcheck(fn, [r]) 11870 11871 @skipIfMps # the test doesn't work on MPS as double types are not supported 11872 def test_inplace_on_view_python(self, device): 11873 # in-place modifications of Python-autograd created view 11874 a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True) 11875 b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True) 11876 11877 class PyAdd(torch.autograd.Function): 11878 @staticmethod 11879 def forward(ctx, x, y): 11880 ctx.mark_dirty(x) 11881 x.add_(y) 11882 return x 11883 11884 @staticmethod 11885 def backward(ctx, grad): 11886 return grad, grad 11887 11888 def func(root, b): 11889 x = root.clone() 11890 PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b) 11891 PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b) 11892 return x 11893 11894 gradcheck(func, [a, b], raise_exception=True) 11895 go = torch.randn( 11896 a.size(), dtype=torch.double, device=device, requires_grad=True 11897 ) 11898 gradgradcheck(func, (a, b), (go,)) 11899 11900 def test_inplace_on_view_non_contig(self, device): 11901 root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True) 11902 x = root.clone() 11903 v1 = x.narrow(0, 0, 1) 11904 v2 = v1.narrow(1, 1, 1) 11905 v2.mul_(2) 11906 x.sum().backward() 11907 self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]]) 11908 11909 def test_inplace_on_view_multi_output_unsafe(self, device): 11910 for f in [ 11911 lambda t: t.unsafe_split(1), 11912 lambda t: t.unsafe_split_with_sizes((1, 1, 1)), 11913 lambda t: t.unsafe_chunk(3), 11914 ]: 11915 a = torch.randn(3, 3, device=device, requires_grad=True) 11916 b = a + a 11917 s1, s2, s3 = f(b) 11918 s1.mul_(s2) 11919 s1.sum().backward() 11920 11921 def test_inplace_on_view_multi_output_safe(self, device): 11922 for f in [ 11923 lambda t: t.split(1), 11924 lambda t: t.split_with_sizes((1, 1, 1)), 11925 lambda t: t.chunk(3), 11926 ]: 11927 a = torch.randn(3, 3, device=device, requires_grad=True) 11928 b = a + a 11929 s1, s2, s3 = f(b) 11930 error_msg = ( 11931 "This view is the output of a function that returns multiple views." 11932 ) 11933 with self.assertRaisesRegex(RuntimeError, error_msg): 11934 s1.mul_(s2) 11935 11936 def test_inplace_on_view_undefined_grad_output(self, device): 11937 a = torch.tensor([1.0], requires_grad=True) 11938 c = a.clone() 11939 v = c[:] 11940 b = torch.tensor(1.0, requires_grad=True) 11941 11942 class InplaceFunc(torch.autograd.Function): 11943 @staticmethod 11944 def forward(ctx, x, other): 11945 ctx.mark_dirty(x) 11946 return x.mul_(2) 11947 11948 @staticmethod 11949 def backward(ctx, grad): 11950 return grad * 2, None 11951 11952 out = InplaceFunc.apply(v, b) 11953 out.backward() 11954 self.assertIsNone(b.grad) 11955 self.assertEqual(a.grad.item(), 2) 11956 11957 @skipIfMps # the test doesn't work on MPS as double types are not supported 11958 def test_mv_grad_stride_0(self, device): 11959 # Reference: https://github.com/pytorch/pytorch/issues/38315 11960 mat = torch.randn(2, 2, dtype=torch.double, device=device) 11961 vec = torch.randn(1, dtype=torch.double, device=device).requires_grad_(True) 11962 11963 def fn(vec): 11964 # Expand inside the function to make sure the input to 11965 # gradcheck does not have overlapping memory 11966 vec = vec.expand(2) 11967 return (mat @ vec).sum() 11968 11969 gradcheck(fn, (vec)) 11970 gradgradcheck(fn, (vec)) 11971 11972 @onlyCUDA 11973 def test_gradcheck_input_output_different_device(self, device): 11974 x = torch.ones((1,), dtype=torch.double, device="cuda", requires_grad=True) 11975 gradcheck(lambda x: x.to("cpu"), (x,)) 11976 11977 x = torch.ones((1,), dtype=torch.double, device="cpu", requires_grad=True) 11978 gradcheck(lambda x: x.to("cuda"), (x,)) 11979 11980 def test_strided_leaf_grad_layout(self, device): 11981 # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf. 11982 for fmt_a in (torch.contiguous_format, torch.channels_last): 11983 for fmt_b in (torch.contiguous_format, torch.channels_last): 11984 a = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_a) 11985 b = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_b) 11986 a.requires_grad_() 11987 b.requires_grad_() 11988 # checks (1) for broadcasted gradients 11989 a.sum().backward() 11990 self.assertEqual(a.grad.stride(), a.stride()) 11991 b.sum().backward() 11992 self.assertEqual(b.grad.stride(), b.stride()) 11993 # checks (1) for non-broadcasted gradients 11994 a.grad = None 11995 b.grad = None 11996 (a * b).sum().backward() 11997 self.assertEqual(a.grad.stride(), a.stride()) 11998 self.assertEqual(b.grad.stride(), b.stride()) 11999 12000 # (2) If leaf isn't dense, checks that grads are rowmajor contiguous. 12001 c = torch.empty_strided((2, 2), (4, 2), device=device).copy_( 12002 torch.rand((2, 2), device=device) 12003 ) 12004 c.requires_grad_() 12005 d = torch.rand((2, 2), device=device) 12006 # checks (2) for broadcasted gradients 12007 c.sum().backward() 12008 self.assertEqual(c.grad.stride(), (2, 1)) 12009 # checks (2) for non-broadcasted gradients 12010 c.grad = None 12011 (c * d).sum().backward() 12012 self.assertEqual(c.grad.stride(), (2, 1)) 12013 12014 @skipIfMps 12015 def test_copy_r_to_c(self, device): 12016 out_c = torch.empty(3, 2, dtype=torch.cdouble, device=device) 12017 inp_r = torch.randn(3, 2, dtype=torch.double, device=device, requires_grad=True) 12018 12019 def do_test(): 12020 out_c.copy_(inp_r) 12021 out_c_inter = out_c.sum() 12022 out_c_inter.abs().backward() 12023 with torch.no_grad(): 12024 self.assertEqual( 12025 inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real 12026 ) 12027 12028 self.assertNotWarn(do_test) 12029 12030 def test_to_r_to_c(self, device): 12031 def do_test(): 12032 inp_r = torch.randn( 12033 3, 2, dtype=torch.double, device=device, requires_grad=True 12034 ) 12035 out = inp_r.to(torch.complex128) 12036 out_inter = out.sum() 12037 out_inter.abs().backward() 12038 with torch.no_grad(): 12039 self.assertEqual( 12040 inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real 12041 ) 12042 12043 self.assertNotWarn(do_test) 12044 12045 def test_non_differentiable_ops(self, device): 12046 # Just make sure the op doesn't raise an error 12047 # and resulting tensor has requires_grad=False. 12048 x = torch.tensor([[1, 2], [3, 4.0]], requires_grad=True, device=device) 12049 out = torch.isin(x, torch.tensor([2, 3], device=device)) 12050 self.assertFalse(out.requires_grad) 12051 12052 x = torch.randn(3, 3, requires_grad=True) 12053 out = torch.signbit(x) 12054 self.assertFalse(out.requires_grad) 12055 12056 def test_warning_in_backward(self, device): 12057 # Test warning during backward are always propagated as python warnings (gh-50209) 12058 # NOTE: For device=cuda, warning gets propagated from a worker thread 12059 a = torch.zeros((), device=device, requires_grad=True) 12060 b = torch._C._nn._test_warn_in_autograd(a) 12061 12062 with self.assertWarnsRegex(UserWarning, "Warn from backward"): 12063 b.backward() 12064 12065 def test_complex_scalar_backward(self, device): 12066 a = torch.zeros(1, device=device, requires_grad=True) 12067 b = a * 0.5j 12068 12069 msg = "grad can be implicitly created only for real scalar outputs" 12070 with self.assertRaisesRegex(RuntimeError, msg): 12071 b.backward() 12072 12073 with self.assertRaisesRegex(RuntimeError, msg): 12074 torch.autograd.grad(b, a) 12075 12076 def test_pow_real_negative_base_complex_exponent(self, device): 12077 # OpInfo doesn't naturally support input of mixed types, hence this test here. 12078 base = -torch.ones(2, device=device, dtype=torch.double) 12079 exponent = torch.randn( 12080 2, device=device, dtype=torch.cdouble, requires_grad=True 12081 ) 12082 12083 def fn(exponent): 12084 return torch.pow(base, exponent) 12085 12086 torch.autograd.gradcheck(fn, (exponent,)) 12087 12088 def fn(exponent): 12089 return torch.pow(-1, exponent) 12090 12091 torch.autograd.gradcheck(fn, (exponent,)) 12092 12093 def test_resize_version_bump(self, device): 12094 x = torch.rand((1,), device=device) 12095 y = torch.randn((3,), device=device) 12096 x.resize_((1, 2)) 12097 self.assertEqual(x._version, 1) 12098 x.resize_as_(y) 12099 self.assertEqual(x._version, 2) 12100 12101 # In the following cases, `resize` is no-op, 12102 # so no version bumps. 12103 x.resize_((3,)) 12104 self.assertEqual(x._version, 2) 12105 12106 x.resize_as_(y) 12107 self.assertEqual(x._version, 2) 12108 12109 12110class TestAllowMutationOnSaved(TestCase): 12111 def assertClonedLenEqual(self, ctx, n): 12112 self.assertEqual(len(list(ctx.cloned.items())), n) 12113 12114 def assertTIDMapLenEqual(self, ctx, n): 12115 self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n) 12116 12117 def test_basic(self): 12118 a = torch.rand(2, 3, requires_grad=True) 12119 12120 def fn(a): 12121 b = a.clone() 12122 out = (b**2).sum() 12123 b.sin_() 12124 out.sum().backward() 12125 return a.grad 12126 12127 msg = ( 12128 "variables needed for gradient computation has been modified by an inplace" 12129 ) 12130 with self.assertRaisesRegex(RuntimeError, msg): 12131 fn(a) 12132 12133 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12134 da = fn(a) 12135 12136 self.assertTrue(torch.allclose(a * 2, da)) 12137 self.assertClonedLenEqual(ctx, 0) 12138 12139 def test_views(self): 12140 a = torch.rand(2, 3, requires_grad=True) 12141 12142 def fn(a): 12143 b = a.clone() 12144 c = b.view_as(b) 12145 out = (b**2).sum() # How does this work? 12146 c.sin_() 12147 out.sum().backward() 12148 return a.grad 12149 12150 msg = ( 12151 "variables needed for gradient computation has been modified by an inplace" 12152 ) 12153 with self.assertRaisesRegex(RuntimeError, msg): 12154 fn(a) 12155 12156 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12157 da = fn(a) 12158 12159 self.assertClonedLenEqual(ctx, 0) 12160 self.assertTrue(torch.allclose(a * 2, da)) 12161 12162 def test_save_base_and_modify_view(self): 12163 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12164 a = torch.rand(2, 3, requires_grad=True) 12165 b = a.clone() 12166 c = b[:1] 12167 out = b**2 12168 # modify the view 12169 c *= 10 12170 # self.assertClonedLenEqual(ctx, 1) 12171 out.sum().backward() 12172 self.assertClonedLenEqual(ctx, 0) 12173 12174 self.assertClonedLenEqual(ctx, 0) 12175 self.assertTrue(torch.allclose(a * 2, a.grad)) 12176 12177 def test_save_view_modify_base(self): 12178 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12179 a = torch.rand(2, 3, requires_grad=True) 12180 b = a.clone() 12181 c = b[:] 12182 out = (c**2).sum() 12183 b *= 2 12184 out.backward() 12185 self.assertTrue(torch.allclose(a * 2, a.grad)) 12186 12187 def test_double_backward(self): 12188 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12189 a = torch.rand(2, 3, requires_grad=True) 12190 b = a.clone() 12191 out = (b**2).sum() 12192 b.sin_() 12193 torch.autograd.grad(out, a, create_graph=True) 12194 (da,) = torch.autograd.grad(out, a, create_graph=True) 12195 (d2a,) = torch.autograd.grad(da.sum(), a) 12196 12197 self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a)) 12198 self.assertClonedLenEqual(ctx, 0) 12199 12200 def test_saved_but_not_anymore(self): 12201 # Make sure we don't clone if the tensor was once saved, but 12202 # by the time we do in-place, it is no longer saved 12203 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12204 a = torch.randn(2, 3, requires_grad=True).clone() 12205 out = (a**2).sum() 12206 self.assertTIDMapLenEqual(ctx, 1) 12207 self.assertClonedLenEqual(ctx, 0) 12208 out.backward() 12209 a.sin_() 12210 self.assertClonedLenEqual(ctx, 0) 12211 out = (a**2).sum() 12212 a.sin_() 12213 self.assertClonedLenEqual(ctx, 1) 12214 del out 12215 self.assertClonedLenEqual(ctx, 0) 12216 12217 def test_saved_same_tensor_many_times(self): 12218 # We should only clone once 12219 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12220 a = torch.randn(2, 3, requires_grad=True).clone() 12221 b = a**2 12222 c = a**2 12223 a.sin_() 12224 self.assertClonedLenEqual(ctx, 1) 12225 del b, c 12226 self.assertClonedLenEqual(ctx, 0) 12227 12228 def test_saved_same_tensor_different_versions(self): 12229 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12230 a = torch.randn(2, 3, requires_grad=True).clone() 12231 b = a**2 12232 a.sin_() 12233 c = a**2 12234 a.sin_() 12235 self.assertClonedLenEqual(ctx, 2) 12236 del b 12237 self.assertClonedLenEqual(ctx, 1) 12238 del c 12239 self.assertClonedLenEqual(ctx, 0) 12240 12241 def test_with_math_views(self): 12242 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12243 a = torch.tensor([1 + 1j], requires_grad=True).clone() 12244 b = a.conj() 12245 out = (b**2).sum() 12246 a.sin_() 12247 out.abs().backward() 12248 12249 a = torch.tensor([1 + 1j], requires_grad=True).clone() 12250 b = a.conj() 12251 out = (b**2).sum() 12252 # in this case, it is no longer a view it seems 12253 b.sin_() 12254 out.abs().backward() 12255 12256 def test_with_out_variant(self): 12257 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12258 a = torch.tensor([1.0], requires_grad=True) 12259 b = torch.tensor([1.0]) 12260 c = torch.tensor([2.0]) 12261 out = a * b 12262 self.assertTIDMapLenEqual(ctx, 1) 12263 torch.sin(c, out=b) 12264 self.assertClonedLenEqual(ctx, 1) 12265 out.backward() 12266 self.assertClonedLenEqual(ctx, 0) 12267 12268 def test_backward_out_of_context(self): 12269 # Out of context 12270 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12271 a = torch.rand(2, 3, requires_grad=True) 12272 out = (a**2).sum() 12273 12274 msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" 12275 with self.assertRaisesRegex(AssertionError, msg): 12276 out.backward() 12277 12278 # Different context 12279 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12280 a = torch.rand(2, 3, requires_grad=True) 12281 out = (a**2).sum() 12282 12283 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12284 with self.assertRaisesRegex(AssertionError, msg): 12285 out.backward() 12286 12287 def test_disallow_nesting(self): 12288 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12289 msg = "allow_mutation_on_saved_tensors contexts cannot be nested" 12290 with self.assertRaisesRegex(RuntimeError, msg): 12291 with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12292 pass 12293 12294 12295class TestAutogradInferenceMode(TestCase): 12296 def _is_inference_tensor(self, tensor): 12297 try: 12298 err_msg = "Inference tensors do not track version counter" 12299 with self.assertRaisesRegex(RuntimeError, err_msg): 12300 tensor._version 12301 return True 12302 except AssertionError as e: 12303 return False 12304 12305 def test_inference_mode_context_manager(self): 12306 self.assertFalse(torch.is_inference_mode_enabled()) 12307 with torch.inference_mode(): 12308 self.assertTrue(torch.is_inference_mode_enabled()) 12309 with torch.inference_mode(False): 12310 self.assertFalse(torch.is_inference_mode_enabled()) 12311 self.assertTrue(torch.is_inference_mode_enabled()) 12312 self.assertFalse(torch.is_inference_mode_enabled()) 12313 12314 def test_inference_mode_decorator(self): 12315 def func(x): 12316 self.assertEqual(torch.is_inference_mode_enabled(), mode) 12317 return x * x 12318 12319 for mode, use_kwarg in product((True, False, None), (True, False)): 12320 if mode is None: 12321 if use_kwarg: 12322 decorated = torch.inference_mode(mode=func) 12323 else: 12324 decorated = torch.inference_mode(func) 12325 mode = True 12326 else: 12327 if use_kwarg: 12328 decorated = torch.inference_mode(mode=mode)(func) 12329 else: 12330 decorated = torch.inference_mode(mode)(func) 12331 12332 for requires_grad in (True, False): 12333 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12334 d = decorated(c) 12335 self.assertTrue(not mode or torch.is_inference(d)) 12336 self.assertEqual(d.requires_grad, requires_grad and not mode) 12337 12338 def test_inference_mode_tensor_creation(self): 12339 with torch.inference_mode(): 12340 # new tensors created through constructors are inference tensors 12341 c = torch.ones(1, 2, 3) 12342 self.assertFalse(c.requires_grad) 12343 self.assertTrue(torch.is_inference(c)) 12344 12345 # requires_grad doesn't change inference tensor behavior in InferenceMode 12346 tmp = torch.ones(1, 2, 3, requires_grad=True) 12347 self.assertTrue(tmp.requires_grad) 12348 self.assertTrue(torch.is_inference(tmp)) 12349 12350 tmp = torch.ones(1, 2, 3).requires_grad_(False) 12351 self.assertFalse(tmp.requires_grad) 12352 self.assertTrue(torch.is_inference(tmp)) 12353 12354 def test_inference_mode_existing_autograd_session(self): 12355 s = torch.ones(1, 2, 3, requires_grad=True) 12356 a = s.clone() 12357 12358 # `a` gets saved outside of inference mode 12359 out = a * a 12360 with torch.inference_mode(): 12361 a.add_(2) 12362 12363 self.assertFalse(torch.is_inference(a)) 12364 # tensors created outside of inference mode aren't 12365 # inference tensors, so they will still have their 12366 # version counters tracked 12367 err_msg = ( 12368 "one of the variables needed for gradient computation has been " 12369 "modified by an inplace operation" 12370 ) 12371 with self.assertRaisesRegex(RuntimeError, err_msg): 12372 out.backward(torch.ones_like(out)) 12373 12374 def test_inference_mode_inf_tensor_in_inf_mode_functional_op(self): 12375 def functional_op(x): 12376 return x * x 12377 12378 with torch.inference_mode(): 12379 for requires_grad in (True, False): 12380 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12381 12382 # performing a non-view operation produces a inference tensor 12383 # that does not require grad 12384 func_out = functional_op(c) 12385 self.assertTrue(torch.is_inference(func_out)) 12386 self.assertFalse(func_out.requires_grad) 12387 12388 def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self): 12389 @torch.inference_mode() 12390 def run_test(fn): 12391 for requires_grad in (True, False): 12392 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12393 12394 # after performing inplace operation, tensor is still 12395 # an inference tensor 12396 fn(c) 12397 self.assertTrue(torch.is_inference(c)) 12398 self.assertEqual(c.requires_grad, requires_grad) 12399 12400 run_test(lambda x: x.add_(2)) 12401 run_test(lambda x: x.transpose_(0, 1)) 12402 12403 # inplace ops with manual kernel for ADInplaceOrView key in VariableTypeManual.cpp 12404 run_test(lambda x: x.resize_(1, 2)) 12405 run_test(lambda x: x.resize_as_(torch.ones(1, 2))) 12406 run_test(lambda x: x.copy_(torch.ones(1, 2, 3))) 12407 12408 def test_inference_mode_inf_tensor_in_inf_mode_view_op(self): 12409 with torch.inference_mode(): 12410 for requires_grad in (True, False): 12411 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12412 12413 # perform view operation produces inference tensor 12414 # that does not require grad 12415 view_out = c.view(-1) 12416 self.assertTrue(torch.is_inference(view_out)) 12417 self.assertFalse(view_out.requires_grad) 12418 12419 def test_inference_mode_inf_tensor_in_normal_mode_functional_op(self): 12420 def functional_op(x): 12421 return x * x 12422 12423 for requires_grad in (True, False): 12424 with torch.inference_mode(): 12425 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12426 12427 func_out = functional_op(c) 12428 self.assertFalse(torch.is_inference(func_out)) 12429 self.assertFalse(func_out.requires_grad) 12430 self.assertTrue(func_out.is_leaf) 12431 12432 def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self): 12433 def run_test(fn): 12434 for requires_grad in (False, True): 12435 with torch.inference_mode(): 12436 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12437 12438 if requires_grad: 12439 # leaf variable that requires grad is being used in an inplace 12440 # operation when requires_grad=True 12441 pass 12442 else: 12443 err_msg = "Inplace update to inference tensor outside InferenceMode" 12444 with self.assertRaisesRegex(RuntimeError, err_msg): 12445 fn(c) 12446 12447 run_test(lambda x: x.add_(2)) 12448 run_test(lambda x: x.transpose_(0, 1)) 12449 12450 def test_inference_mode_inf_tensor_in_normal_mode_view_op(self): 12451 for requires_grad in (True, False): 12452 with torch.inference_mode(): 12453 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12454 12455 out = c.view(-1) 12456 self.assertTrue(torch.is_inference(out)) 12457 self.assertFalse(out.requires_grad) 12458 self.assertFalse(out._is_view()) 12459 self.assertTrue(out.is_leaf) 12460 12461 def test_normal_tensor_inplace_output_in_inference_mode(self): 12462 def run_test(fn): 12463 for requires_grad in (True, False): 12464 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12465 a = s.clone() 12466 12467 with torch.inference_mode(): 12468 fn(a) 12469 self.assertFalse(torch.is_inference(a)) 12470 self.assertEqual(a.requires_grad, requires_grad) 12471 12472 # inplace -> inplace 12473 fn(a) 12474 self.assertFalse(torch.is_inference(a)) 12475 self.assertEqual(a.requires_grad, requires_grad) 12476 12477 # inplace -> inplace -> view 12478 view_out = a.view(-1) 12479 self.assertFalse(torch.is_inference(view_out)) 12480 self.assertEqual(view_out.requires_grad, requires_grad) 12481 12482 run_test(lambda x: x.add_(2)) 12483 run_test(lambda x: x.transpose_(0, 1)) 12484 12485 def test_normal_tensor_inplace_output_in_normal_mode(self): 12486 def run_test(fn): 12487 for requires_grad in (True, False): 12488 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12489 a = s.clone() 12490 12491 with torch.inference_mode(): 12492 fn(a) 12493 self.assertFalse(torch.is_inference(a)) 12494 self.assertEqual(a.requires_grad, requires_grad) 12495 12496 fn(a) 12497 self.assertFalse(torch.is_inference(a)) 12498 self.assertEqual(a.requires_grad, requires_grad) 12499 12500 # inplace -> inplace 12501 fn(a) 12502 self.assertFalse(torch.is_inference(a)) 12503 self.assertEqual(a.requires_grad, requires_grad) 12504 12505 # inplace -> inplace -> view 12506 view_out = a.view(-1) 12507 self.assertFalse(torch.is_inference(view_out)) 12508 self.assertEqual(view_out.requires_grad, requires_grad) 12509 run_test(lambda x: x.add_(2)) 12510 run_test(lambda x: x.transpose_(0, 1)) 12511 12512 def test_normal_tensor_view_output_in_inference_mode(self): 12513 for requires_grad in (True, False): 12514 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12515 a = s.clone() 12516 12517 with torch.inference_mode(): 12518 out = a.view(-1) 12519 self.assertFalse(torch.is_inference(out)) 12520 self.assertEqual(out.requires_grad, requires_grad) 12521 self.assertTrue(out._is_view()) 12522 12523 # view -> view 12524 tmp = out.view(-1) 12525 self.assertFalse(torch.is_inference(tmp)) 12526 self.assertEqual(tmp.requires_grad, requires_grad) 12527 self.assertTrue(tmp._is_view()) 12528 self.assertTrue(tmp.is_leaf) 12529 12530 # view -> view -> inplace 12531 self.assertTrue(torch.is_inference_mode_enabled()) 12532 tmp.add_(2) 12533 self.assertFalse(torch.is_inference(tmp)) 12534 self.assertEqual(tmp.requires_grad, requires_grad) 12535 # Accessing is_leaf in python tries to update grad_fn and raises: 12536 # A view was created in inference mode and its base or 12537 # another view of its base has been modified inplace in normal mode 12538 # tmp.is_leaf 12539 self.assertEqual(a._version, tmp._version) 12540 12541 def test_normal_tensor_view_output_in_normal_mode(self): 12542 def functional_op(x): 12543 return x * x 12544 12545 for requires_grad in (True, False): 12546 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12547 a = s.clone() 12548 12549 with torch.inference_mode(): 12550 out = a.view(-1) 12551 self.assertFalse(torch.is_inference(out)) 12552 self.assertEqual(out.requires_grad, requires_grad) 12553 self.assertTrue(out._is_view()) 12554 self.assertTrue(out.is_leaf) 12555 12556 tmp = functional_op(out) 12557 self.assertFalse(torch.is_inference(tmp)) 12558 self.assertEqual(tmp.requires_grad, requires_grad) 12559 12560 if requires_grad: 12561 err_msg = ( 12562 "A view was created in inference mode and is being modified inplace" 12563 ) 12564 with self.assertRaisesRegex(RuntimeError, err_msg): 12565 out.add_(2) 12566 else: 12567 out.add_(2) 12568 12569 tmp = out.view(2, 3) 12570 self.assertFalse(torch.is_inference(tmp)) 12571 self.assertEqual(tmp.requires_grad, requires_grad) 12572 12573 def test_mix_inference_and_normal_tensor_functional_op(self): 12574 for requires_grad in (True, False): 12575 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12576 12577 with torch.inference_mode(): 12578 c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12579 12580 # add is safe since it doesn't save any variable for backward 12581 out = c.add(s) 12582 self.assertFalse(torch.is_inference(out)) 12583 self.assertEqual(out.requires_grad, requires_grad) 12584 if requires_grad: 12585 # leaf inference tensor with requires_grad=True can still have gradient 12586 out.backward(torch.ones_like(out)) 12587 self.assertEqual(c.grad, torch.ones_like(c)) 12588 12589 if requires_grad: 12590 err_msg = "Inference tensors cannot be saved for backward" 12591 with self.assertRaisesRegex(RuntimeError, err_msg): 12592 c * s 12593 12594 # TODO: Test this with an autograd.Function when it works 12595 # stack stopped capturing a TensorList input 12596 # # inference tensor in TensorList input 12597 # inputs = [s, c] 12598 # with self.assertRaisesRegex(RuntimeError, err_msg): 12599 # torch.stack(inputs) 12600 12601 def test_mix_inference_and_normal_tensor_inplace_op(self): 12602 for requires_grad in (True, False): 12603 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12604 a = s.clone() 12605 12606 with torch.inference_mode(): 12607 c = torch.ones(1, 2, 3) 12608 12609 self.assertTrue(torch.is_inference(c)) 12610 if requires_grad: 12611 err_msg = "Inference tensors cannot be saved for backward" 12612 with self.assertRaisesRegex(RuntimeError, err_msg): 12613 a.mul_(c) 12614 12615 # inference tensor in TensorList input 12616 err_msg = ( 12617 "out=... arguments don't support automatic differentiation, " 12618 "but one of the arguments requires grad" 12619 ) 12620 with self.assertRaisesRegex(RuntimeError, err_msg): 12621 torch.mul(s, s, out=c) 12622 else: 12623 a.mul_(c) 12624 err_msg = "Inplace update to inference tensor outside InferenceMode is not allowed" 12625 with self.assertRaisesRegex(RuntimeError, err_msg): 12626 torch.mul(s, s, out=c) 12627 12628 def test_mix_inference_and_normal_tensor_view_op(self): 12629 for requires_grad in (True, False): 12630 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12631 12632 with torch.inference_mode(): 12633 c = torch.ones(1, 2, 3) 12634 12635 # view_as is a composite op which calls view with only one 12636 # tensor argument. So there isn't a mixed inference and normal 12637 # tensor inputs for view ops 12638 tmp1 = c.view_as(s) 12639 self.assertTrue(torch.is_inference(tmp1)) 12640 self.assertFalse(tmp1.requires_grad) 12641 12642 # this is fine since its equivalent as s.view(c.sizes()) which 12643 # isn't a mixed input scenario 12644 tmp2 = s.view_as(c) 12645 self.assertFalse(torch.is_inference(tmp2)) 12646 self.assertEqual(tmp2.requires_grad, requires_grad) 12647 12648 def test_inference_mode_handle_direct_view_on_rebase(self): 12649 def run_test(fn): 12650 for requires_grad in (True, False): 12651 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12652 a = s.clone() 12653 12654 with torch.inference_mode(): 12655 view_out = a.view_as(a) 12656 12657 if requires_grad: 12658 err_msg = "A view was created in inference mode and is being modified inplace" 12659 with self.assertRaisesRegex(RuntimeError, err_msg): 12660 fn(view_out) 12661 else: 12662 fn(view_out) 12663 12664 run_test(lambda x: x.add_(2)) 12665 run_test(lambda x: x.transpose_(0, 1)) 12666 12667 def test_inference_mode_handle_indirect_view_on_rebase(self): 12668 def run_test(fn): 12669 for requires_grad in (True, False): 12670 s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12671 a = s.clone() 12672 12673 with torch.inference_mode(): 12674 view_out = a.view(-1) 12675 12676 fn(a) 12677 if requires_grad: 12678 err_msg = "A view was created in inference mode and its base or another view " 12679 with self.assertRaisesRegex(RuntimeError, err_msg): 12680 view_out.grad_fn 12681 else: 12682 view_out.grad_fn 12683 12684 run_test(lambda x: x.add_(2)) 12685 run_test(lambda x: x.transpose_(0, 1)) 12686 12687 12688class TestMultithreadAutograd(TestCase): 12689 def _run_py_multithread_fn( 12690 self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False 12691 ): 12692 class PropagatingThread(threading.Thread): 12693 """Helper class to propagate exception from child 12694 thread to main thread on join. 12695 12696 Reference: https://stackoverflow.com/a/31614591/5602957 12697 """ 12698 12699 def run(self): 12700 self.exception = None 12701 try: 12702 self.ret = super().run() 12703 except Exception as e: 12704 self.exception = e 12705 12706 def join(self, timeout=None): 12707 super().join(timeout) 12708 if self.exception: 12709 raise self.exception from self.exception 12710 return self.ret 12711 12712 threads = [] 12713 for idx in range(num_threads): 12714 p = PropagatingThread(target=fn, args=((idx, *args) if pass_idx else args)) 12715 p.start() 12716 threads.append(p) 12717 12718 for p in threads: 12719 p.join() 12720 12721 def test_multithreaded_exception_propagation(self): 12722 # Test whether exception in child thread 12723 # are propagated to main thread. 12724 def fn(): 12725 self.assertTrue(False) 12726 12727 with self.assertRaises(AssertionError): 12728 self._run_py_multithread_fn(fn) 12729 12730 def test_simple_backward(self): 12731 # simple multithreaded backward that create threads in the beginning of training 12732 # and everything else is training separately, i.e. inputs, operations, etc. 12733 def train_fn(): 12734 x = torch.ones(5, 5, requires_grad=True) 12735 y = (x + 3) * (x + 4) * 0.5 12736 y.sum().backward() 12737 self.assertEqual(x.grad, x + 3.5) 12738 12739 self._run_py_multithread_fn(train_fn) 12740 12741 def test_simple_backward_same_input(self): 12742 # simple multithreaded backward with only shared inputs (i.e. This is common 12743 # for things like Hogwild multithreaded training with multiple CPU threads) 12744 def train_fn_backward(x): 12745 y = (x + 3) * (x + 4) * 0.5 12746 y.sum().backward() 12747 12748 x = torch.ones(5, 5, requires_grad=True) 12749 self._run_py_multithread_fn(train_fn_backward, (x,)) 12750 # Since we are calling backward from multiple threads 12751 # and all threads share the same input, when we do backward 12752 # concurrently, different backwards will all accumulate to 12753 # the same .grad for each input, and the gradients should 12754 # be equal to num_threads * gradient 12755 self.assertEqual(x.grad, 10 * (x + 3.5)) 12756 12757 def train_fn_grad(x): 12758 y = (x + 3) * (x + 4) * 0.5 12759 grads = torch.autograd.grad(y.sum(), x) 12760 self.assertEqual(len(grads), 1) 12761 self.assertEqual(grads[0], x + 3.5) 12762 12763 # since we use functional grad() api, gradients will not 12764 # be accumulate to the same place and should be the same 12765 self._run_py_multithread_fn(train_fn_grad, (x,)) 12766 12767 def test_multi_grad_all_hooks(self): 12768 # Multihooks should behave independently per execution of backward 12769 # Test that the hook fired the number of times we ran backward 12770 # even if those executions occur concurrently on different threads 12771 t1 = torch.rand(2, requires_grad=True) 12772 t2 = torch.rand(2, requires_grad=True) 12773 t3 = torch.rand(2, requires_grad=True) 12774 t4 = torch.rand(2, requires_grad=True) 12775 12776 res = None 12777 count = [0] 12778 hook_lock = threading.Lock() 12779 12780 def hook(grads): 12781 nonlocal res 12782 with hook_lock: 12783 count[0] += 1 12784 grad_is_none = [g is not None for g in grads] 12785 if res is None: 12786 res = grad_is_none 12787 else: 12788 self.assertEqual(res, grad_is_none) 12789 12790 torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook) 12791 12792 out = (t2 * t3).sum() 12793 12794 def backward_retain_graph(out, t2, t3): 12795 out.backward(inputs=(t2, t3), retain_graph=True) 12796 12797 self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12798 12799 self.assertEqual(count[0], 5) 12800 self.assertEqual(res, [False, True, True, False]) 12801 12802 # Leave one hook partially applied 12803 res = None 12804 count = [0] 12805 err_count = [0] 12806 bw_count = [0] 12807 bw_count_lock = threading.Lock() 12808 err_count_lock = threading.Lock() 12809 12810 class Func(torch.autograd.Function): 12811 @staticmethod 12812 def forward(ctx, x): 12813 return x 12814 12815 @staticmethod 12816 def backward(ctx, gO): 12817 with bw_count_lock: 12818 bw_count[0] += 1 12819 if bw_count[0] == 1: 12820 raise RuntimeError("error message") 12821 else: 12822 return gO 12823 12824 out = (Func.apply(t2) * t3).sum() 12825 12826 def backward_retain_graph(out, t2, t3): 12827 try: 12828 out.backward(inputs=(t2, t3), retain_graph=True) 12829 except RuntimeError: 12830 with err_count_lock: 12831 err_count[0] += 1 12832 12833 self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12834 12835 self.assertEqual(count[0], 4) 12836 self.assertEqual(err_count[0], 1) 12837 self.assertEqual(res, [False, True, True, False]) 12838 12839 def test_multi_grad_any_hooks(self): 12840 # Multihooks should behave independently per execution of backward 12841 # Test that the hook fired the number of times we ran backward 12842 # even if those executions occur concurrently on different threads 12843 t1 = torch.rand(2, requires_grad=True) 12844 t2 = torch.rand(2, requires_grad=True) 12845 t3 = torch.rand(2, requires_grad=True) 12846 t4 = torch.rand(2, requires_grad=True) 12847 12848 res = None 12849 count = [0] 12850 hook_lock = threading.Lock() 12851 12852 def hook(grad): 12853 nonlocal res 12854 with hook_lock: 12855 count[0] += 1 12856 if res is None: 12857 res = "foo" 12858 else: 12859 self.assertEqual(res, "foo") 12860 12861 torch.autograd.graph.register_multi_grad_hook( 12862 (t1, t2, t3, t4), hook, mode="any" 12863 ) 12864 12865 out = (t2 * t3).sum() 12866 12867 def backward_retain_graph(out, t2, t3): 12868 out.backward(inputs=(t2, t3), retain_graph=True) 12869 12870 self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12871 self.assertEqual(count[0], 5) 12872 self.assertEqual(res, "foo") 12873 12874 # Raise an error in one thread's backward 12875 res = None 12876 count = [0] 12877 err_count = [0] 12878 bw_count = [0] 12879 bw_count_lock = threading.Lock() 12880 err_count_lock = threading.Lock() 12881 12882 class Func(torch.autograd.Function): 12883 @staticmethod 12884 def forward(ctx, x): 12885 return x 12886 12887 @staticmethod 12888 def backward(ctx, gO): 12889 with bw_count_lock: 12890 bw_count[0] += 1 12891 if bw_count[0] == 1: 12892 raise RuntimeError("error message") 12893 else: 12894 return gO 12895 12896 out = (Func.apply(t2) * t3).sum() 12897 12898 def backward_retain_graph(out, t2, t3): 12899 try: 12900 out.backward(inputs=(t2, t3), retain_graph=True) 12901 except RuntimeError: 12902 with err_count_lock: 12903 err_count[0] += 1 12904 12905 self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12906 12907 # Expect all 5 threads to increment count since the hook runs before 12908 # the custom backward 12909 self.assertEqual(count[0], 5) 12910 self.assertEqual(err_count[0], 1) 12911 self.assertEqual(res, "foo") 12912 12913 def test_dataparallel_saved_tensors_hooks(self): 12914 def pack(x): 12915 warnings.warn("pack") 12916 return x 12917 12918 _self = self 12919 12920 class Model(torch.nn.Module): 12921 def forward(self, x): 12922 with warnings.catch_warnings(record=True) as w: 12923 y = x * x 12924 if torch.cuda.device_count() >= 2: 12925 # DataParallel is calling the forward in different threads 12926 # without progating TLS, so hooks should not be called here 12927 _self.assertEqual(len(w), 0) 12928 else: 12929 # DataParallel only uses one thread 12930 # so hooks should be called here 12931 _self.assertGreater(len(w), 0) 12932 12933 x = torch.ones(5, 5, requires_grad=True) 12934 model = torch.nn.DataParallel(Model()) 12935 12936 with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): 12937 model(x) 12938 with warnings.catch_warnings(record=True) as w: 12939 y = x * x 12940 # hooks should be called here 12941 _self.assertGreater(len(w), 0) 12942 12943 def test_python_thread_in_middle(self): 12944 # User might write a network that starts on one CPU thread, then runs its second half 12945 # concurrently with other threads (either via python threading or fork/join calls), 12946 # then calls backward()/grad() on BOTH threads, like a Y pattern from input at the 12947 # bottom to output at the top. This way part of the GraphTask is being shared across 12948 # different threads and we need to ensure user specify retain_graph=True, otherwise 12949 # error out with the correct error message 12950 12951 # Case 1: multiple backward with python threads, retain_graph=False 12952 # should throw error in some threads with no retain_graph. 12953 success_vs_raises = [0, 0] 12954 12955 def train_fn_no_retain_graph(x): 12956 y = x + x**2 12957 try: 12958 y.sum().backward() 12959 success_vs_raises[0] += 1 12960 except RuntimeError as error: 12961 success_vs_raises[1] += 1 12962 self.assertRegex(str(error), "Specify retain_graph=True") 12963 12964 x_no_retain = torch.ones(5, 5, requires_grad=True) 12965 y_no_retain = x_no_retain + x_no_retain**2 12966 self._run_py_multithread_fn( 12967 train_fn_no_retain_graph, (y_no_retain,), num_threads=5 12968 ) 12969 # at least one thread will be success in this case, all other threads should raise 12970 # with the error that throw to user to recommend them specify retain_graph=True 12971 self.assertTrue(success_vs_raises[0] >= 1) 12972 12973 # multiple backward with python threads, no error with retain_graph=True 12974 def train_fn_retain_graph(x): 12975 y = x + x**2 12976 y.sum().backward(retain_graph=True) 12977 12978 x_retain = torch.ones(5, 5, requires_grad=True) 12979 y_retain = x_retain + x_retain**2 12980 self._run_py_multithread_fn(train_fn_retain_graph, (y_retain,), num_threads=5) 12981 # result should equal to num_thread * gradients 12982 self.assertEqual( 12983 x_retain.grad, 12984 5 * (4 * x_retain**3 + 6 * (x_retain**2) + 4 * x_retain + 1), 12985 ) 12986 12987 def test_fork_join_in_middle(self): 12988 # multiple backward with jit threads (fork/join primitive) 12989 # similar to test_python_thread_in_middle, we test with retain_graph=False/True 12990 12991 # Case 1: multiple grad() calls with jit threads, retain_graph=False 12992 # should throw error in some threads with no retain_graph. 12993 @torch.jit.script 12994 def train_fn_jit_no_retain(middle, orig_x): 12995 y = middle + middle**2 12996 return torch.autograd.grad([y.sum()], [orig_x]) 12997 12998 @torch.jit.script 12999 def train_fn_fork_join_calls_no_retain(x): 13000 y_no_retain = (x + 3) * (x + 4) * 0.5 13001 13002 fut = torch.jit._fork(train_fn_jit_no_retain, y_no_retain, x) 13003 grad_hat = train_fn_jit_no_retain(y_no_retain, x) 13004 grad = torch.jit._wait(fut) 13005 return grad, grad_hat 13006 13007 try: 13008 train_fn_fork_join_calls_no_retain(torch.randn(5, 5, requires_grad=True)) 13009 except RuntimeError as error: 13010 self.assertRegex(str(error), "Specify retain_graph=True") 13011 13012 # Case 2: no error with retain_graph=True 13013 @torch.jit.script 13014 def train_fn_jit_retain(middle, orig_x): 13015 y = middle + middle**2 13016 return torch.autograd.grad([y.sum()], [orig_x], retain_graph=True) 13017 13018 @torch.jit.script 13019 def train_fn_fork_join_calls_retain(x): 13020 y_retain = (x + 3) * (x + 4) * 0.5 13021 fut1 = torch.jit._fork(train_fn_jit_retain, y_retain, x) 13022 fut2 = torch.jit._fork(train_fn_jit_retain, y_retain, x) 13023 grad = train_fn_jit_retain(y_retain, x) 13024 grad1 = torch.jit._wait(fut1) 13025 grad2 = torch.jit._wait(fut2) 13026 return grad, grad1, grad2 13027 13028 grad, grad1, grad2 = train_fn_fork_join_calls_retain( 13029 torch.randn(5, 5, requires_grad=True) 13030 ) 13031 self.assertEqual(grad, grad1) 13032 self.assertEqual(grad, grad2) 13033 13034 def test_preserve_backtrace(self): 13035 class Foo(torch.autograd.Function): 13036 @staticmethod 13037 def forward(ctx, input): 13038 return input 13039 13040 @staticmethod 13041 def backward(ctx, *grad): 13042 raise ValueError("something") 13043 13044 t = torch.rand(10, requires_grad=True) 13045 try: 13046 Foo.apply(t).sum().backward() 13047 except Exception: 13048 import traceback 13049 13050 tb = sys.exc_info()[2] 13051 tb_str = "\n".join(traceback.format_tb(tb)) 13052 self.assertTrue('raise ValueError("something")' in tb_str) 13053 13054 # TODO(@anjali411): add an OpInfo based test for torch.cat 13055 # Issue: https://github.com/pytorch/pytorch/issues/51627 13056 # https://github.com/pytorch/pytorch/issues/75852 13057 def test_cat_stack_r_to_c(self): 13058 inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True) 13059 inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True) 13060 13061 def fn(x1, x2): 13062 return torch.cat((x1, x2), dim=-1) 13063 13064 def fn2(x1, x2): 13065 return torch.stack((x1, x2), dim=-1) 13066 13067 torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True) 13068 torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True) 13069 13070 torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True) 13071 torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True) 13072 13073 def test_set_multithreading_enabled_as_context_manager_and_function(self): 13074 # Test as a context manager 13075 with torch.autograd.set_multithreading_enabled(False): 13076 self.assertFalse(torch.autograd.is_multithreading_enabled()) 13077 self.assertTrue(torch.autograd.is_multithreading_enabled()) 13078 13079 with torch.autograd.set_multithreading_enabled(True): 13080 self.assertTrue(torch.autograd.is_multithreading_enabled()) 13081 self.assertTrue(torch.autograd.is_multithreading_enabled()) 13082 13083 with torch.autograd.set_multithreading_enabled(False): 13084 torch.autograd.set_multithreading_enabled(True) 13085 self.assertTrue(torch.autograd.is_multithreading_enabled()) 13086 self.assertTrue(torch.autograd.is_multithreading_enabled()) 13087 13088 torch.autograd.set_multithreading_enabled(False) 13089 self.assertFalse(torch.autograd.is_multithreading_enabled()) 13090 13091 torch.autograd.set_multithreading_enabled(True) 13092 self.assertTrue(torch.autograd.is_multithreading_enabled()) 13093 13094 @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 13095 def test_custom_function_propagates_errors_from_device_thread(self): 13096 class MyFunc(Function): 13097 @staticmethod 13098 def forward(ctx, x): 13099 return x 13100 13101 @staticmethod 13102 def backward(ctx, gO): 13103 raise RuntimeError("blah") 13104 return gO 13105 13106 t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda")) 13107 out = MyFunc.apply(t).sum() 13108 13109 with self.assertRaisesRegex(RuntimeError, "blah"): 13110 out.backward() 13111 13112 13113class TestNestedCheckpoint(TestCase): 13114 @staticmethod 13115 def grad(fn): 13116 def wrapper(x): 13117 with torch.enable_grad(): 13118 out = fn(x) 13119 (grad_input,) = torch.autograd.grad(out, inputs=(x,), create_graph=True) 13120 return grad_input 13121 13122 return wrapper 13123 13124 @staticmethod 13125 def sum(fn): 13126 def wrapped(x): 13127 return fn(x).sum() 13128 13129 return wrapped 13130 13131 @staticmethod 13132 def checkpoint(fn): 13133 def wrapped(*args, **kwargs): 13134 return torch.utils.checkpoint.checkpoint( 13135 fn, *args, use_reentrant=False, **kwargs 13136 ) 13137 13138 return wrapped 13139 13140 def get_tests(self, fn): 13141 grad, c = self.grad, self.checkpoint 13142 13143 tests = ( 13144 # function <> tuple of function arbitrarily wrapped in checkpoint in various ways 13145 (fn, (c(fn), c(c(fn)))), 13146 (grad(fn), (grad(c(fn)), grad(c(c(fn))))), 13147 ( 13148 grad(grad(fn)), 13149 (grad(c(grad(fn))), c(grad(grad(c(fn)))), grad(c(grad(c(fn))))), 13150 ), 13151 ( 13152 grad(grad(grad(fn))), 13153 (grad(c(grad(grad(c(fn))))), grad(c(grad(c(grad(c(fn))))))), 13154 ), 13155 ) 13156 return tests 13157 13158 def check_graph_dies(self, fn): 13159 def iter_graph(roots): 13160 if not roots: 13161 return 13162 seen = set() 13163 q = collections.deque() 13164 for node in roots: 13165 if node is not None: 13166 seen.add(node) 13167 q.append(node) 13168 13169 while q: 13170 node = q.popleft() 13171 for fn, _idx in node.next_functions: 13172 if fn in seen or fn is None: 13173 continue 13174 seen.add(fn) 13175 q.append(fn) 13176 13177 yield node 13178 13179 class Handle: 13180 __slot__ = ["node_name"] 13181 13182 def __init__(self, node_name): 13183 self.node_name = node_name 13184 13185 def scope(): 13186 a = torch.randn((), requires_grad=True) 13187 out = fn(a) 13188 refs = [] 13189 for node in iter_graph([out.grad_fn]): 13190 handle = Handle(node.name()) 13191 refs.append(weakref.ref(handle)) 13192 node.metadata["blah"] = handle 13193 return refs 13194 13195 refs = scope() 13196 node_names = [ref().node_name for ref in refs if ref() is not None] 13197 if len(node_names) > 0: 13198 print("Nodes still alive:", node_names) 13199 13200 self.assertEqual(len(node_names), 0) 13201 13202 @parametrize("early_stop", [True, False]) 13203 def test_nested_checkpoint(self, early_stop): 13204 with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13205 x = torch.randn((), requires_grad=True) 13206 13207 def f(x): 13208 out = x.sin().exp().sin() 13209 return out 13210 13211 def g(x): 13212 a = x.sin().exp().sin() 13213 b = x.sin().exp().sin() 13214 (ga,) = torch.autograd.grad(a, x) 13215 (gb,) = torch.autograd.grad(b, x) 13216 return x.sin() 13217 13218 for fn in (f, g): 13219 for expected_fn, actual_fns in self.get_tests(fn): 13220 expected = expected_fn(x) 13221 13222 for actual_fn in actual_fns: 13223 actual = actual_fn(x) 13224 self.assertTrue(torch.allclose(expected, actual)) 13225 self.check_graph_dies(actual_fn) 13226 13227 @parametrize("early_stop", [True, False]) 13228 def test_nested_checkpoint_two_children(self, early_stop): 13229 with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13230 grad, sum, c = self.grad, self.sum, self.checkpoint 13231 13232 def f(x): 13233 return x.sin().exp().sin() 13234 13235 def g(x): 13236 return x.cos().sin().exp() 13237 13238 def hc(x): 13239 return c(g)(c(f)(x)) 13240 13241 def h(x): 13242 return g(f(x)) 13243 13244 a = torch.randn(3, 3, requires_grad=True) 13245 expected = grad(sum(grad(sum(h))))(a) 13246 actual = grad(sum(grad(sum(c(hc)))))(a) 13247 self.assertTrue(torch.allclose(expected, actual)) 13248 13249 actual = grad(sum(c(grad(sum(c(hc))))))(a) 13250 self.assertTrue(torch.allclose(expected, actual)) 13251 13252 self.check_graph_dies(grad(c(hc))) 13253 self.check_graph_dies(grad(sum(grad(sum(c(hc)))))) 13254 self.check_graph_dies(grad(sum(c(grad(sum(c(hc))))))) 13255 13256 @parametrize("early_stop", [True, False]) 13257 def test_nested_checkpoint_non_tensor_inputs_and_outputs(self, early_stop): 13258 def fn(k, a, b, f): 13259 return f(k * a * b.exp()), 1, "abcd" 13260 13261 k = 3 13262 a = torch.tensor(2.0, requires_grad=True) 13263 b = torch.tensor(3.0, requires_grad=True) 13264 13265 def f(x): 13266 return x.sin() 13267 13268 with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13269 out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False) 13270 actual_grads = torch.autograd.grad(out, (a, b)) 13271 13272 out, _unused1, _unused2 = fn(k, a, b, f) 13273 expected_grads = torch.autograd.grad(out, (a, b)) 13274 for actual, expected in zip(actual_grads, expected_grads): 13275 self.assertTrue(torch.allclose(actual, expected)) 13276 13277 @parametrize("early_stop", [True, False]) 13278 def test_nested_checkpoint_kwargs(self, early_stop): 13279 def fn(a, blah=None): 13280 out = a.sin().exp() 13281 if blah is not None: 13282 out = out * blah 13283 return out.sin().exp() 13284 13285 a = torch.tensor(2.0, requires_grad=True) 13286 b = torch.tensor(3.0, requires_grad=True) 13287 13288 with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13289 out = checkpoint(fn, a, blah=b, use_reentrant=False) 13290 actual_grads = torch.autograd.grad(out, (a, b)) 13291 13292 out = fn(a, blah=b) 13293 expected_grads = torch.autograd.grad(out, (a, b)) 13294 for actual, expected in zip(actual_grads, expected_grads): 13295 self.assertTrue(torch.allclose(actual, expected)) 13296 13297 @parametrize("early_stop", [True, False]) 13298 def test_nested_checkpoint_same_graph(self, early_stop): 13299 counter = [0] 13300 13301 def hook(*_unused_args): 13302 counter[0] += 1 13303 13304 def fn(a): 13305 return a.sin().cos().sin() 13306 13307 a = torch.tensor(1.0, requires_grad=True) 13308 13309 with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13310 out = checkpoint(fn, a, use_reentrant=False) 13311 # The hook is registered on the original graph 13312 out.grad_fn.next_functions[0][0].register_hook(hook) 13313 # And backward is performed on the original graph 13314 out.backward() 13315 13316 self.assertEqual(counter[0], 1) 13317 13318 @parametrize("early_stop", [True, False]) 13319 def test_nested_checkpoint_reentrant_backwards(self, early_stop): 13320 def fn(a): 13321 x = a.sin().cos() 13322 out = x.sin() 13323 return x, out 13324 13325 def hook(*_unused_args): 13326 # do backward again, but skip over the part of the graph where 13327 # the hook was registered 13328 x.backward(retain_graph=True) 13329 13330 a = torch.tensor(1.0, requires_grad=True) 13331 with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13332 x, out = checkpoint(fn, a, use_reentrant=False) 13333 out.grad_fn.register_hook(hook) 13334 out.backward(retain_graph=True) 13335 13336 def test_nested_checkpoint_set_early_stop(self): 13337 counter = [0] 13338 13339 def clone(x): 13340 counter[0] += 1 13341 return x.clone() 13342 13343 def fn(x): 13344 # Since clone does not save anything, it is not recomputed iff 13345 # early stop is enabled. 13346 return clone(x.sin().cos()) 13347 13348 # Early stopping is enabled by default 13349 a = torch.tensor(1.0, requires_grad=True) 13350 out = checkpoint(fn, a, use_reentrant=False) 13351 out.backward() 13352 self.assertEqual(counter[0], 1) 13353 13354 # Try using the context manager to set early stopping to False. 13355 # Expect early stopping to be disabled for all checkpoints ran under 13356 # the context manager, even though context manager is no longer active 13357 # when backward/recomputation is performed. 13358 counter = [0] 13359 a = torch.tensor(1.0, requires_grad=True) 13360 with torch.utils.checkpoint.set_checkpoint_early_stop(False): 13361 out = checkpoint(fn, a, use_reentrant=False) 13362 13363 out.backward() 13364 self.assertEqual(counter[0], 2) 13365 13366 def test_nested_checkpoint_set_early_stop_no_recompution_needed(self): 13367 # Case 1: We have one tensor saved and its the input 13368 13369 # We have two different counters here because in this case we actually 13370 # do call into x.sin() at the python level during recomputation whether 13371 # or not early stop is enabled. This is because the early stopping 13372 # only happens at the autograd level (preventing us from reaching the 13373 # backend). 13374 python_dispatch_counter = [0] 13375 counter = [0] 13376 13377 class SinCounterMode(TorchDispatchMode): 13378 def __init__(self) -> None: 13379 self.count = 0 13380 13381 def __torch_dispatch__(self, func, types, args=(), kwargs=None): 13382 kwargs = {} if kwargs is None else kwargs 13383 if func is torch.ops.aten.sin.default: 13384 self.count += 1 13385 return func(*args, **kwargs) 13386 13387 def fn(x): 13388 counter[0] += 1 13389 return x.sin() 13390 13391 # With early stopping (enabled by default) 13392 a = torch.tensor(1.0, requires_grad=True) 13393 with SinCounterMode() as python_dispatch_counter: # noqa: F811 13394 out = checkpoint(fn, a, use_reentrant=False) 13395 out.backward() 13396 self.assertEqual(counter[0], 2) 13397 self.assertEqual(python_dispatch_counter.count, 1) 13398 13399 # Without early stopping 13400 counter = [0] 13401 a = torch.tensor(1.0, requires_grad=True) 13402 with SinCounterMode() as python_dispatch_counter: 13403 with torch.utils.checkpoint.set_checkpoint_early_stop(False): 13404 out = checkpoint(fn, a, use_reentrant=False) 13405 out.backward() 13406 self.assertEqual(counter[0], 2) 13407 self.assertEqual(python_dispatch_counter.count, 2) 13408 13409 # Case 2: Forward saves no tensors 13410 13411 # Since unpack isn't even called, counter is 1 whether or not early stop 13412 # is enabled! 13413 counter = [0] 13414 13415 def fn2(x): 13416 counter[0] += 1 13417 return x.clone() 13418 13419 # With early stopping (enabled by default) 13420 a = torch.tensor(1.0, requires_grad=True) 13421 out = checkpoint(fn2, a, use_reentrant=False) 13422 out.backward() 13423 self.assertEqual(counter[0], 1) 13424 13425 # Without early stopping 13426 counter = [0] 13427 a = torch.tensor(1.0, requires_grad=True) 13428 with torch.utils.checkpoint.set_checkpoint_early_stop(False): 13429 out = checkpoint(fn2, a, use_reentrant=False) 13430 out.backward() 13431 self.assertEqual(counter[0], 1) 13432 13433 13434class TestSelectiveActivationCheckpoint(TestCase): 13435 @unittest.skipIf(not TEST_CUDA, "requires CUDA") 13436 def test_flops_and_mem(self): 13437 # From https://github.com/pytorch/pytorch/pull/126320 13438 def get_act_mem(f): 13439 out = f() 13440 out.backward() 13441 # Why do one forward and backward? 13442 start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] 13443 out = f() 13444 cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] 13445 act_mem = (cur_mem - start_mem) / (1024 * 1024) 13446 out.backward() 13447 return act_mem 13448 13449 def get_bw_flops(f): 13450 # Normalized so that a 512 square matmul returns 1 13451 f().backward() 13452 out = f() 13453 # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so 13454 # it will be able to observe whether an op is cached or not. 13455 with FlopCounterMode(display=False) as mode: 13456 out.backward() 13457 return mode.get_total_flops() / (512**3 * 2) 13458 13459 x = torch.randn(512, 512, requires_grad=True, device="cuda") 13460 y = torch.randn(512, 512, requires_grad=True, device="cuda") 13461 13462 def fn(x, y): 13463 return torch.mm(x.cos(), y).sin().sum() 13464 13465 def fn_ac(x, y): 13466 return checkpoint(fn, x, y, use_reentrant=False) 13467 13468 def fn_sac(x, y): 13469 context_fn = functools.partial( 13470 create_selective_checkpoint_contexts, 13471 [torch.ops.aten.mm.default], 13472 ) 13473 out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13474 return out 13475 13476 def policy_fn(ctx, op, *args, **kwargs): 13477 if op == torch.ops.aten.mm.default: 13478 return CheckpointPolicy.MUST_SAVE 13479 else: 13480 return CheckpointPolicy.PREFER_RECOMPUTE 13481 13482 def fn_sac2(x, y): 13483 context_fn = functools.partial( 13484 create_selective_checkpoint_contexts, 13485 policy_fn, 13486 ) 13487 out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13488 return out 13489 13490 def policy_fn_bool(ctx, op, *args, **kwargs): 13491 return op == torch.ops.aten.mm.default 13492 13493 def fn_sac3(x, y): 13494 context_fn = functools.partial( 13495 create_selective_checkpoint_contexts, 13496 policy_fn_bool, 13497 ) 13498 out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13499 return out 13500 13501 act_mem_noac = get_act_mem(lambda: fn(x, y)) 13502 bw_flops_noac = get_bw_flops(lambda: fn(x, y)) 13503 13504 self.assertEqual(act_mem_noac, 2.0) 13505 self.assertEqual(bw_flops_noac, 2.0) 13506 13507 act_mem_ac = get_act_mem(lambda: fn_ac(x, y)) 13508 bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y)) 13509 13510 self.assertEqual(act_mem_ac, 0.0) 13511 self.assertEqual(bw_flops_ac, 3.0) 13512 13513 act_mem_sac = get_act_mem(lambda: fn_sac(x, y)) 13514 bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y)) 13515 13516 self.assertEqual(act_mem_sac, 1.0) 13517 self.assertEqual(bw_flops_sac, 2.0) 13518 13519 act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y)) 13520 bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y)) 13521 13522 self.assertEqual(act_mem_sac2, 1.0) 13523 self.assertEqual(bw_flops_sac2, 2.0) 13524 13525 act_mem_sac3 = get_act_mem(lambda: fn_sac3(x, y)) 13526 bw_flops_sac3 = get_bw_flops(lambda: fn_sac3(x, y)) 13527 13528 self.assertEqual(act_mem_sac3, 1.0) 13529 self.assertEqual(bw_flops_sac3, 2.0) 13530 13531 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13532 def test_output_already_has_autograd_meta(self): 13533 # View of tensor of non-differentiable dtype still has AutogradMeta 13534 def fn(x, y): 13535 return x.view(-1), y.sin().cos() 13536 13537 x = torch.tensor([1, 2, 3], dtype=torch.int64) 13538 y = torch.randn(3, requires_grad=True) 13539 13540 context_fn = functools.partial( 13541 create_selective_checkpoint_contexts, 13542 [torch.ops.aten.view.default], 13543 ) 13544 out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13545 out[1].sum().backward() 13546 13547 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13548 def test_subclass_dispatching_sizes(self): 13549 # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default 13550 # Caching such metadata ops can be problematic when the following are satisfied: 13551 # 13552 # 1. size/strides are dispatched upon 13553 # 2. our policy saves sizes 13554 ta = torch.randn(6, 2) 13555 13556 class CustomSizeDynamicShapesTensor(torch.Tensor): 13557 @staticmethod 13558 def __new__(cls, inner): 13559 return torch.Tensor._make_wrapper_subclass( 13560 # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 13561 # Calling the overload that has kwargs causes us to go down the first overload path, 13562 # which will **always** specialize sizes. 13563 # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 13564 cls, 13565 inner.size(), 13566 inner.stride(), 13567 None, 13568 None, 13569 inner.dtype, 13570 inner.layout, 13571 inner.device, 13572 False, 13573 inner.requires_grad, 13574 "sizes", 13575 ) 13576 13577 def __init__(self, inner): 13578 self.inner = inner 13579 13580 @classmethod 13581 def __torch_dispatch__(cls, func, types, args, kwargs): 13582 if kwargs is None: 13583 kwargs = {} 13584 args_inner = torch.utils._pytree.tree_map_only( 13585 cls, lambda x: x.inner, args 13586 ) 13587 out_inner = func(*args_inner, **kwargs) 13588 return torch.utils._pytree.tree_map_only( 13589 torch.Tensor, lambda x: cls(x), out_inner 13590 ) 13591 13592 def policy_fn(ctx, op, *args, **kwargs): 13593 if op is torch.ops.aten.sym_size.default: 13594 # Silently ignored! 13595 return CheckpointPolicy.MUST_SAVE 13596 else: 13597 return CheckpointPolicy.PREFER_RECOMPUTE 13598 13599 def fn(x): 13600 # We avoid the following case 13601 # 13602 # saved :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12] 13603 # forward :sum ,sum,mul, mul , mul ,view , view 13604 # recompute :sum ,sum,mul, view , view 13605 # 13606 # Views save the shape of their input, so we expect the second 13607 # view to save 12, but because during AC packing during forward 13608 # saves the shapes of the input for metadata checks later, 13609 # we would save the wrong shape during the recompute. 13610 view_out = (x * x.sum()).view(-1).view(4, 3) 13611 self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12]) 13612 return view_out.exp() 13613 13614 x = torch.randn(4, 3, requires_grad=True) 13615 x_wrapper = CustomSizeDynamicShapesTensor(x) 13616 context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13617 out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn) 13618 out.sum().backward() 13619 13620 def test_bad_inputs(self): 13621 bad_op_list1 = [2] 13622 13623 with self.assertRaisesRegex( 13624 ValueError, "Expected op in `op_list` to be an OpOverload" 13625 ): 13626 create_selective_checkpoint_contexts(bad_op_list1) 13627 13628 bad_op_list2 = [torch.ops.aten.sin] 13629 13630 with self.assertRaisesRegex( 13631 ValueError, "update the OpOverloadPacket to a specific OpOverload" 13632 ): 13633 create_selective_checkpoint_contexts(bad_op_list2) 13634 13635 with self.assertRaisesRegex(TypeError, "either a function or a list of ops."): 13636 create_selective_checkpoint_contexts(2) 13637 13638 # Dynamo fails for various reasons: 13639 # - some tests using custom op that does not implement Fake 13640 # - dynamo is trying to trace into saved variable hooks unpack hook for some reason 13641 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13642 def test_policy_with_state(self): 13643 # If I have a stateful callable, state is shared between the original 13644 # forward and the recompute. 13645 counters = [] 13646 13647 class Policy: 13648 def __init__(self) -> None: 13649 self.counter = [0] 13650 self.recompute_counter = [0] 13651 13652 def __call__(self, ctx, func, *args, **kwargs): 13653 counter = self.recompute_counter if ctx.is_recompute else self.counter 13654 counter[0] += 1 13655 counters.append(counter[0]) 13656 if counter == 1 and func is torch.ops.aten.mm.default: 13657 return CheckpointPolicy.MUST_SAVE 13658 return CheckpointPolicy.PREFER_RECOMPUTE 13659 13660 def fn(x): 13661 return x.sin().sin().sin() 13662 13663 x = torch.randn(3, requires_grad=True) 13664 context_fn = functools.partial( 13665 create_selective_checkpoint_contexts, 13666 Policy(), 13667 allow_cache_entry_mutation=True, 13668 ) 13669 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13670 out.sum().backward() 13671 # 1. counter properly reset to 0 for the recompute 13672 # 2. due to early-stop we do not recompute the final op 13673 self.assertEqual(counters, [1, 2, 3, 1, 2]) 13674 13675 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13676 def test_storage_lifetime(self): 13677 from torch.utils._python_dispatch import _get_current_dispatch_mode 13678 from torch.utils.checkpoint import ( 13679 _CachedTorchDispatchMode, 13680 _CachingTorchDispatchMode, 13681 ) 13682 13683 def policy_fn(ctx, op, *args, **kwargs): 13684 return CheckpointPolicy.MUST_SAVE 13685 13686 ref = None 13687 13688 def fn(x): 13689 nonlocal ref 13690 13691 self.assertIsInstance( 13692 _get_current_dispatch_mode(), 13693 (_CachingTorchDispatchMode, _CachedTorchDispatchMode), 13694 ) 13695 13696 out = x.cos().exp() 13697 13698 if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode): 13699 raw_val = ( 13700 _get_current_dispatch_mode() 13701 .storage[torch.ops.aten.exp.default][0] 13702 .val 13703 ) 13704 # ref should've been detached 13705 # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph 13706 self.assertFalse(raw_val.requires_grad) 13707 ref = weakref.ref(raw_val) 13708 13709 # Careful for early-stop 13710 return out.sin() 13711 13712 with disable_gc(): 13713 # Case 1: If graph goes away without backward, make sure there's no reference cycle 13714 # keeping storage alive. 13715 x = torch.randn(3, requires_grad=True) 13716 context_fn = functools.partial( 13717 create_selective_checkpoint_contexts, policy_fn 13718 ) 13719 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13720 self.assertIsNotNone(ref()) 13721 del out 13722 self.assertIsNone(ref()) 13723 13724 # Case 2: After backward, even if retain_graph=True, the storage should go away 13725 x = torch.randn(3, requires_grad=True) 13726 context_fn = functools.partial( 13727 create_selective_checkpoint_contexts, policy_fn 13728 ) 13729 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13730 self.assertIsNotNone(ref()) 13731 out.sum().backward(retain_graph=True) 13732 # The dispatch mode's storage should still be alive, but the entries should've 13733 # been cleared. 13734 self.assertIsNone(ref()) 13735 13736 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13737 def test_version_counter(self): 13738 def policy_fn(ctx, op, *args, **kwargs): 13739 if op == torch.ops.aten.sin.default: 13740 return CheckpointPolicy.MUST_SAVE 13741 else: 13742 return CheckpointPolicy.PREFER_RECOMPUTE 13743 13744 def fn(x): 13745 return x.sin().mul_(2).cos().exp() 13746 13747 x = torch.randn(3, requires_grad=True) 13748 context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13749 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13750 13751 # 1) Error because the output of sin is saved and mutated by mul_ 13752 with self.assertRaisesRegex(RuntimeError, "has been mutated"): 13753 out.sum().backward() 13754 13755 x = torch.randn(3, requires_grad=True) 13756 context_fn = functools.partial( 13757 create_selective_checkpoint_contexts, 13758 policy_fn, 13759 allow_cache_entry_mutation=True, 13760 ) 13761 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13762 13763 # 2) No longer should be an error because of allow_cache_entry_mutation 13764 out.sum().backward() 13765 13766 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13767 def test_function_with_more_than_one_output(self): 13768 # maybe there is a more systematic way: 13769 counter = [0] 13770 13771 def policy_fn(ctx, op, *args, **kwargs): 13772 if op == torch.ops.aten.var_mean.correction: 13773 counter[0] += 1 13774 return CheckpointPolicy.MUST_SAVE 13775 else: 13776 return CheckpointPolicy.PREFER_RECOMPUTE 13777 13778 # var_mean has two outputs 13779 def fn(x): 13780 a, b = torch.var_mean(x) 13781 return a * b 13782 13783 x = torch.randn(3, requires_grad=True) 13784 context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13785 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13786 x_grad = torch.autograd.grad(out.sum(), (x,)) 13787 x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) 13788 self.assertEqual(x_grad, x_grad_ref) 13789 self.assertEqual(counter[0], 2) 13790 13791 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13792 def test_function_with_non_tensor_output(self): 13793 # When SAC is enabled, the op is not computed a second time 13794 with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 13795 counter = [0] 13796 13797 @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) 13798 def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: 13799 counter[0] += 1 13800 return x.sin(), 2 13801 13802 def setup_context(ctx, inputs, output) -> torch.Tensor: 13803 (x,) = inputs 13804 ctx.save_for_backward(x) 13805 13806 def backward(ctx, grad, _unused): 13807 (x,) = ctx.saved_tensors 13808 return grad * x.cos() 13809 13810 torch.library.register_autograd( 13811 "mylib::sin_with_extra", backward, setup_context=setup_context 13812 ) 13813 13814 x = torch.randn(3, requires_grad=True) 13815 13816 def fn(x): 13817 return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin() 13818 13819 ops_list = [torch.ops.mylib.sin_with_extra.default] 13820 13821 x = torch.randn(3, requires_grad=True) 13822 context_fn = functools.partial( 13823 create_selective_checkpoint_contexts, ops_list 13824 ) 13825 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13826 x_grad = torch.autograd.grad(out.sum(), (x,)) 13827 self.assertEqual(counter[0], 1) 13828 x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) 13829 self.assertEqual(x_grad, x_grad_ref) 13830 13831 @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13832 def test_can_only_trigger_recompute_once(self): 13833 # We don't support this to avoid adding extra complexity for now. 13834 # If there's a need, we could probably do some kind of use_count tracking. 13835 # TODO: have a nice error message here. 13836 def policy_fn(ctx, op, *args, **kwargs): 13837 if op == torch.ops.aten.sin.default: 13838 return CheckpointPolicy.MUST_SAVE 13839 else: 13840 return CheckpointPolicy.PREFER_RECOMPUTE 13841 13842 def fn(x): 13843 return x.sin().cos().exp() 13844 13845 x = torch.randn(3, requires_grad=True) 13846 context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13847 out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13848 out.sum().backward(retain_graph=True) 13849 13850 with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): 13851 out.sum().backward(retain_graph=True) 13852 13853 13854class TestAutogradMultipleDispatch(TestCase): 13855 def test_autograd_multiple_dispatch_registrations(self, device): 13856 t = torch.randn(3, 3, device=device, requires_grad=True) 13857 # using _test_autograd_multiple_dispatch.fullcoverage which has 13858 # registrations in derivatives.yaml for Default, AutogradCUDA and NestedTensorAutograd 13859 out = torch._test_autograd_multiple_dispatch(t) 13860 grad = torch.randn(3, 3, device=device) 13861 out.backward(grad) 13862 13863 if "cuda" not in device: 13864 # bogus default gradient registered for Autograd is grad + 1 13865 self.assertEqual(t.grad, grad + 1) 13866 else: 13867 # bogus gradient registered for AutogradCUDA is grad * 2 13868 self.assertEqual(t.grad, grad * 2) 13869 13870 # test registered AutogradNestedTensor formula 13871 a = ( 13872 torch.arange(6, dtype=torch.float, device=device) 13873 .reshape(2, 3) 13874 .requires_grad_(True) 13875 ) 13876 b = ( 13877 torch.arange(8, dtype=torch.float, device=device) 13878 .reshape(2, 4) 13879 .requires_grad_(True) 13880 ) 13881 nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device) 13882 13883 nt_out = torch._test_autograd_multiple_dispatch(nt) 13884 c = torch.randn(2, 3, device=device) 13885 d = torch.randn(2, 4, device=device) 13886 nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device) 13887 nt_out.backward(nt_grad) 13888 13889 # bogus gradient for AutogradNestedTensor is grad * grad 13890 self.assertEqual(a.grad, c * c) 13891 self.assertEqual(b.grad, d * d) 13892 13893 def test_autograd_composite_implicit_and_dispatch_registration(self, device): 13894 t = torch.randn(3, 3, device=device, requires_grad=True) 13895 # using _test_autograd_multiple_dispatch.ntonly 13896 # which has registrations in derivatives.yaml for NestedTensorAutograd and otherwise is CompositeImplicit 13897 out = torch._test_autograd_multiple_dispatch(t, True) 13898 grad = torch.randn(3, 3, device=device) 13899 out.backward(grad) 13900 13901 # t.grad is just out.grad by composite op since _test_autograd_multiple_dispatch is just a clone 13902 self.assertEqual(t.grad, grad) 13903 13904 # test registered AutogradNestedTensor formula 13905 a = ( 13906 torch.arange(6, dtype=torch.float, device=device) 13907 .reshape(2, 3) 13908 .requires_grad_(True) 13909 ) 13910 b = ( 13911 torch.arange(8, dtype=torch.float, device=device) 13912 .reshape(2, 4) 13913 .requires_grad_(True) 13914 ) 13915 nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device) 13916 13917 nt_out = torch._test_autograd_multiple_dispatch(nt, True) 13918 c = torch.randn(2, 3, device=device) 13919 d = torch.randn(2, 4, device=device) 13920 nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device) 13921 nt_out.backward(nt_grad) 13922 13923 # bogus gradient for AutogradNestedTensor is grad * grad + grad 13924 self.assertEqual(a.grad, c * c + c) 13925 self.assertEqual(b.grad, d * d + d) 13926 13927 def test_foward_mode_AD(self, device): 13928 # check that forward mode AD is only registered for the Default 13929 # dispatch for _test_autograd_multiple_dispatch.fullcoverage and not AutogradCUDA 13930 13931 primal = torch.randn(3, device=device) 13932 tangent = torch.randn(3, device=device) 13933 13934 with fwAD.dual_level(): 13935 dual_input = fwAD.make_dual(primal, tangent) 13936 13937 err_msg = r"Trying to use forward AD with .* that does not support it" 13938 hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError" 13939 13940 if "cuda" in device: 13941 with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): 13942 torch._test_autograd_multiple_dispatch(dual_input) 13943 else: 13944 torch._test_autograd_multiple_dispatch(dual_input) 13945 13946 def test_view_copy(self, device): 13947 # tests that view_copy derivative formulas are also generated per dispatch key 13948 # from their respective view ops in derivatives.yaml 13949 t = torch.randn(2, 2, device=device, requires_grad=True) 13950 t_ref = t.clone().detach().requires_grad_() 13951 # _test_autograd_multiple_dispatch_view does a .view(-1) on the input 13952 t_view = torch._test_autograd_multiple_dispatch_view(t_ref) 13953 t_view_copy = torch._test_autograd_multiple_dispatch_view_copy(t) 13954 13955 grad = torch.randn(4, device=device) 13956 t_view_copy.backward(grad) 13957 t_view.backward(grad.clone()) 13958 13959 # forward and backward give the same shape + result 13960 self.assertEqual(t_view_copy, t_view) 13961 self.assertEqual(t.grad, t_ref.grad) 13962 # backward results are per-dispatch-key in derivatives.yaml 13963 if "cuda" in device: 13964 # gradient registered to AutogradCUDA is grad.reshape_as(self) + 1 13965 self.assertEqual(t.grad, grad.reshape_as(t) + 1) 13966 else: 13967 # Default gradient registered is grad.reshape_as(self) 13968 self.assertEqual(t.grad, grad.reshape_as(t)) 13969 13970 @onlyCPU 13971 def test_per_dispatch_key_input_saving(self, device): 13972 # Tests that sum.dim_IntList's input is not saved for regular tensors but is saved for nested tensors 13973 def foo(x): 13974 # Don't modify the input inplace 13975 x = x.clone() 13976 res = x.sum(-1, keepdim=True) 13977 x.add_(x) 13978 return res 13979 13980 inp = torch.rand(2, device=device, requires_grad=True) 13981 # sum's input is not saved for regular Tensors 13982 foo(inp).backward() 13983 13984 # sum's input is saved for Nested Tensors 13985 nt = torch.nested.nested_tensor( 13986 [torch.rand(2), torch.rand(2)], device=device, requires_grad=True 13987 ) 13988 with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): 13989 foo(nt).backward( 13990 torch.nested.nested_tensor( 13991 [torch.rand(1), torch.rand(1)], device=device 13992 ) 13993 ) 13994 13995 @onlyCUDA 13996 def test_backward_single_threaded(self): 13997 threads_eq = None 13998 13999 class TestFn(Function): 14000 @staticmethod 14001 def forward(ctx, x, self): 14002 ctx.self = self 14003 ctx.tid = threading.get_ident() 14004 return x.clone() 14005 14006 @staticmethod 14007 def backward(ctx, gO): 14008 nonlocal threads_eq 14009 threads_eq = ctx.tid == threading.get_ident() 14010 return gO, None 14011 14012 inp = torch.rand(10, device="cuda", requires_grad=True) 14013 14014 with torch.autograd.set_multithreading_enabled(False): 14015 TestFn.apply(inp, None).sum().backward() 14016 self.assertTrue(threads_eq) 14017 14018 TestFn.apply(inp, None).sum().backward() 14019 self.assertFalse(threads_eq) 14020 14021 @onlyCUDA 14022 def test_backward_tls_stash(self): 14023 local = threading.local() 14024 local.my_obj = {} 14025 local.my_obj[10] = 10 14026 test_self = self 14027 torch._C._stash_obj_in_tls("my_obj", local.my_obj) 14028 14029 class TestFn(Function): 14030 @staticmethod 14031 def forward(ctx, x, self): 14032 return x.clone() 14033 14034 @staticmethod 14035 def backward(ctx, gO): 14036 test_self.assertTrue(torch._C._is_key_in_tls("my_obj")) 14037 test_self.assertTrue(torch._C._get_obj_in_tls("my_obj")[10] == 10) 14038 torch._C._get_obj_in_tls("my_obj")[10] = 5 14039 return gO, None 14040 14041 inp = torch.rand(10, device="cuda", requires_grad=True) 14042 14043 TestFn.apply(inp, None).sum().backward() 14044 self.assertEqual(local.my_obj[10], 5) 14045 14046 def test_is_retain_graph(self): 14047 retain_graph_set = False 14048 14049 class TestFn(Function): 14050 @staticmethod 14051 def forward(ctx, x): 14052 return x.clone() 14053 14054 @staticmethod 14055 def backward(ctx, gO): 14056 nonlocal retain_graph_set 14057 retain_graph_set = ( 14058 torch._C._autograd._get_current_graph_task_keep_graph() 14059 ) 14060 return gO, None 14061 14062 inp = torch.rand(10, requires_grad=True) 14063 14064 out = TestFn.apply(inp) 14065 self.assertFalse(retain_graph_set) 14066 out.sum().backward(retain_graph=True) 14067 self.assertTrue(retain_graph_set) 14068 out.sum().backward(retain_graph=False) 14069 self.assertFalse(retain_graph_set) 14070 14071 def test_set_sequence_nr(self): 14072 x = torch.randn((10,), dtype=torch.float32, requires_grad=True) 14073 y = torch.randn((10,), dtype=torch.float32, requires_grad=True) 14074 z = torch.randn((10,), dtype=torch.float32, requires_grad=True) 14075 14076 a = x + y 14077 b = y + z 14078 c = a + b 14079 14080 self.assertIsNotNone(a.grad_fn) 14081 self.assertIsNotNone(b.grad_fn) 14082 self.assertIsNotNone(c.grad_fn) 14083 14084 a.grad_fn._set_sequence_nr(100) 14085 b.grad_fn._set_sequence_nr(99) 14086 c.grad_fn._set_sequence_nr(98) 14087 14088 self.assertEqual(a.grad_fn._sequence_nr(), 100) 14089 self.assertEqual(b.grad_fn._sequence_nr(), 99) 14090 self.assertEqual(c.grad_fn._sequence_nr(), 98) 14091 14092 def log_grad_order(grad: torch.Tensor, name: str, order): 14093 order.append(name) 14094 return grad 14095 14096 order = [] 14097 a.register_hook(partial(log_grad_order, name="a", order=order)) 14098 b.register_hook(partial(log_grad_order, name="b", order=order)) 14099 c.register_hook(partial(log_grad_order, name="c", order=order)) 14100 14101 c.sum().backward() 14102 14103 # Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd. 14104 # Also check that although a comes first during the forward, after giving it priority with sequence_nr, 14105 # its autograd node is run before that of b. 14106 self.assertEqual(order, ["c", "a", "b"]) 14107 14108 self.assertEqual(x.grad, torch.ones_like(x)) 14109 self.assertEqual(y.grad, 2 * torch.ones_like(x)) 14110 self.assertEqual(z.grad, torch.ones_like(x)) 14111 14112 14113# Import test cases from below autograd/ here. These are found 14114# implicitly by the loader, so Flake8 thinks they are unused, hence 14115# the suppressions. 14116 14117from autograd.test_complex import TestAutogradComplex # noqa: F401 14118from autograd.test_functional import TestAutogradFunctional # noqa: F401 14119from autograd.test_logging import TestAutogradLogging # noqa: F401 14120 14121 14122# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA 14123instantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None) 14124 14125instantiate_device_type_tests( 14126 TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda") 14127) 14128 14129instantiate_parametrized_tests(TestAutograd) 14130instantiate_parametrized_tests(TestNestedCheckpoint) 14131 14132if __name__ == "__main__": 14133 run_tests() 14134