1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: autograd"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport collections 4*da0073e9SAndroid Build Coastguard Workerimport contextlib 5*da0073e9SAndroid Build Coastguard Workerimport functools 6*da0073e9SAndroid Build Coastguard Workerimport gc 7*da0073e9SAndroid Build Coastguard Workerimport io 8*da0073e9SAndroid Build Coastguard Workerimport math 9*da0073e9SAndroid Build Coastguard Workerimport operator 10*da0073e9SAndroid Build Coastguard Workerimport os 11*da0073e9SAndroid Build Coastguard Workerimport pickle 12*da0073e9SAndroid Build Coastguard Workerimport random 13*da0073e9SAndroid Build Coastguard Workerimport subprocess 14*da0073e9SAndroid Build Coastguard Workerimport sys 15*da0073e9SAndroid Build Coastguard Workerimport tempfile 16*da0073e9SAndroid Build Coastguard Workerimport threading 17*da0073e9SAndroid Build Coastguard Workerimport time 18*da0073e9SAndroid Build Coastguard Workerimport unittest 19*da0073e9SAndroid Build Coastguard Workerimport uuid 20*da0073e9SAndroid Build Coastguard Workerimport warnings 21*da0073e9SAndroid Build Coastguard Workerimport weakref 22*da0073e9SAndroid Build Coastguard Workerfrom collections import OrderedDict 23*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 24*da0073e9SAndroid Build Coastguard Workerfrom functools import partial, reduce 25*da0073e9SAndroid Build Coastguard Workerfrom itertools import product 26*da0073e9SAndroid Build Coastguard Workerfrom operator import mul 27*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Tuple, TYPE_CHECKING 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerimport torch 30*da0073e9SAndroid Build Coastguard Workerimport torch.autograd._functions 31*da0073e9SAndroid Build Coastguard Workerimport torch.autograd.forward_ad as fwAD 32*da0073e9SAndroid Build Coastguard Workerfrom torch import inf, nan, nn 33*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import ( 34*da0073e9SAndroid Build Coastguard Worker _calculate_shape, 35*da0073e9SAndroid Build Coastguard Worker detect_anomaly, 36*da0073e9SAndroid Build Coastguard Worker Function, 37*da0073e9SAndroid Build Coastguard Worker kineto_available, 38*da0073e9SAndroid Build Coastguard Worker Variable, 39*da0073e9SAndroid Build Coastguard Worker) 40*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.function import InplaceFunction, once_differentiable 41*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.graph import GradientEdge 42*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler import emit_itt, emit_nvtx, profile, record_function 43*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd.profiler_util import ( 44*da0073e9SAndroid Build Coastguard Worker _format_time, 45*da0073e9SAndroid Build Coastguard Worker EventList, 46*da0073e9SAndroid Build Coastguard Worker FunctionEvent, 47*da0073e9SAndroid Build Coastguard Worker FunctionEventAvg, 48*da0073e9SAndroid Build Coastguard Worker) 49*da0073e9SAndroid Build Coastguard Workerfrom torch.testing import make_tensor 50*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_cuda import TEST_CUDA 51*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 52*da0073e9SAndroid Build Coastguard Worker deviceCountAtLeast, 53*da0073e9SAndroid Build Coastguard Worker dtypes, 54*da0073e9SAndroid Build Coastguard Worker dtypesIfCUDA, 55*da0073e9SAndroid Build Coastguard Worker dtypesIfMPS, 56*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 57*da0073e9SAndroid Build Coastguard Worker onlyCPU, 58*da0073e9SAndroid Build Coastguard Worker onlyCUDA, 59*da0073e9SAndroid Build Coastguard Worker skipMeta, 60*da0073e9SAndroid Build Coastguard Worker) 61*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_dtype import floating_types_and 62*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import mask_not_all_zeros 63*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 64*da0073e9SAndroid Build Coastguard Worker disable_gc, 65*da0073e9SAndroid Build Coastguard Worker gradcheck, 66*da0073e9SAndroid Build Coastguard Worker gradgradcheck, 67*da0073e9SAndroid Build Coastguard Worker instantiate_parametrized_tests, 68*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 69*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 70*da0073e9SAndroid Build Coastguard Worker parametrize, 71*da0073e9SAndroid Build Coastguard Worker run_tests, 72*da0073e9SAndroid Build Coastguard Worker set_warn_always_context, 73*da0073e9SAndroid Build Coastguard Worker skipIfMps, 74*da0073e9SAndroid Build Coastguard Worker skipIfNoLapack, 75*da0073e9SAndroid Build Coastguard Worker skipIfTorchDynamo, 76*da0073e9SAndroid Build Coastguard Worker slowTest, 77*da0073e9SAndroid Build Coastguard Worker TestCase, 78*da0073e9SAndroid Build Coastguard Worker xfailIfTorchDynamo, 79*da0073e9SAndroid Build Coastguard Worker) 80*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._mode_utils import no_dispatch 81*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import TorchDispatchMode 82*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.checkpoint import ( 83*da0073e9SAndroid Build Coastguard Worker checkpoint, 84*da0073e9SAndroid Build Coastguard Worker checkpoint_sequential, 85*da0073e9SAndroid Build Coastguard Worker CheckpointPolicy, 86*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 87*da0073e9SAndroid Build Coastguard Worker) 88*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.cpp_extension import load_inline 89*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.flop_counter import FlopCounterMode 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 93*da0073e9SAndroid Build Coastguard Worker from torch.utils.hooks import RemovableHandle 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Workerdef graph_desc(fn): 97*da0073e9SAndroid Build Coastguard Worker if fn is None: 98*da0073e9SAndroid Build Coastguard Worker return "None" 99*da0073e9SAndroid Build Coastguard Worker result = type(fn).__name__ + "(" 100*da0073e9SAndroid Build Coastguard Worker next_functions = fn.next_functions 101*da0073e9SAndroid Build Coastguard Worker for next_fn, _ in next_functions: 102*da0073e9SAndroid Build Coastguard Worker result += graph_desc(next_fn) 103*da0073e9SAndroid Build Coastguard Worker result += ", " 104*da0073e9SAndroid Build Coastguard Worker if next_functions: 105*da0073e9SAndroid Build Coastguard Worker result = result[:-2] 106*da0073e9SAndroid Build Coastguard Worker return result + ")" 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Workerclass TestAutograd(TestCase): 110*da0073e9SAndroid Build Coastguard Worker def test_copy_slices_graph_task_updates(self): 111*da0073e9SAndroid Build Coastguard Worker def f1(x, y): 112*da0073e9SAndroid Build Coastguard Worker out = x.clone().view(-1) 113*da0073e9SAndroid Build Coastguard Worker out += y 114*da0073e9SAndroid Build Coastguard Worker return out 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def f2(x, y): 117*da0073e9SAndroid Build Coastguard Worker out = x.clone().view(-1) 118*da0073e9SAndroid Build Coastguard Worker b = out * 2 119*da0073e9SAndroid Build Coastguard Worker out += y 120*da0073e9SAndroid Build Coastguard Worker return out + b 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True) 123*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, requires_grad=True) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker y_safe = torch._C._functions.DelayedError("Boom!", 1)(y) 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker for f in [f1, f2]: 128*da0073e9SAndroid Build Coastguard Worker # Ensure that the error Node works 129*da0073e9SAndroid Build Coastguard Worker out = f(x, y_safe) 130*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boom!"): 131*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker out = f(x, y_safe) 134*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Boom!"): 135*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), y) 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker # Ensure that if we don't ask for y, it doesn't crash 138*da0073e9SAndroid Build Coastguard Worker out = f(x, y_safe) 139*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), x) 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker out = f(x, y_safe) 142*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), y_safe) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker out = f(x, y_safe) 145*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), (x, y_safe)) 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker # Ensure that we don't run extra view Node 148*da0073e9SAndroid Build Coastguard Worker def f3(x, y): 149*da0073e9SAndroid Build Coastguard Worker out = x.clone().view(-1) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def hook(*args): 152*da0073e9SAndroid Build Coastguard Worker # This should never be called! 153*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker out.register_hook(hook) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker b = out + y 158*da0073e9SAndroid Build Coastguard Worker out += y 159*da0073e9SAndroid Build Coastguard Worker return out + b, b 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker out, b = f3(x, y_safe) 162*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), (b, y_safe)) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker def test_grad_mode_class_decoration(self): 165*da0073e9SAndroid Build Coastguard Worker # Decorating class is deprecated and should not be used 166*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(FutureWarning, "Decorating classes is deprecated"): 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 169*da0073e9SAndroid Build Coastguard Worker class Foo: 170*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 171*da0073e9SAndroid Build Coastguard Worker assert not torch.is_grad_enabled() 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker def foo(self): 174*da0073e9SAndroid Build Coastguard Worker # Not applied to methods 175*da0073e9SAndroid Build Coastguard Worker assert torch.is_grad_enabled() 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker # Show that we can actually construct the class 178*da0073e9SAndroid Build Coastguard Worker foo = Foo() 179*da0073e9SAndroid Build Coastguard Worker foo.foo() 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker # Decorating functions or methods is fine though 182*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 185*da0073e9SAndroid Build Coastguard Worker def foo(): 186*da0073e9SAndroid Build Coastguard Worker assert not torch.is_grad_enabled() 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker foo() 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker class Foo2: 191*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 192*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 193*da0073e9SAndroid Build Coastguard Worker assert not torch.is_grad_enabled() 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 196*da0073e9SAndroid Build Coastguard Worker def foo(self): 197*da0073e9SAndroid Build Coastguard Worker assert not torch.is_grad_enabled() 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker foo2 = Foo2() 200*da0073e9SAndroid Build Coastguard Worker foo2.foo() 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def test_tensor_grad_warnings(self): 205*da0073e9SAndroid Build Coastguard Worker dummy = torch.empty(1) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 208*da0073e9SAndroid Build Coastguard Worker # Accessing .grad on leaf 209*da0073e9SAndroid Build Coastguard Worker dummy.requires_grad_() 210*da0073e9SAndroid Build Coastguard Worker foo = dummy.grad 211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 0) 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker # Accessing .grad on non-leaf 214*da0073e9SAndroid Build Coastguard Worker dummy = dummy.clone() 215*da0073e9SAndroid Build Coastguard Worker foo = dummy.grad 216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker # Accessing .grad on non-leaf that retains gradients 219*da0073e9SAndroid Build Coastguard Worker dummy.retain_grad() 220*da0073e9SAndroid Build Coastguard Worker foo = dummy.grad 221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker def _function_test(self, cls): 224*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 225*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 226*da0073e9SAndroid Build Coastguard Worker result = cls.apply(x, 2, y) 227*da0073e9SAndroid Build Coastguard Worker go = torch.ones((), requires_grad=True) 228*da0073e9SAndroid Build Coastguard Worker result.sum().backward(go, create_graph=True) 229*da0073e9SAndroid Build Coastguard Worker 230*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y + torch.ones(5, 5)) 231*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, x + torch.ones(5, 5) * 2) 232*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(x.grad.grad_fn) 233*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(y.grad.grad_fn) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker return x, y 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker def test_function(self): 238*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 239*da0073e9SAndroid Build Coastguard Worker @staticmethod 240*da0073e9SAndroid Build Coastguard Worker def forward(ctx, tensor1, pyscalar, tensor2): 241*da0073e9SAndroid Build Coastguard Worker ctx.pyscalar = pyscalar 242*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(tensor1, tensor2) 243*da0073e9SAndroid Build Coastguard Worker return tensor1 + pyscalar * tensor2 + tensor1 * tensor2 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker @staticmethod 246*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 247*da0073e9SAndroid Build Coastguard Worker var1, var2 = ctx.saved_tensors 248*da0073e9SAndroid Build Coastguard Worker # NOTE: self is the test case here 249*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(var1, torch.Tensor) 250*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(var2, torch.Tensor) 251*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(grad_output, torch.Tensor) 252*da0073e9SAndroid Build Coastguard Worker return ( 253*da0073e9SAndroid Build Coastguard Worker grad_output + grad_output * var2, 254*da0073e9SAndroid Build Coastguard Worker None, 255*da0073e9SAndroid Build Coastguard Worker grad_output * ctx.pyscalar + grad_output * var1, 256*da0073e9SAndroid Build Coastguard Worker ) 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker x, y = self._function_test(MyFunction) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker x_grad_desc = graph_desc(x.grad.grad_fn) 261*da0073e9SAndroid Build Coastguard Worker y_grad_desc = graph_desc(y.grad.grad_fn) 262*da0073e9SAndroid Build Coastguard Worker self.assertExpected(x_grad_desc, "x_grad_desc") 263*da0073e9SAndroid Build Coastguard Worker self.assertExpected(y_grad_desc, "y_grad_desc") 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker def test_once_differentiable(self): 266*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 267*da0073e9SAndroid Build Coastguard Worker @staticmethod 268*da0073e9SAndroid Build Coastguard Worker def forward(ctx, tensor1, pyscalar, tensor2): 269*da0073e9SAndroid Build Coastguard Worker ctx.pyscalar = pyscalar 270*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(tensor1, tensor2) 271*da0073e9SAndroid Build Coastguard Worker return tensor1 + pyscalar * tensor2 + tensor1 * tensor2 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker @staticmethod 274*da0073e9SAndroid Build Coastguard Worker @once_differentiable 275*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 276*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 277*da0073e9SAndroid Build Coastguard Worker t1, t2 = ctx.saved_tensors 278*da0073e9SAndroid Build Coastguard Worker return ( 279*da0073e9SAndroid Build Coastguard Worker grad_output + grad_output * t2, 280*da0073e9SAndroid Build Coastguard Worker None, 281*da0073e9SAndroid Build Coastguard Worker grad_output * ctx.pyscalar + grad_output * t1, 282*da0073e9SAndroid Build Coastguard Worker ) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker x, y = self._function_test(MyFunction) 285*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 286*da0073e9SAndroid Build Coastguard Worker graph_desc(x.grad.grad_fn), 287*da0073e9SAndroid Build Coastguard Worker "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))", 288*da0073e9SAndroid Build Coastguard Worker ) 289*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 290*da0073e9SAndroid Build Coastguard Worker graph_desc(y.grad.grad_fn), 291*da0073e9SAndroid Build Coastguard Worker "CopyBackwards(None, Error(AccumulateGrad(), None, AccumulateGrad()))", 292*da0073e9SAndroid Build Coastguard Worker ) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker def test_function_returns_input(self): 295*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 296*da0073e9SAndroid Build Coastguard Worker @staticmethod 297*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 298*da0073e9SAndroid Build Coastguard Worker return x 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker @staticmethod 301*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 302*da0073e9SAndroid Build Coastguard Worker return grad * 2 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker for shape in [(1,), ()]: 305*da0073e9SAndroid Build Coastguard Worker v = torch.ones(shape, requires_grad=True) 306*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(v).backward() 307*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.full(shape, 2.0)) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 310*da0073e9SAndroid Build Coastguard Worker v.grad.zero_() 311*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(v.clone()).backward() 312*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.full(shape, 2.0)) 313*da0073e9SAndroid Build Coastguard Worker 314*da0073e9SAndroid Build Coastguard Worker def test_function_returns_undefined_tensor(self): 315*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 316*da0073e9SAndroid Build Coastguard Worker @staticmethod 317*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 318*da0073e9SAndroid Build Coastguard Worker return x * 2 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker @staticmethod 321*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 322*da0073e9SAndroid Build Coastguard Worker return None 323*da0073e9SAndroid Build Coastguard Worker 324*da0073e9SAndroid Build Coastguard Worker # Test that undefined tensors returned from custom backward function 325*da0073e9SAndroid Build Coastguard Worker # are propagated as undefined and not tensor full of zeroes 326*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(x).backward() 329*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad) 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(x**2).backward() 332*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad) 333*da0073e9SAndroid Build Coastguard Worker 334*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(x).sum().backward() 335*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad) 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker self.assertIsNone( 338*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(MyFunction.apply(x), x, allow_unused=True)[0] 339*da0073e9SAndroid Build Coastguard Worker ) 340*da0073e9SAndroid Build Coastguard Worker 341*da0073e9SAndroid Build Coastguard Worker def test_materialize_grads(self): 342*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 343*da0073e9SAndroid Build Coastguard Worker @staticmethod 344*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 345*da0073e9SAndroid Build Coastguard Worker return x 346*da0073e9SAndroid Build Coastguard Worker 347*da0073e9SAndroid Build Coastguard Worker @staticmethod 348*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 349*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.zeros(1)) 350*da0073e9SAndroid Build Coastguard Worker return grad 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 353*da0073e9SAndroid Build Coastguard Worker torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward() 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker def test_dont_materialize_grads(self): 356*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 357*da0073e9SAndroid Build Coastguard Worker @staticmethod 358*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 359*da0073e9SAndroid Build Coastguard Worker ctx.set_materialize_grads(False) 360*da0073e9SAndroid Build Coastguard Worker return x 361*da0073e9SAndroid Build Coastguard Worker 362*da0073e9SAndroid Build Coastguard Worker @staticmethod 363*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 364*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grad) 365*da0073e9SAndroid Build Coastguard Worker return grad 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 368*da0073e9SAndroid Build Coastguard Worker torch._C._functions.UndefinedGrad()(MyFunction.apply(x)).backward() 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 371*da0073e9SAndroid Build Coastguard Worker def test_set_materialize_non_diff_grads(self): 372*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 373*da0073e9SAndroid Build Coastguard Worker @staticmethod 374*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 375*da0073e9SAndroid Build Coastguard Worker out0 = x.clone() 376*da0073e9SAndroid Build Coastguard Worker out1 = x.clone() 377*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(out1) 378*da0073e9SAndroid Build Coastguard Worker ctx._materialize_non_diff_grads = False 379*da0073e9SAndroid Build Coastguard Worker return out0, out1 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker @staticmethod 382*da0073e9SAndroid Build Coastguard Worker def backward(ctx, g0, g1): 383*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(g1) 384*da0073e9SAndroid Build Coastguard Worker return g0 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 387*da0073e9SAndroid Build Coastguard Worker out = Func.apply(a)[0] 388*da0073e9SAndroid Build Coastguard Worker out.backward() 389*da0073e9SAndroid Build Coastguard Worker 390*da0073e9SAndroid Build Coastguard Worker def test_legacy_function_deprecation_exception(self): 391*da0073e9SAndroid Build Coastguard Worker # Trigger exception 392*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 393*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 394*da0073e9SAndroid Build Coastguard Worker return x 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker def backward(self, grad_output): 397*da0073e9SAndroid Build Coastguard Worker return grad_output 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Worker # Check exception occurs 400*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 401*da0073e9SAndroid Build Coastguard Worker RuntimeError, 402*da0073e9SAndroid Build Coastguard Worker "Legacy autograd function with non-static forward method is deprecated", 403*da0073e9SAndroid Build Coastguard Worker ): 404*da0073e9SAndroid Build Coastguard Worker MyFunction()(torch.randn(3, 4)) 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker class SimulateBackwardError(Function): 407*da0073e9SAndroid Build Coastguard Worker @staticmethod 408*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 409*da0073e9SAndroid Build Coastguard Worker return input.clone() 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker @staticmethod 412*da0073e9SAndroid Build Coastguard Worker @once_differentiable 413*da0073e9SAndroid Build Coastguard Worker def backward(ctx, input): 414*da0073e9SAndroid Build Coastguard Worker raise Exception("Simulate error on backward pass") # noqa: TRY002 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker def test_custom_function_exception(self): 417*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand((3, 3), requires_grad=True) 418*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand((3, 3), requires_grad=True) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker tmp = (t1 + t2) * (t1 + t2) 421*da0073e9SAndroid Build Coastguard Worker t3 = TestAutograd.SimulateBackwardError.apply(tmp) 422*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Simulate error on backward pass"): 423*da0073e9SAndroid Build Coastguard Worker t3.sum().backward() 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker def test_custom_function_non_tensor_inputs_outputs(self): 426*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 427*da0073e9SAndroid Build Coastguard Worker @staticmethod 428*da0073e9SAndroid Build Coastguard Worker def forward(ctx, t1, t2, scale, t3): 429*da0073e9SAndroid Build Coastguard Worker t4 = t1 + t2 * t3 430*da0073e9SAndroid Build Coastguard Worker t5 = t1 * t2 + t3 431*da0073e9SAndroid Build Coastguard Worker t4 *= scale 432*da0073e9SAndroid Build Coastguard Worker t5 *= scale 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker # Save scale 435*da0073e9SAndroid Build Coastguard Worker ctx.scale = scale 436*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(t1, t2, t3) 437*da0073e9SAndroid Build Coastguard Worker return scale, t4, None, True, t5, "bar", t1 438*da0073e9SAndroid Build Coastguard Worker 439*da0073e9SAndroid Build Coastguard Worker @staticmethod 440*da0073e9SAndroid Build Coastguard Worker @once_differentiable 441*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *grads): 442*da0073e9SAndroid Build Coastguard Worker # Verify grads 443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(7, len(grads)) 444*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grads[0]) 445*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grads[2]) 446*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grads[3]) 447*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grads[5]) 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker scale = ctx.scale 450*da0073e9SAndroid Build Coastguard Worker var1, var2, var3 = ctx.saved_tensors 451*da0073e9SAndroid Build Coastguard Worker return ( 452*da0073e9SAndroid Build Coastguard Worker grads[1] * scale + grads[4] * var2 * scale + grads[6], 453*da0073e9SAndroid Build Coastguard Worker grads[1] * var3 * scale + grads[4] * var1 * scale, 454*da0073e9SAndroid Build Coastguard Worker None, 455*da0073e9SAndroid Build Coastguard Worker grads[1] * var2 * scale + grads[4] * scale, 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(10, dtype=torch.double, requires_grad=True) 459*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(10, dtype=torch.double, requires_grad=True) 460*da0073e9SAndroid Build Coastguard Worker t3 = torch.rand(10, dtype=torch.double) 461*da0073e9SAndroid Build Coastguard Worker scale = random.randint(0, 10) 462*da0073e9SAndroid Build Coastguard Worker res = MyFunction.apply(t1, t2, scale, t3) 463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, res[0]) 464*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 + t2 * t3) * scale, res[1]) 465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, res[2]) 466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(True, res[3]) 467*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 * t2 + t3) * scale, res[4]) 468*da0073e9SAndroid Build Coastguard Worker self.assertEqual("bar", res[5]) 469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, res[6]) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker # Validate running backward. 472*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) 473*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(t1.grad) 474*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(t2.grad) 475*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t3.grad) 476*da0073e9SAndroid Build Coastguard Worker 477*da0073e9SAndroid Build Coastguard Worker # Test gradcheck 478*da0073e9SAndroid Build Coastguard Worker def foo(t1, t2, t3): 479*da0073e9SAndroid Build Coastguard Worker res = MyFunction.apply(t1, t2, scale, t3) 480*da0073e9SAndroid Build Coastguard Worker return res[1], res[4], res[6] 481*da0073e9SAndroid Build Coastguard Worker 482*da0073e9SAndroid Build Coastguard Worker gradcheck(foo, (t1, t2, t3)) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker def test_custom_function_no_tensors(self): 485*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 486*da0073e9SAndroid Build Coastguard Worker @staticmethod 487*da0073e9SAndroid Build Coastguard Worker def forward(ctx, t1, t2, scale, t3): 488*da0073e9SAndroid Build Coastguard Worker t4 = t1 + t2 * t3 489*da0073e9SAndroid Build Coastguard Worker t5 = t1 * t2 + t3 490*da0073e9SAndroid Build Coastguard Worker t4 *= scale 491*da0073e9SAndroid Build Coastguard Worker t5 *= scale 492*da0073e9SAndroid Build Coastguard Worker return scale, t4, None, True, t5, "bar", t1 493*da0073e9SAndroid Build Coastguard Worker 494*da0073e9SAndroid Build Coastguard Worker @staticmethod 495*da0073e9SAndroid Build Coastguard Worker @once_differentiable 496*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *args): 497*da0073e9SAndroid Build Coastguard Worker return (args[0], args[1], None, args[2]) 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker t1 = random.random() 500*da0073e9SAndroid Build Coastguard Worker t2 = random.random() 501*da0073e9SAndroid Build Coastguard Worker t3 = random.random() 502*da0073e9SAndroid Build Coastguard Worker scale = random.randint(0, 10) 503*da0073e9SAndroid Build Coastguard Worker res = MyFunction.apply(t1, t2, scale, t3) 504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(scale, res[0]) 505*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 + t2 * t3) * scale, res[1]) 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(None, res[2]) 507*da0073e9SAndroid Build Coastguard Worker self.assertEqual(True, res[3]) 508*da0073e9SAndroid Build Coastguard Worker self.assertEqual((t1 * t2 + t3) * scale, res[4]) 509*da0073e9SAndroid Build Coastguard Worker self.assertEqual("bar", res[5]) 510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t1, res[6]) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker def test_invalid_gradients(self): 513*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 514*da0073e9SAndroid Build Coastguard Worker @staticmethod 515*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 516*da0073e9SAndroid Build Coastguard Worker return x * 2 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker @staticmethod 519*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 520*da0073e9SAndroid Build Coastguard Worker return torch.randn(10, dtype=torch.float) 521*da0073e9SAndroid Build Coastguard Worker 522*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expected shape"): 523*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 5, dtype=torch.float, requires_grad=True) 524*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(input).sum().backward() 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker def test_unrelated_inputs(self): 527*da0073e9SAndroid Build Coastguard Worker # test to ensure grad(grad)check runs successfully even if there is an 528*da0073e9SAndroid Build Coastguard Worker # unrelated (but differentiable) inputs 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker def my_function(x, y): 531*da0073e9SAndroid Build Coastguard Worker return x * x 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, dtype=torch.double, requires_grad=True) 534*da0073e9SAndroid Build Coastguard Worker y = torch.rand(10, dtype=torch.double, requires_grad=True) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker gradcheck(my_function, (x, y)) 537*da0073e9SAndroid Build Coastguard Worker gradgradcheck(my_function, (x, y)) 538*da0073e9SAndroid Build Coastguard Worker 539*da0073e9SAndroid Build Coastguard Worker def test_not_implemented_grad(self): 540*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, requires_grad=True) 541*da0073e9SAndroid Build Coastguard Worker # if grad for nextafter ends up being implemented, this should be changed 542*da0073e9SAndroid Build Coastguard Worker y = torch.nextafter(a, a).sum() 543*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 544*da0073e9SAndroid Build Coastguard Worker NotImplementedError, "the derivative for .* is not implemented" 545*da0073e9SAndroid Build Coastguard Worker ): 546*da0073e9SAndroid Build Coastguard Worker y.backward() 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker def test_not_implemented_fwad(self): 549*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 550*da0073e9SAndroid Build Coastguard Worker v = torch.rand(3) 551*da0073e9SAndroid Build Coastguard Worker 552*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 553*da0073e9SAndroid Build Coastguard Worker dual_x = fwAD.make_dual(x, v) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker err_msg = r"Trying to use forward AD with .* that does not support it" 556*da0073e9SAndroid Build Coastguard Worker hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError" 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): 559*da0073e9SAndroid Build Coastguard Worker # if forward AD ends up being implemented for torch.igamma, choose a different op 560*da0073e9SAndroid Build Coastguard Worker torch.igamma(dual_x, dual_x) 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker def test_saved_tensor_hooks_extra_exit_during_bw_no_crash(self): 563*da0073e9SAndroid Build Coastguard Worker # This usage of saved tensor is not supported, but should not crash 564*da0073e9SAndroid Build Coastguard Worker def unpack(x): 565*da0073e9SAndroid Build Coastguard Worker ctx_1.__exit__() 566*da0073e9SAndroid Build Coastguard Worker return x 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack) 569*da0073e9SAndroid Build Coastguard Worker ctx_2 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x) 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker for i in range(10): 572*da0073e9SAndroid Build Coastguard Worker with ctx_2: 573*da0073e9SAndroid Build Coastguard Worker ctx_1.__enter__() 574*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=True) 575*da0073e9SAndroid Build Coastguard Worker x.sin().sum().backward() 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker # Clean up 578*da0073e9SAndroid Build Coastguard Worker for i in range(10): 579*da0073e9SAndroid Build Coastguard Worker ctx_1.__exit__() 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker # Validate there are no more hooks on the stack 582*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 583*da0073e9SAndroid Build Coastguard Worker y = a.exp() 584*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x) 585*da0073e9SAndroid Build Coastguard Worker 586*da0073e9SAndroid Build Coastguard Worker def test_saved_tensor_hooks_extra_enter_during_bw_no_leak(self): 587*da0073e9SAndroid Build Coastguard Worker # This usage of saved tensor is not supported, but should not leak 588*da0073e9SAndroid Build Coastguard Worker def scope(): 589*da0073e9SAndroid Build Coastguard Worker def unpack(x): 590*da0073e9SAndroid Build Coastguard Worker weak_ctx_1().__enter__() 591*da0073e9SAndroid Build Coastguard Worker return x 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker ctx_1 = torch.autograd.graph.saved_tensors_hooks(lambda x: x, unpack) 594*da0073e9SAndroid Build Coastguard Worker weak_ctx_1 = weakref.ref(ctx_1) 595*da0073e9SAndroid Build Coastguard Worker 596*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=True) 597*da0073e9SAndroid Build Coastguard Worker with ctx_1: 598*da0073e9SAndroid Build Coastguard Worker x.sin().sum().backward() 599*da0073e9SAndroid Build Coastguard Worker return weakref.ref(unpack) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 602*da0073e9SAndroid Build Coastguard Worker unpack_hook_ref = scope() 603*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(unpack_hook_ref()) 604*da0073e9SAndroid Build Coastguard Worker 605*da0073e9SAndroid Build Coastguard Worker def test_will_engine_execute_node(self): 606*da0073e9SAndroid Build Coastguard Worker counter = [0] 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 609*da0073e9SAndroid Build Coastguard Worker @staticmethod 610*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 611*da0073e9SAndroid Build Coastguard Worker return x * 2 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker @staticmethod 614*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 615*da0073e9SAndroid Build Coastguard Worker return gO * 2 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker def get_grad_fn(t): 618*da0073e9SAndroid Build Coastguard Worker if t.requires_grad and t.grad_fn is None: 619*da0073e9SAndroid Build Coastguard Worker return t.clone().grad_fn.next_functions[0][0] 620*da0073e9SAndroid Build Coastguard Worker else: 621*da0073e9SAndroid Build Coastguard Worker return t.grad_fn 622*da0073e9SAndroid Build Coastguard Worker 623*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4, requires_grad=True) 624*da0073e9SAndroid Build Coastguard Worker a2 = torch.randn(2, 3, 4, requires_grad=True) 625*da0073e9SAndroid Build Coastguard Worker b = a * a2 626*da0073e9SAndroid Build Coastguard Worker b2 = b.cos() 627*da0073e9SAndroid Build Coastguard Worker c = MyFunction.apply(b) 628*da0073e9SAndroid Build Coastguard Worker 629*da0073e9SAndroid Build Coastguard Worker should_execute = list(map(get_grad_fn, (a, b, c))) 630*da0073e9SAndroid Build Coastguard Worker should_not_execute = list(map(get_grad_fn, (a2, b2))) 631*da0073e9SAndroid Build Coastguard Worker 632*da0073e9SAndroid Build Coastguard Worker def fn(x): 633*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 634*da0073e9SAndroid Build Coastguard Worker 635*da0073e9SAndroid Build Coastguard Worker for g in should_execute: 636*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._will_engine_execute_node(g)) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker for g in should_not_execute: 639*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._will_engine_execute_node(g)) 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker b.register_hook(fn) 642*da0073e9SAndroid Build Coastguard Worker c.register_hook(fn) 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker # .backward(inputs=) is OK 645*da0073e9SAndroid Build Coastguard Worker out = c.sum() 646*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(out, inputs=(a, b), retain_graph=True) 647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 2) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker # .backward() is OK 650*da0073e9SAndroid Build Coastguard Worker should_execute = list(map(get_grad_fn, (a, a2, b, c))) 651*da0073e9SAndroid Build Coastguard Worker should_not_execute = list(map(get_grad_fn, (b2,))) 652*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(out, retain_graph=True) 653*da0073e9SAndroid Build Coastguard Worker 654*da0073e9SAndroid Build Coastguard Worker # .grad is NOT OK when leaf is passed (this is the current state, subject to change) 655*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 656*da0073e9SAndroid Build Coastguard Worker RuntimeError, "are currently running autograd.grad()" 657*da0073e9SAndroid Build Coastguard Worker ): 658*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out, (a,)) 659*da0073e9SAndroid Build Coastguard Worker 660*da0073e9SAndroid Build Coastguard Worker # .grad is OK when non-leaf is passed 661*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2, 3, requires_grad=True) * 2 662*da0073e9SAndroid Build Coastguard Worker b = a * 2 663*da0073e9SAndroid Build Coastguard Worker 664*da0073e9SAndroid Build Coastguard Worker def fn(x): 665*da0073e9SAndroid Build Coastguard Worker # Check a non-leaf 666*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 667*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._will_engine_execute_node(b.grad_fn)) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker b.register_hook(fn) 670*da0073e9SAndroid Build Coastguard Worker counter[0] = 0 671*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(b.sum(), (a,)) 672*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 673*da0073e9SAndroid Build Coastguard Worker 674*da0073e9SAndroid Build Coastguard Worker # Verify other errors are raised 675*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "during the backward pass"): 676*da0073e9SAndroid Build Coastguard Worker torch._C._will_engine_execute_node(out.grad_fn) 677*da0073e9SAndroid Build Coastguard Worker 678*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "expects an grad_fn"): 679*da0073e9SAndroid Build Coastguard Worker torch._C._will_engine_execute_node(out) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker def test_custom_function_vmap_defaults(self): 682*da0073e9SAndroid Build Coastguard Worker class MySquare(Function): 683*da0073e9SAndroid Build Coastguard Worker @staticmethod 684*da0073e9SAndroid Build Coastguard Worker def forward(x): 685*da0073e9SAndroid Build Coastguard Worker return x**2 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker @staticmethod 688*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 689*da0073e9SAndroid Build Coastguard Worker (x,) = inputs 690*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 691*da0073e9SAndroid Build Coastguard Worker 692*da0073e9SAndroid Build Coastguard Worker @staticmethod 693*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 694*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 695*da0073e9SAndroid Build Coastguard Worker return gO * 2 * x 696*da0073e9SAndroid Build Coastguard Worker 697*da0073e9SAndroid Build Coastguard Worker self.assertFalse(MySquare.generate_vmap_rule) 698*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(MySquare, "vmap")) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker def test_custom_function_setup_context_simple(self): 701*da0073e9SAndroid Build Coastguard Worker class MySquare(Function): 702*da0073e9SAndroid Build Coastguard Worker @staticmethod 703*da0073e9SAndroid Build Coastguard Worker def forward(x): 704*da0073e9SAndroid Build Coastguard Worker return x**2 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker @staticmethod 707*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 708*da0073e9SAndroid Build Coastguard Worker (x,) = inputs 709*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 710*da0073e9SAndroid Build Coastguard Worker 711*da0073e9SAndroid Build Coastguard Worker @staticmethod 712*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 713*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 714*da0073e9SAndroid Build Coastguard Worker return gO * 2 * x 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 717*da0073e9SAndroid Build Coastguard Worker y = MySquare.apply(x) 718*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad(y, x) 719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx, 2 * x) 720*da0073e9SAndroid Build Coastguard Worker 721*da0073e9SAndroid Build Coastguard Worker def test_custom_function_setup_context_multi_output(self): 722*da0073e9SAndroid Build Coastguard Worker # Multiple outputs with some non-Tensor outputs. 723*da0073e9SAndroid Build Coastguard Worker class MySquare(Function): 724*da0073e9SAndroid Build Coastguard Worker @staticmethod 725*da0073e9SAndroid Build Coastguard Worker def forward(x): 726*da0073e9SAndroid Build Coastguard Worker two_x = x.item() * 2 727*da0073e9SAndroid Build Coastguard Worker return x**2, two_x 728*da0073e9SAndroid Build Coastguard Worker 729*da0073e9SAndroid Build Coastguard Worker @staticmethod 730*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 731*da0073e9SAndroid Build Coastguard Worker (x,) = inputs 732*da0073e9SAndroid Build Coastguard Worker _, two_x = output 733*da0073e9SAndroid Build Coastguard Worker ctx.two_x = two_x 734*da0073e9SAndroid Build Coastguard Worker 735*da0073e9SAndroid Build Coastguard Worker @staticmethod 736*da0073e9SAndroid Build Coastguard Worker @once_differentiable 737*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO, _): 738*da0073e9SAndroid Build Coastguard Worker return gO * ctx.two_x 739*da0073e9SAndroid Build Coastguard Worker 740*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 741*da0073e9SAndroid Build Coastguard Worker y, _ = MySquare.apply(x) 742*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad(y, x) 743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx, 2 * x) 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker def test_custom_function_setup_context_multi_input(self): 746*da0073e9SAndroid Build Coastguard Worker class MyReshape(Function): 747*da0073e9SAndroid Build Coastguard Worker @staticmethod 748*da0073e9SAndroid Build Coastguard Worker def forward(x, shape, scale_forward, scale_backward): 749*da0073e9SAndroid Build Coastguard Worker return x.reshape(shape) * scale_forward 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker @staticmethod 752*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 753*da0073e9SAndroid Build Coastguard Worker x, shape, scale_forward, scale_backward = inputs 754*da0073e9SAndroid Build Coastguard Worker ctx.scale_backward = scale_backward 755*da0073e9SAndroid Build Coastguard Worker ctx.x_shape = x.shape 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker @staticmethod 758*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 759*da0073e9SAndroid Build Coastguard Worker return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker class MyReshapeRef(Function): 762*da0073e9SAndroid Build Coastguard Worker @staticmethod 763*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, shape, scale_forward, scale_backward): 764*da0073e9SAndroid Build Coastguard Worker ctx.scale_backward = scale_backward 765*da0073e9SAndroid Build Coastguard Worker ctx.x_shape = x.shape 766*da0073e9SAndroid Build Coastguard Worker return x.reshape(shape) * scale_forward 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker @staticmethod 769*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 770*da0073e9SAndroid Build Coastguard Worker return gO.reshape(ctx.x_shape) * ctx.scale_backward, None, None, None 771*da0073e9SAndroid Build Coastguard Worker 772*da0073e9SAndroid Build Coastguard Worker def test(x, shape, scale_forward, scale_backward): 773*da0073e9SAndroid Build Coastguard Worker y = MyReshape.apply(x, shape, scale_forward, scale_backward).sum() 774*da0073e9SAndroid Build Coastguard Worker (gx,) = torch.autograd.grad(y, x) 775*da0073e9SAndroid Build Coastguard Worker 776*da0073e9SAndroid Build Coastguard Worker y_expected = MyReshapeRef.apply( 777*da0073e9SAndroid Build Coastguard Worker x, shape, scale_forward, scale_backward 778*da0073e9SAndroid Build Coastguard Worker ).sum() 779*da0073e9SAndroid Build Coastguard Worker (gx_expected,) = torch.autograd.grad(y_expected, x) 780*da0073e9SAndroid Build Coastguard Worker 781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_expected, y) 782*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gx_expected, gx) 783*da0073e9SAndroid Build Coastguard Worker 784*da0073e9SAndroid Build Coastguard Worker test(torch.randn(24, requires_grad=True), (3, 8), 7, 11) 785*da0073e9SAndroid Build Coastguard Worker test(torch.randn(2, 3, 4, requires_grad=True), (6, 4), -1, 2) 786*da0073e9SAndroid Build Coastguard Worker 787*da0073e9SAndroid Build Coastguard Worker def test_multiple_insert_removal_caching(self): 788*da0073e9SAndroid Build Coastguard Worker torch._C._set_cached_tensors_enabled(True) 789*da0073e9SAndroid Build Coastguard Worker try: 790*da0073e9SAndroid Build Coastguard Worker x = torch.rand([4]) 791*da0073e9SAndroid Build Coastguard Worker 792*da0073e9SAndroid Build Coastguard Worker torch._C._add_cached_tensor(x) 793*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._is_cached_tensor(x)) 794*da0073e9SAndroid Build Coastguard Worker 795*da0073e9SAndroid Build Coastguard Worker torch._C._add_cached_tensor(x) 796*da0073e9SAndroid Build Coastguard Worker torch._C._remove_cached_tensor(x) 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._is_cached_tensor(x)) 799*da0073e9SAndroid Build Coastguard Worker finally: 800*da0073e9SAndroid Build Coastguard Worker torch._C._set_cached_tensors_enabled(False) 801*da0073e9SAndroid Build Coastguard Worker 802*da0073e9SAndroid Build Coastguard Worker def test_accumulate_grad(self): 803*da0073e9SAndroid Build Coastguard Worker grad_output = torch.ones(5, 5) 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker def compute_grad(create_graph): 806*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 807*da0073e9SAndroid Build Coastguard Worker y = x + 2 808*da0073e9SAndroid Build Coastguard Worker y.backward(grad_output, retain_graph=True) 809*da0073e9SAndroid Build Coastguard Worker x_grad = x.grad 810*da0073e9SAndroid Build Coastguard Worker x_grad_clone = x.grad.clone() 811*da0073e9SAndroid Build Coastguard Worker y.backward(grad_output, create_graph=create_graph) 812*da0073e9SAndroid Build Coastguard Worker return x_grad, x_grad_clone 813*da0073e9SAndroid Build Coastguard Worker 814*da0073e9SAndroid Build Coastguard Worker # Accumulate in-place when create_graph is False 815*da0073e9SAndroid Build Coastguard Worker x_grad, x_grad_clone = compute_grad(create_graph=False) 816*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_grad, x_grad_clone * 2) 817*da0073e9SAndroid Build Coastguard Worker 818*da0073e9SAndroid Build Coastguard Worker # Accumulate out-of-place when create_graph is False 819*da0073e9SAndroid Build Coastguard Worker x_grad, x_grad_clone = compute_grad(create_graph=True) 820*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_grad, x_grad_clone) 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker def test_accumulate_grad_tensor_reference(self): 823*da0073e9SAndroid Build Coastguard Worker def _test_grad_tensor( 824*da0073e9SAndroid Build Coastguard Worker params_grad_tensor, 825*da0073e9SAndroid Build Coastguard Worker backward_grad_tensor, 826*da0073e9SAndroid Build Coastguard Worker should_preserve_reference, 827*da0073e9SAndroid Build Coastguard Worker create_graph, 828*da0073e9SAndroid Build Coastguard Worker ): 829*da0073e9SAndroid Build Coastguard Worker params = torch.tensor([1.5, 1.5]).requires_grad_() 830*da0073e9SAndroid Build Coastguard Worker params.grad = params_grad_tensor 831*da0073e9SAndroid Build Coastguard Worker grad_saved = params.grad 832*da0073e9SAndroid Build Coastguard Worker params.backward(backward_grad_tensor, create_graph=create_graph) 833*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 834*da0073e9SAndroid Build Coastguard Worker id(grad_saved) == id(params.grad), should_preserve_reference 835*da0073e9SAndroid Build Coastguard Worker ) 836*da0073e9SAndroid Build Coastguard Worker 837*da0073e9SAndroid Build Coastguard Worker for create_graph in (False, True): 838*da0073e9SAndroid Build Coastguard Worker # Accumulate dense gradient to sparse gradient will change the `params.grad` reference 839*da0073e9SAndroid Build Coastguard Worker _test_grad_tensor( 840*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor( 841*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0]) 842*da0073e9SAndroid Build Coastguard Worker ), 843*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.5, 1.5]), 844*da0073e9SAndroid Build Coastguard Worker False, # never accumulates in-place 845*da0073e9SAndroid Build Coastguard Worker create_graph, 846*da0073e9SAndroid Build Coastguard Worker ) 847*da0073e9SAndroid Build Coastguard Worker 848*da0073e9SAndroid Build Coastguard Worker # Accumulate dense gradient to dense gradient will preserve the `params.grad` reference, 849*da0073e9SAndroid Build Coastguard Worker # but only if create_graph=False. 850*da0073e9SAndroid Build Coastguard Worker _test_grad_tensor( 851*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.5, 1.5]), 852*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.5, 1.5]), 853*da0073e9SAndroid Build Coastguard Worker not create_graph, 854*da0073e9SAndroid Build Coastguard Worker create_graph, 855*da0073e9SAndroid Build Coastguard Worker ) 856*da0073e9SAndroid Build Coastguard Worker 857*da0073e9SAndroid Build Coastguard Worker # Accumulate sparse gradient to sparse gradient will preserve the `params.grad` reference, 858*da0073e9SAndroid Build Coastguard Worker # but only if create_graph=False. 859*da0073e9SAndroid Build Coastguard Worker _test_grad_tensor( 860*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor( 861*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0]) 862*da0073e9SAndroid Build Coastguard Worker ), 863*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo_tensor( 864*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1]]).long(), torch.tensor([1.0, 1.0]) 865*da0073e9SAndroid Build Coastguard Worker ), 866*da0073e9SAndroid Build Coastguard Worker not create_graph, 867*da0073e9SAndroid Build Coastguard Worker create_graph, 868*da0073e9SAndroid Build Coastguard Worker ) 869*da0073e9SAndroid Build Coastguard Worker 870*da0073e9SAndroid Build Coastguard Worker def test_accumulate_grad_with_zero_numel_grad(self): 871*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, 0, requires_grad=True) 872*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4, 1, requires_grad=True) 873*da0073e9SAndroid Build Coastguard Worker c = a + b 874*da0073e9SAndroid Build Coastguard Worker assert c.shape == (4, 0) 875*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 876*da0073e9SAndroid Build Coastguard Worker 877*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, torch.zeros(4, 1)) 878*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.zeros(4, 0)) 879*da0073e9SAndroid Build Coastguard Worker 880*da0073e9SAndroid Build Coastguard Worker def test_hessian_vector(self): 881*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 882*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, requires_grad=True) 883*da0073e9SAndroid Build Coastguard Worker 884*da0073e9SAndroid Build Coastguard Worker z = x**2 + y * x + y**2 885*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(2, 2), create_graph=True) 886*da0073e9SAndroid Build Coastguard Worker 887*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 888*da0073e9SAndroid Build Coastguard Worker x_grad = 2 * x + y 889*da0073e9SAndroid Build Coastguard Worker y_grad = x + 2 * y 890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad) 891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad) 892*da0073e9SAndroid Build Coastguard Worker 893*da0073e9SAndroid Build Coastguard Worker grad_sum = 2 * x.grad + y.grad 894*da0073e9SAndroid Build Coastguard Worker grad_sum.backward(torch.ones(2, 2)) 895*da0073e9SAndroid Build Coastguard Worker x_hv = torch.ones(2, 2) * 5 896*da0073e9SAndroid Build Coastguard Worker y_hv = torch.ones(2, 2) * 4 897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad + x_hv) 898*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad + y_hv) 899*da0073e9SAndroid Build Coastguard Worker 900*da0073e9SAndroid Build Coastguard Worker def test_grad(self): 901*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 902*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, requires_grad=True) 903*da0073e9SAndroid Build Coastguard Worker z = x**2 + y * x + y**2 904*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(2, 2), create_graph=True) 905*da0073e9SAndroid Build Coastguard Worker 906*da0073e9SAndroid Build Coastguard Worker x_grad = 2 * x + y 907*da0073e9SAndroid Build Coastguard Worker y_grad = x + 2 * y 908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad) 909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad) 910*da0073e9SAndroid Build Coastguard Worker 911*da0073e9SAndroid Build Coastguard Worker grad_sum = 2 * x.grad + y.grad 912*da0073e9SAndroid Build Coastguard Worker x_hv = torch.autograd.grad( 913*da0073e9SAndroid Build Coastguard Worker outputs=[grad_sum], 914*da0073e9SAndroid Build Coastguard Worker grad_outputs=[torch.ones(2, 2)], 915*da0073e9SAndroid Build Coastguard Worker inputs=[x], 916*da0073e9SAndroid Build Coastguard Worker create_graph=True, 917*da0073e9SAndroid Build Coastguard Worker ) 918*da0073e9SAndroid Build Coastguard Worker expected_x_hv = torch.ones(2, 2) * 5 919*da0073e9SAndroid Build Coastguard Worker expected_y_hv = torch.ones(2, 2) * 4 920*da0073e9SAndroid Build Coastguard Worker 921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_hv[0], expected_x_hv) 922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad) 923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad) 924*da0073e9SAndroid Build Coastguard Worker 925*da0073e9SAndroid Build Coastguard Worker # Test that grad_outputs and outputs have the same shape 926*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones(2) 927*da0073e9SAndroid Build Coastguard Worker try: 928*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 929*da0073e9SAndroid Build Coastguard Worker outputs=[grad_sum], 930*da0073e9SAndroid Build Coastguard Worker grad_outputs=[grad_out], 931*da0073e9SAndroid Build Coastguard Worker inputs=[x], 932*da0073e9SAndroid Build Coastguard Worker create_graph=True, 933*da0073e9SAndroid Build Coastguard Worker ) 934*da0073e9SAndroid Build Coastguard Worker self.assertFail() 935*da0073e9SAndroid Build Coastguard Worker except RuntimeError as error: 936*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 937*da0073e9SAndroid Build Coastguard Worker str(error), 938*da0073e9SAndroid Build Coastguard Worker "Mismatch in shape: grad_output[0] has a shape of " 939*da0073e9SAndroid Build Coastguard Worker + str(grad_out.shape) 940*da0073e9SAndroid Build Coastguard Worker + " and output[0] has a shape of " 941*da0073e9SAndroid Build Coastguard Worker + str(grad_sum.shape) 942*da0073e9SAndroid Build Coastguard Worker + ".", 943*da0073e9SAndroid Build Coastguard Worker ) 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker def test_grad_to_node(self): 946*da0073e9SAndroid Build Coastguard Worker def check_matches(out, inp): 947*da0073e9SAndroid Build Coastguard Worker ref = torch.autograd.grad(out.sum(), inp) 948*da0073e9SAndroid Build Coastguard Worker 949*da0073e9SAndroid Build Coastguard Worker edge = torch.autograd.graph.get_gradient_edge(inp) 950*da0073e9SAndroid Build Coastguard Worker new = torch.autograd.grad(out.sum(), edge) 951*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, new) 952*da0073e9SAndroid Build Coastguard Worker 953*da0073e9SAndroid Build Coastguard Worker # We need to ensure that our main types of Node work (regular cpp Nodes, 954*da0073e9SAndroid Build Coastguard Worker # AccumulateGrad Nodes and custom Function) 955*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True) 956*da0073e9SAndroid Build Coastguard Worker out = x.clone() 957*da0073e9SAndroid Build Coastguard Worker check_matches(out, x) 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker x = x.clone() 960*da0073e9SAndroid Build Coastguard Worker out = x.clone() 961*da0073e9SAndroid Build Coastguard Worker check_matches(out, x) 962*da0073e9SAndroid Build Coastguard Worker 963*da0073e9SAndroid Build Coastguard Worker x = torch.autograd._functions.Resize.apply(x, (2,)) 964*da0073e9SAndroid Build Coastguard Worker out = x.clone() 965*da0073e9SAndroid Build Coastguard Worker check_matches(out, x) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker x = torch.var_mean(x)[1] 968*da0073e9SAndroid Build Coastguard Worker out = x.clone() 969*da0073e9SAndroid Build Coastguard Worker check_matches(out, x) 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker def test_grad_to_node_set(self): 972*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True) 973*da0073e9SAndroid Build Coastguard Worker x_edge = torch.autograd.graph.get_gradient_edge(x) 974*da0073e9SAndroid Build Coastguard Worker out = x.clone() 975*da0073e9SAndroid Build Coastguard Worker 976*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 977*da0073e9SAndroid Build Coastguard Worker x.set_(torch.rand_like(x)) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "to not have been used in the graph"): 980*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), x) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker # Works 983*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), x_edge) 984*da0073e9SAndroid Build Coastguard Worker 985*da0073e9SAndroid Build Coastguard Worker def test_grad_to_node_inplace(self): 986*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True).clone() 987*da0073e9SAndroid Build Coastguard Worker x_edge = torch.autograd.graph.get_gradient_edge(x) 988*da0073e9SAndroid Build Coastguard Worker x *= 2 989*da0073e9SAndroid Build Coastguard Worker 990*da0073e9SAndroid Build Coastguard Worker g_old, g_new = torch.autograd.grad(x.sum(), (x_edge, x)) 991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g_old, 2 * torch.ones_like(x)) 992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g_new, torch.ones_like(x)) 993*da0073e9SAndroid Build Coastguard Worker 994*da0073e9SAndroid Build Coastguard Worker def test_grad_to_node_multi(self): 995*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True).clone() 996*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, requires_grad=True).clone() 997*da0073e9SAndroid Build Coastguard Worker 998*da0073e9SAndroid Build Coastguard Worker out = x + y 999*da0073e9SAndroid Build Coastguard Worker 1000*da0073e9SAndroid Build Coastguard Worker ref = torch.autograd.grad(out.sum(), (x, y)) 1001*da0073e9SAndroid Build Coastguard Worker 1002*da0073e9SAndroid Build Coastguard Worker inp_edges = ( 1003*da0073e9SAndroid Build Coastguard Worker GradientEdge(x.grad_fn, x.output_nr), 1004*da0073e9SAndroid Build Coastguard Worker GradientEdge(y.grad_fn, y.output_nr), 1005*da0073e9SAndroid Build Coastguard Worker ) 1006*da0073e9SAndroid Build Coastguard Worker new = torch.autograd.grad(out.sum(), inp_edges) 1007*da0073e9SAndroid Build Coastguard Worker 1008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref, new) 1009*da0073e9SAndroid Build Coastguard Worker 1010*da0073e9SAndroid Build Coastguard Worker def test_grad_to_node_materialize(self): 1011*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True).clone() 1012*da0073e9SAndroid Build Coastguard Worker edge_x = GradientEdge(x.grad_fn, x.output_nr) 1013*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, requires_grad=True).clone() 1014*da0073e9SAndroid Build Coastguard Worker edge_y = GradientEdge(y.grad_fn, y.output_nr) 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker out = x.clone() 1017*da0073e9SAndroid Build Coastguard Worker 1018*da0073e9SAndroid Build Coastguard Worker # Works 1019*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1020*da0073e9SAndroid Build Coastguard Worker out.sum(), (edge_x, y), allow_unused=True, materialize_grads=True 1021*da0073e9SAndroid Build Coastguard Worker ) 1022*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1023*da0073e9SAndroid Build Coastguard Worker out.sum(), (x, y), allow_unused=True, materialize_grads=True 1024*da0073e9SAndroid Build Coastguard Worker ) 1025*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out.sum(), (x, edge_y), allow_unused=True) 1026*da0073e9SAndroid Build Coastguard Worker 1027*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1028*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1029*da0073e9SAndroid Build Coastguard Worker "materialize_grads cannot be used when the given input is a GradientEdge", 1030*da0073e9SAndroid Build Coastguard Worker ): 1031*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1032*da0073e9SAndroid Build Coastguard Worker out.sum(), (x, edge_y), allow_unused=True, materialize_grads=True 1033*da0073e9SAndroid Build Coastguard Worker ) 1034*da0073e9SAndroid Build Coastguard Worker 1035*da0073e9SAndroid Build Coastguard Worker def test_backward_to_node(self): 1036*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True).clone() 1037*da0073e9SAndroid Build Coastguard Worker edge_x = GradientEdge(x.grad_fn, x.output_nr) 1038*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, requires_grad=True).clone() 1039*da0073e9SAndroid Build Coastguard Worker edge_y = GradientEdge(y.grad_fn, y.output_nr) 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker out = x.clone() 1042*da0073e9SAndroid Build Coastguard Worker 1043*da0073e9SAndroid Build Coastguard Worker # All should work in this case 1044*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(out.sum(), inputs=(edge_x, y)) 1045*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(out.sum(), inputs=(x, y)) 1046*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(out.sum(), inputs=(x, edge_y)) 1047*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(out.sum(), inputs=(edge_x, edge_y)) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker def test_grad_fn_input_metadata(self): 1050*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, requires_grad=True, dtype=torch.float32) 1051*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, requires_grad=True, dtype=torch.float32) 1052*da0073e9SAndroid Build Coastguard Worker z = x * y 1053*da0073e9SAndroid Build Coastguard Worker z_metadata = z.grad_fn._input_metadata[0] 1054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z_metadata.shape, (2,)) 1055*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z_metadata.dtype, torch.float32) 1056*da0073e9SAndroid Build Coastguard Worker 1057*da0073e9SAndroid Build Coastguard Worker # Multiple outputs 1058*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 3, requires_grad=True) 1059*da0073e9SAndroid Build Coastguard Worker var, _ = torch.var_mean(b, dim=0) 1060*da0073e9SAndroid Build Coastguard Worker 1061*da0073e9SAndroid Build Coastguard Worker metadata_0 = var.grad_fn._input_metadata[0] 1062*da0073e9SAndroid Build Coastguard Worker metadata_1 = var.grad_fn._input_metadata[1] 1063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(metadata_0.shape, (3,)) 1064*da0073e9SAndroid Build Coastguard Worker self.assertEqual(metadata_1.shape, (3,)) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker # Preserves symints 1067*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 1068*da0073e9SAndroid Build Coastguard Worker [torch.randn(3, 2), torch.randn(2, 2)], 1069*da0073e9SAndroid Build Coastguard Worker layout=torch.jagged, 1070*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 1071*da0073e9SAndroid Build Coastguard Worker ) 1072*da0073e9SAndroid Build Coastguard Worker nt_metadata = nt.clone().grad_fn._input_metadata[0] 1073*da0073e9SAndroid Build Coastguard Worker 1074*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(nt_metadata.shape[1], torch.SymInt) 1075*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_metadata.shape, nt.shape) 1076*da0073e9SAndroid Build Coastguard Worker self.assertTrue(nt_metadata.is_nested_tensor) 1077*da0073e9SAndroid Build Coastguard Worker self.assertFalse(nt_metadata.is_cpp_nested_tensor) 1078*da0073e9SAndroid Build Coastguard Worker self.assertEqual(nt_metadata.dtype, nt.dtype) 1079*da0073e9SAndroid Build Coastguard Worker 1080*da0073e9SAndroid Build Coastguard Worker class Test(torch.autograd.Function): 1081*da0073e9SAndroid Build Coastguard Worker @staticmethod 1082*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 1083*da0073e9SAndroid Build Coastguard Worker return x 1084*da0073e9SAndroid Build Coastguard Worker 1085*da0073e9SAndroid Build Coastguard Worker @staticmethod 1086*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 1087*da0073e9SAndroid Build Coastguard Worker return grad_output 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=True) 1090*da0073e9SAndroid Build Coastguard Worker x = Test.apply(x) 1091*da0073e9SAndroid Build Coastguard Worker metadata = x.grad_fn._input_metadata[0] 1092*da0073e9SAndroid Build Coastguard Worker self.assertEqual(metadata.shape, (3, 3)) 1093*da0073e9SAndroid Build Coastguard Worker 1094*da0073e9SAndroid Build Coastguard Worker def test_gradient_edge_output(self): 1095*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0, 2.0], requires_grad=True) 1096*da0073e9SAndroid Build Coastguard Worker 1097*da0073e9SAndroid Build Coastguard Worker def fn(x, reduce=True): 1098*da0073e9SAndroid Build Coastguard Worker tmp = x.sin().cos() 1099*da0073e9SAndroid Build Coastguard Worker if reduce: 1100*da0073e9SAndroid Build Coastguard Worker tmp = tmp.sum() 1101*da0073e9SAndroid Build Coastguard Worker out = tmp.exp().clone().sin().sum() 1102*da0073e9SAndroid Build Coastguard Worker tmp_edge = torch.autograd.graph.get_gradient_edge(tmp) 1103*da0073e9SAndroid Build Coastguard Worker return out, tmp_edge 1104*da0073e9SAndroid Build Coastguard Worker 1105*da0073e9SAndroid Build Coastguard Worker # Compute fn backward in two steps 1106*da0073e9SAndroid Build Coastguard Worker out, tmp_edge = fn(x) 1107*da0073e9SAndroid Build Coastguard Worker (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) 1108*da0073e9SAndroid Build Coastguard Worker 1109*da0073e9SAndroid Build Coastguard Worker (x_grad,) = torch.autograd.grad(tmp_edge, (x,), grad_outputs=(tmp_grad,)) 1110*da0073e9SAndroid Build Coastguard Worker 1111*da0073e9SAndroid Build Coastguard Worker # Compare with as if we did it in one go. 1112*da0073e9SAndroid Build Coastguard Worker out, _ = fn(x) 1113*da0073e9SAndroid Build Coastguard Worker (x_grad_ref,) = torch.autograd.grad(out, (x,)) 1114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_grad, x_grad_ref) 1115*da0073e9SAndroid Build Coastguard Worker 1116*da0073e9SAndroid Build Coastguard Worker # Incorrect case: grad_outputs not passed/implicitly None and output is 1117*da0073e9SAndroid Build Coastguard Worker # not a scalar 1118*da0073e9SAndroid Build Coastguard Worker out, tmp_edge = fn(x, reduce=False) 1119*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1120*da0073e9SAndroid Build Coastguard Worker RuntimeError, "grad can be implicitly created only for scalar output" 1121*da0073e9SAndroid Build Coastguard Worker ): 1122*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(tmp_edge, (x,)) 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker # grad_outputs is None, and output is a scalar is fine 1125*da0073e9SAndroid Build Coastguard Worker out, tmp_edge = fn(x, reduce=True) 1126*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(tmp_edge, (x,)) 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker # Incorrect case: grad_outputs wrong size 1129*da0073e9SAndroid Build Coastguard Worker out, tmp_edge = fn(x) 1130*da0073e9SAndroid Build Coastguard Worker (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) 1131*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): 1132*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1133*da0073e9SAndroid Build Coastguard Worker tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0]) 1134*da0073e9SAndroid Build Coastguard Worker ) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # Incorrect case: wrong dtype 1137*da0073e9SAndroid Build Coastguard Worker out, tmp_edge = fn(x) 1138*da0073e9SAndroid Build Coastguard Worker (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) 1139*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "required to have the same dtype"): 1140*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1141*da0073e9SAndroid Build Coastguard Worker tmp_edge, 1142*da0073e9SAndroid Build Coastguard Worker (x,), 1143*da0073e9SAndroid Build Coastguard Worker grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64), 1144*da0073e9SAndroid Build Coastguard Worker ) 1145*da0073e9SAndroid Build Coastguard Worker 1146*da0073e9SAndroid Build Coastguard Worker def test_grad_nonleaf(self): 1147*da0073e9SAndroid Build Coastguard Worker x_init = torch.randn(2, 2, requires_grad=True) 1148*da0073e9SAndroid Build Coastguard Worker x = x_init 1149*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, requires_grad=True) 1150*da0073e9SAndroid Build Coastguard Worker grad_output = torch.ones(2, 2) 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker def fn(x): 1153*da0073e9SAndroid Build Coastguard Worker return x**2 + y * x + y**2 1154*da0073e9SAndroid Build Coastguard Worker 1155*da0073e9SAndroid Build Coastguard Worker for _ in range(5): 1156*da0073e9SAndroid Build Coastguard Worker (grad_x,) = torch.autograd.grad( 1157*da0073e9SAndroid Build Coastguard Worker fn(x), x, grad_outputs=grad_output, create_graph=True 1158*da0073e9SAndroid Build Coastguard Worker ) 1159*da0073e9SAndroid Build Coastguard Worker 1160*da0073e9SAndroid Build Coastguard Worker grad_x_expected = 2 * x + y 1161*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(y.grad) 1162*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad) 1163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_x, grad_x_expected) 1164*da0073e9SAndroid Build Coastguard Worker 1165*da0073e9SAndroid Build Coastguard Worker x = x + 0.05 * grad_x 1166*da0073e9SAndroid Build Coastguard Worker 1167*da0073e9SAndroid Build Coastguard Worker val_init = fn(x_init).sum() 1168*da0073e9SAndroid Build Coastguard Worker val_final = fn(x).sum() 1169*da0073e9SAndroid Build Coastguard Worker self.assertGreater(val_final, val_init) 1170*da0073e9SAndroid Build Coastguard Worker 1171*da0073e9SAndroid Build Coastguard Worker x.backward(grad_output) 1172*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(y.grad) 1173*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(x_init.grad) 1174*da0073e9SAndroid Build Coastguard Worker 1175*da0073e9SAndroid Build Coastguard Worker def test_grad_nonleaf_many_outputs(self): 1176*da0073e9SAndroid Build Coastguard Worker # This checks an edge case for function callbacks 1177*da0073e9SAndroid Build Coastguard Worker # We want to capture two grads of a function, but can only 1178*da0073e9SAndroid Build Coastguard Worker # register a single callback. 1179*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 2, requires_grad=True) 1180*da0073e9SAndroid Build Coastguard Worker a, b = x.chunk(2) 1181*da0073e9SAndroid Build Coastguard Worker 1182*da0073e9SAndroid Build Coastguard Worker def hook(*grads): 1183*da0073e9SAndroid Build Coastguard Worker hook_called[0] = True 1184*da0073e9SAndroid Build Coastguard Worker 1185*da0073e9SAndroid Build Coastguard Worker hook_called = [False] 1186*da0073e9SAndroid Build Coastguard Worker x.register_hook(hook) 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker go = torch.randn(2, 2) 1189*da0073e9SAndroid Build Coastguard Worker grad_a, grad_b = torch.autograd.grad( 1190*da0073e9SAndroid Build Coastguard Worker (a + 2 * b), [a, b], grad_outputs=go, create_graph=True 1191*da0073e9SAndroid Build Coastguard Worker ) 1192*da0073e9SAndroid Build Coastguard Worker 1193*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_a, go) 1194*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_b, go * 2) 1195*da0073e9SAndroid Build Coastguard Worker self.assertFalse(hook_called[0]) 1196*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad) 1197*da0073e9SAndroid Build Coastguard Worker 1198*da0073e9SAndroid Build Coastguard Worker def test_grad_nonleaf_register_hook(self): 1199*da0073e9SAndroid Build Coastguard Worker # This checks an edge case for register_hook. 1200*da0073e9SAndroid Build Coastguard Worker # We want to capture grad of a nonleaf tensor, 1201*da0073e9SAndroid Build Coastguard Worker # but avoid segfault during backward of other nonleaf tensors 1202*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, requires_grad=True) 1203*da0073e9SAndroid Build Coastguard Worker x_list = x.unbind() 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker x0 = x_list[0] 1206*da0073e9SAndroid Build Coastguard Worker hook_results = [None] 1207*da0073e9SAndroid Build Coastguard Worker 1208*da0073e9SAndroid Build Coastguard Worker def hook(grad): 1209*da0073e9SAndroid Build Coastguard Worker hook_results[0] = grad 1210*da0073e9SAndroid Build Coastguard Worker 1211*da0073e9SAndroid Build Coastguard Worker x0.register_hook(hook) 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Worker x_list[0].backward() 1214*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hook_results[0], torch.tensor(1.0)) 1215*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.tensor([1.0, 0, 0, 0, 0]) 1216*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, expected_grad) 1217*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x_list[0].grad) 1218*da0073e9SAndroid Build Coastguard Worker 1219*da0073e9SAndroid Build Coastguard Worker for i in range(1, 5, 1): 1220*da0073e9SAndroid Build Coastguard Worker x_list[i].backward() 1221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hook_results[0], None) 1222*da0073e9SAndroid Build Coastguard Worker expected_grad[i] = 1.0 1223*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, expected_grad) 1224*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x_list[i].grad) 1225*da0073e9SAndroid Build Coastguard Worker 1226*da0073e9SAndroid Build Coastguard Worker def test_grad_materialize_grads(self): 1227*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(0.5, requires_grad=True) 1228*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1229*da0073e9SAndroid Build Coastguard Worker y = x * a 1230*da0073e9SAndroid Build Coastguard Worker dydx = torch.autograd.grad(y, x, create_graph=True) 1231*da0073e9SAndroid Build Coastguard Worker d2ydx2_none = torch.autograd.grad(dydx, x, create_graph=True, allow_unused=True) 1232*da0073e9SAndroid Build Coastguard Worker d2ydx2 = torch.autograd.grad( 1233*da0073e9SAndroid Build Coastguard Worker dydx, x, create_graph=True, allow_unused=True, materialize_grads=True 1234*da0073e9SAndroid Build Coastguard Worker ) 1235*da0073e9SAndroid Build Coastguard Worker # `allow_unused` set to True implicitly 1236*da0073e9SAndroid Build Coastguard Worker d3ydx3 = torch.autograd.grad(d2ydx2, x, materialize_grads=True) 1237*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(d2ydx2_none[0]) 1238*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d2ydx2[0].item(), 0) 1239*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d3ydx3[0].item(), 0) 1240*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1241*da0073e9SAndroid Build Coastguard Worker ValueError, "Expected allow_unused to be True or not passed when" 1242*da0073e9SAndroid Build Coastguard Worker ): 1243*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x, allow_unused=False, materialize_grads=True) 1244*da0073e9SAndroid Build Coastguard Worker 1245*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_on_non_leaf(self): 1246*da0073e9SAndroid Build Coastguard Worker def hook(tensor): 1247*da0073e9SAndroid Build Coastguard Worker tensor.sub_(1.0) 1248*da0073e9SAndroid Build Coastguard Worker 1249*da0073e9SAndroid Build Coastguard Worker leaf = torch.rand(3, requires_grad=True) 1250*da0073e9SAndroid Build Coastguard Worker non_leaf = 2.0 * leaf 1251*da0073e9SAndroid Build Coastguard Worker 1252*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1253*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1254*da0073e9SAndroid Build Coastguard Worker "post accumulate grad hooks cannot be registered on non-leaf tensors", 1255*da0073e9SAndroid Build Coastguard Worker ): 1256*da0073e9SAndroid Build Coastguard Worker non_leaf.register_post_accumulate_grad_hook(hook) 1257*da0073e9SAndroid Build Coastguard Worker 1258*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_multiple_hooks(self): 1259*da0073e9SAndroid Build Coastguard Worker def hook1(tensor): 1260*da0073e9SAndroid Build Coastguard Worker tensor.sub_(tensor.grad) 1261*da0073e9SAndroid Build Coastguard Worker 1262*da0073e9SAndroid Build Coastguard Worker def hook2(tensor): 1263*da0073e9SAndroid Build Coastguard Worker tensor.mul_(4.0) 1264*da0073e9SAndroid Build Coastguard Worker 1265*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(3, requires_grad=True) 1266*da0073e9SAndroid Build Coastguard Worker tensor_ref = tensor.clone().detach() 1267*da0073e9SAndroid Build Coastguard Worker tensor.register_post_accumulate_grad_hook(hook1) 1268*da0073e9SAndroid Build Coastguard Worker tensor.register_post_accumulate_grad_hook(hook2) 1269*da0073e9SAndroid Build Coastguard Worker sum = tensor.sum() 1270*da0073e9SAndroid Build Coastguard Worker sum.backward() 1271*da0073e9SAndroid Build Coastguard Worker # both hooks should be called, in order 1272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4.0 * (tensor_ref - 1.0), tensor) 1273*da0073e9SAndroid Build Coastguard Worker 1274*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_multiple_tensors(self): 1275*da0073e9SAndroid Build Coastguard Worker def hook(tensor): 1276*da0073e9SAndroid Build Coastguard Worker tensor.sub_(tensor.grad) 1277*da0073e9SAndroid Build Coastguard Worker 1278*da0073e9SAndroid Build Coastguard Worker tensor1 = torch.rand(3, requires_grad=True) 1279*da0073e9SAndroid Build Coastguard Worker tensor1_ref = tensor1.clone().detach() 1280*da0073e9SAndroid Build Coastguard Worker tensor2 = torch.rand(5, requires_grad=True) 1281*da0073e9SAndroid Build Coastguard Worker tensor2_ref = tensor2.clone().detach() 1282*da0073e9SAndroid Build Coastguard Worker tensor1.register_post_accumulate_grad_hook(hook) 1283*da0073e9SAndroid Build Coastguard Worker tensor2.register_post_accumulate_grad_hook(hook) 1284*da0073e9SAndroid Build Coastguard Worker tensor1.sum().backward() 1285*da0073e9SAndroid Build Coastguard Worker tensor2.sum().backward() 1286*da0073e9SAndroid Build Coastguard Worker # both tensors should have been modified 1287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor1_ref - 1.0, tensor1) 1288*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor2_ref - 1.0, tensor2) 1289*da0073e9SAndroid Build Coastguard Worker 1290*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_returns_not_None(self): 1291*da0073e9SAndroid Build Coastguard Worker def bad_hook(tensor): 1292*da0073e9SAndroid Build Coastguard Worker return tensor.grad 1293*da0073e9SAndroid Build Coastguard Worker 1294*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(2, 3, requires_grad=True) 1295*da0073e9SAndroid Build Coastguard Worker tensor.register_post_accumulate_grad_hook(bad_hook) 1296*da0073e9SAndroid Build Coastguard Worker # should error! 1297*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "hooks should return None."): 1298*da0073e9SAndroid Build Coastguard Worker tensor.sum().backward() 1299*da0073e9SAndroid Build Coastguard Worker 1300*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_e2e(self): 1301*da0073e9SAndroid Build Coastguard Worker def setup_optim_in_bwd(model): 1302*da0073e9SAndroid Build Coastguard Worker optims = {} 1303*da0073e9SAndroid Build Coastguard Worker handles = [] 1304*da0073e9SAndroid Build Coastguard Worker 1305*da0073e9SAndroid Build Coastguard Worker def optim_step_hook(param): 1306*da0073e9SAndroid Build Coastguard Worker optims[param].step() 1307*da0073e9SAndroid Build Coastguard Worker optims[param].zero_grad() 1308*da0073e9SAndroid Build Coastguard Worker 1309*da0073e9SAndroid Build Coastguard Worker for p in model.parameters(): 1310*da0073e9SAndroid Build Coastguard Worker optims[p] = torch.optim.Adam([p]) 1311*da0073e9SAndroid Build Coastguard Worker handles.append(p.register_post_accumulate_grad_hook(optim_step_hook)) 1312*da0073e9SAndroid Build Coastguard Worker 1313*da0073e9SAndroid Build Coastguard Worker return handles 1314*da0073e9SAndroid Build Coastguard Worker 1315*da0073e9SAndroid Build Coastguard Worker model = torch.nn.Linear(3, 2) 1316*da0073e9SAndroid Build Coastguard Worker input = torch.rand(2, 3) 1317*da0073e9SAndroid Build Coastguard Worker handles = setup_optim_in_bwd(model) 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker # make a copy for reference 1320*da0073e9SAndroid Build Coastguard Worker model_copy = deepcopy(model) 1321*da0073e9SAndroid Build Coastguard Worker optim_copy = torch.optim.Adam(model_copy.parameters()) 1322*da0073e9SAndroid Build Coastguard Worker 1323*da0073e9SAndroid Build Coastguard Worker iters = 5 1324*da0073e9SAndroid Build Coastguard Worker 1325*da0073e9SAndroid Build Coastguard Worker for _ in range(iters): 1326*da0073e9SAndroid Build Coastguard Worker loss = model(input).sum() 1327*da0073e9SAndroid Build Coastguard Worker loss.backward() 1328*da0073e9SAndroid Build Coastguard Worker 1329*da0073e9SAndroid Build Coastguard Worker loss_copy = model_copy(input).sum() 1330*da0073e9SAndroid Build Coastguard Worker loss_copy.backward() 1331*da0073e9SAndroid Build Coastguard Worker optim_copy.step() 1332*da0073e9SAndroid Build Coastguard Worker optim_copy.zero_grad() 1333*da0073e9SAndroid Build Coastguard Worker 1334*da0073e9SAndroid Build Coastguard Worker params_copy = [] # freeze a copy of the params to compare later 1335*da0073e9SAndroid Build Coastguard Worker for p_reference, p in zip(model_copy.parameters(), model.parameters()): 1336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p_reference, p) 1337*da0073e9SAndroid Build Coastguard Worker params_copy.append(p_reference.clone().detach()) 1338*da0073e9SAndroid Build Coastguard Worker 1339*da0073e9SAndroid Build Coastguard Worker # After removing the handle, the model should no longer update. 1340*da0073e9SAndroid Build Coastguard Worker for h in handles: 1341*da0073e9SAndroid Build Coastguard Worker h.remove() 1342*da0073e9SAndroid Build Coastguard Worker 1343*da0073e9SAndroid Build Coastguard Worker for _ in range(iters): 1344*da0073e9SAndroid Build Coastguard Worker loss = model(input).sum() 1345*da0073e9SAndroid Build Coastguard Worker loss.backward() 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker loss_copy = model_copy(input).sum() 1348*da0073e9SAndroid Build Coastguard Worker loss_copy.backward() 1349*da0073e9SAndroid Build Coastguard Worker optim_copy.step() 1350*da0073e9SAndroid Build Coastguard Worker optim_copy.zero_grad() 1351*da0073e9SAndroid Build Coastguard Worker 1352*da0073e9SAndroid Build Coastguard Worker for p_static, p_reference, p in zip( 1353*da0073e9SAndroid Build Coastguard Worker params_copy, model_copy.parameters(), model.parameters() 1354*da0073e9SAndroid Build Coastguard Worker ): 1355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p_static, p) 1356*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(p_reference, p) 1357*da0073e9SAndroid Build Coastguard Worker 1358*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_gets_cleaned_up(self): 1359*da0073e9SAndroid Build Coastguard Worker def fun_stuff_with_hook(): 1360*da0073e9SAndroid Build Coastguard Worker thing_to_put_in_hook = torch.rand(3) 1361*da0073e9SAndroid Build Coastguard Worker 1362*da0073e9SAndroid Build Coastguard Worker def hook(tensor): 1363*da0073e9SAndroid Build Coastguard Worker tensor.sub_(tensor.grad) 1364*da0073e9SAndroid Build Coastguard Worker tensor.add_(thing_to_put_in_hook) 1365*da0073e9SAndroid Build Coastguard Worker 1366*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(3, requires_grad=True) 1367*da0073e9SAndroid Build Coastguard Worker tensor.register_post_accumulate_grad_hook(hook) 1368*da0073e9SAndroid Build Coastguard Worker tensor.sum().backward() 1369*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(thing_to_put_in_hook) 1370*da0073e9SAndroid Build Coastguard Worker gc.collect() 1371*da0073e9SAndroid Build Coastguard Worker return tensor, ref 1372*da0073e9SAndroid Build Coastguard Worker 1373*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 1374*da0073e9SAndroid Build Coastguard Worker tensor, ref = fun_stuff_with_hook() 1375*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone( 1376*da0073e9SAndroid Build Coastguard Worker ref() 1377*da0073e9SAndroid Build Coastguard Worker ) # thing_to_put_in_hook should be kept alive by tensor 1378*da0073e9SAndroid Build Coastguard Worker 1379*da0073e9SAndroid Build Coastguard Worker del tensor 1380*da0073e9SAndroid Build Coastguard Worker gc.collect() 1381*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref()) # thing_to_put_in_hook should be cleaned 1382*da0073e9SAndroid Build Coastguard Worker 1383*da0073e9SAndroid Build Coastguard Worker def test_post_accumulate_grad_hook_ordering(self): 1384*da0073e9SAndroid Build Coastguard Worker tensor = torch.rand(3, requires_grad=True) 1385*da0073e9SAndroid Build Coastguard Worker 1386*da0073e9SAndroid Build Coastguard Worker def pre_hook(grad): 1387*da0073e9SAndroid Build Coastguard Worker return grad.sub(2.0) 1388*da0073e9SAndroid Build Coastguard Worker 1389*da0073e9SAndroid Build Coastguard Worker def acc_grad_node_pre_hook(grad_out): 1390*da0073e9SAndroid Build Coastguard Worker return (grad_out[0].div(5.0),) 1391*da0073e9SAndroid Build Coastguard Worker 1392*da0073e9SAndroid Build Coastguard Worker def post_acc_grad_hook(tensor): 1393*da0073e9SAndroid Build Coastguard Worker tensor.grad.add_(0.5) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker def acc_grad_node_post_hook(grad_in, grad_out): 1396*da0073e9SAndroid Build Coastguard Worker tensor.grad = grad_out[0].mul(10) 1397*da0073e9SAndroid Build Coastguard Worker 1398*da0073e9SAndroid Build Coastguard Worker acc_grad = tensor.view_as(tensor).grad_fn.next_functions[0][0] 1399*da0073e9SAndroid Build Coastguard Worker tensor.register_hook(pre_hook) 1400*da0073e9SAndroid Build Coastguard Worker acc_grad.register_prehook(acc_grad_node_pre_hook) 1401*da0073e9SAndroid Build Coastguard Worker tensor.register_post_accumulate_grad_hook(post_acc_grad_hook) 1402*da0073e9SAndroid Build Coastguard Worker acc_grad.register_hook(acc_grad_node_post_hook) 1403*da0073e9SAndroid Build Coastguard Worker tensor.sum().backward() 1404*da0073e9SAndroid Build Coastguard Worker 1405*da0073e9SAndroid Build Coastguard Worker # the hooks should run in the order of: 1406*da0073e9SAndroid Build Coastguard Worker # 1. tensor prehook 1407*da0073e9SAndroid Build Coastguard Worker # 2. acc_grad prehook 1408*da0073e9SAndroid Build Coastguard Worker # 3. tensor post acc_grad hook 1409*da0073e9SAndroid Build Coastguard Worker # 4. acc_grad posthook 1410*da0073e9SAndroid Build Coastguard Worker # so that would be ((1 - 2) / 5 + 0.5) * 10 = 3 1411*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.tensor([3.0, 3.0, 3.0]), tensor.grad) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker def test_hook_with_no_name(self): 1414*da0073e9SAndroid Build Coastguard Worker # Create a hook that do not have a __name__ attribute 1415*da0073e9SAndroid Build Coastguard Worker class MyHookClass: 1416*da0073e9SAndroid Build Coastguard Worker def __call__(self, grad): 1417*da0073e9SAndroid Build Coastguard Worker return grad.clone() 1418*da0073e9SAndroid Build Coastguard Worker 1419*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, requires_grad=True).clone() 1420*da0073e9SAndroid Build Coastguard Worker x.register_hook(MyHookClass()) 1421*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 1422*da0073e9SAndroid Build Coastguard Worker # Should run fine 1423*da0073e9SAndroid Build Coastguard Worker 1424*da0073e9SAndroid Build Coastguard Worker def test_prehook_ordering(self): 1425*da0073e9SAndroid Build Coastguard Worker # Hooks registered to tensor are ordered before those 1426*da0073e9SAndroid Build Coastguard Worker # that are registered to grad_fn 1427*da0073e9SAndroid Build Coastguard Worker log = [] 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker def hook1(g): 1430*da0073e9SAndroid Build Coastguard Worker log.append(1) 1431*da0073e9SAndroid Build Coastguard Worker return g * 3 1432*da0073e9SAndroid Build Coastguard Worker 1433*da0073e9SAndroid Build Coastguard Worker def hook2(gs): 1434*da0073e9SAndroid Build Coastguard Worker log.append(2) 1435*da0073e9SAndroid Build Coastguard Worker return tuple(g * 2 for g in gs) 1436*da0073e9SAndroid Build Coastguard Worker 1437*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1438*da0073e9SAndroid Build Coastguard Worker b = a.clone() 1439*da0073e9SAndroid Build Coastguard Worker 1440*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(hook2) 1441*da0073e9SAndroid Build Coastguard Worker b.register_hook(hook1) 1442*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(hook2) 1443*da0073e9SAndroid Build Coastguard Worker 1444*da0073e9SAndroid Build Coastguard Worker acc = b.grad_fn.next_functions[0][0] 1445*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook1) 1446*da0073e9SAndroid Build Coastguard Worker acc.register_prehook(hook2) 1447*da0073e9SAndroid Build Coastguard Worker a.register_hook(hook1) 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker b.sum().backward(retain_graph=True) 1450*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log, [1, 2, 2, 1, 1, 2]) 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker # grad also runs hooks on accumulate grad nodes, even though 1453*da0073e9SAndroid Build Coastguard Worker # the accumulate grad nodes are not actually executed 1454*da0073e9SAndroid Build Coastguard Worker log = [] 1455*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(b.sum(), inputs=(a,), retain_graph=True) 1456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log, [1, 2, 2, 1, 1]) 1457*da0073e9SAndroid Build Coastguard Worker 1458*da0073e9SAndroid Build Coastguard Worker log = [] 1459*da0073e9SAndroid Build Coastguard Worker b.sum().backward(inputs=(b,)) 1460*da0073e9SAndroid Build Coastguard Worker self.assertEqual(log, [1, 2, 2]) 1461*da0073e9SAndroid Build Coastguard Worker # retains_grad hooks would not observe modifications by all pre hooks 1462*da0073e9SAndroid Build Coastguard Worker # because they are executed after 1463*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.item(), 3) 1464*da0073e9SAndroid Build Coastguard Worker 1465*da0073e9SAndroid Build Coastguard Worker def test_retains_grad_can_always_observe_tensor_prehook(self): 1466*da0073e9SAndroid Build Coastguard Worker def tensor_prehook(g): 1467*da0073e9SAndroid Build Coastguard Worker return g * 2 1468*da0073e9SAndroid Build Coastguard Worker 1469*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1470*da0073e9SAndroid Build Coastguard Worker b = a.clone() 1471*da0073e9SAndroid Build Coastguard Worker b.register_hook(tensor_prehook) 1472*da0073e9SAndroid Build Coastguard Worker b.retain_grad() 1473*da0073e9SAndroid Build Coastguard Worker b.register_hook(tensor_prehook) 1474*da0073e9SAndroid Build Coastguard Worker 1475*da0073e9SAndroid Build Coastguard Worker b.clone().backward() 1476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.item(), 4) 1477*da0073e9SAndroid Build Coastguard Worker 1478*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1479*da0073e9SAndroid Build Coastguard Worker b = a.clone() 1480*da0073e9SAndroid Build Coastguard Worker b.retain_grad() 1481*da0073e9SAndroid Build Coastguard Worker b.register_hook(tensor_prehook) 1482*da0073e9SAndroid Build Coastguard Worker 1483*da0073e9SAndroid Build Coastguard Worker b.clone().backward() 1484*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.item(), 2) 1485*da0073e9SAndroid Build Coastguard Worker 1486*da0073e9SAndroid Build Coastguard Worker def test_accumulate_grad_posthooks_can_observe_tensor_prehook(self): 1487*da0073e9SAndroid Build Coastguard Worker # Post hooks on accumulate should be able to observe changes to 1488*da0073e9SAndroid Build Coastguard Worker # grad made by tensor prehooks 1489*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1490*da0073e9SAndroid Build Coastguard Worker 1491*da0073e9SAndroid Build Coastguard Worker def tensor_prehook(g): 1492*da0073e9SAndroid Build Coastguard Worker return g * 2 1493*da0073e9SAndroid Build Coastguard Worker 1494*da0073e9SAndroid Build Coastguard Worker def posthook(gO, gI): 1495*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(gI[0], a * 2)) 1496*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(gO), 0) 1497*da0073e9SAndroid Build Coastguard Worker 1498*da0073e9SAndroid Build Coastguard Worker def prehook(gI): 1499*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(gI[0], a * 2)) 1500*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(gI), 1) 1501*da0073e9SAndroid Build Coastguard Worker 1502*da0073e9SAndroid Build Coastguard Worker b = a.clone() 1503*da0073e9SAndroid Build Coastguard Worker acc = b.grad_fn.next_functions[0][0] 1504*da0073e9SAndroid Build Coastguard Worker acc.register_hook(posthook) 1505*da0073e9SAndroid Build Coastguard Worker acc.register_prehook(prehook) 1506*da0073e9SAndroid Build Coastguard Worker a.register_hook(tensor_prehook) 1507*da0073e9SAndroid Build Coastguard Worker 1508*da0073e9SAndroid Build Coastguard Worker b.backward() 1509*da0073e9SAndroid Build Coastguard Worker 1510*da0073e9SAndroid Build Coastguard Worker def test_accumulate_grad_posthooks_should_not_execute(self): 1511*da0073e9SAndroid Build Coastguard Worker def tensor_prehook(g): 1512*da0073e9SAndroid Build Coastguard Worker raise RuntimeError 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker def posthook(gO, gI): 1515*da0073e9SAndroid Build Coastguard Worker raise RuntimeError 1516*da0073e9SAndroid Build Coastguard Worker 1517*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1518*da0073e9SAndroid Build Coastguard Worker a.register_hook(tensor_prehook) 1519*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1.0, requires_grad=True) 1520*da0073e9SAndroid Build Coastguard Worker c = a.clone() 1521*da0073e9SAndroid Build Coastguard Worker acc = c.grad_fn.next_functions[0][0] 1522*da0073e9SAndroid Build Coastguard Worker acc.register_hook(posthook) 1523*da0073e9SAndroid Build Coastguard Worker 1524*da0073e9SAndroid Build Coastguard Worker out = a + b + c 1525*da0073e9SAndroid Build Coastguard Worker out.sum().backward(inputs=[b]) 1526*da0073e9SAndroid Build Coastguard Worker 1527*da0073e9SAndroid Build Coastguard Worker def test_hook_edge_case_when_called_with_grad(self): 1528*da0073e9SAndroid Build Coastguard Worker # grad executes the tensor hooks of the next node but not 1529*da0073e9SAndroid Build Coastguard Worker # grad_fn pre hooks or the post hooks 1530*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 1531*da0073e9SAndroid Build Coastguard Worker b = a * 2 1532*da0073e9SAndroid Build Coastguard Worker c = b * 2 1533*da0073e9SAndroid Build Coastguard Worker 1534*da0073e9SAndroid Build Coastguard Worker tensor_hook_count = [0] 1535*da0073e9SAndroid Build Coastguard Worker prehook_count = [0] 1536*da0073e9SAndroid Build Coastguard Worker posthook_count = [0] 1537*da0073e9SAndroid Build Coastguard Worker 1538*da0073e9SAndroid Build Coastguard Worker def reset_counts(): 1539*da0073e9SAndroid Build Coastguard Worker nonlocal tensor_hook_count, prehook_count, posthook_count 1540*da0073e9SAndroid Build Coastguard Worker tensor_hook_count = [0] 1541*da0073e9SAndroid Build Coastguard Worker prehook_count = [0] 1542*da0073e9SAndroid Build Coastguard Worker posthook_count = [0] 1543*da0073e9SAndroid Build Coastguard Worker 1544*da0073e9SAndroid Build Coastguard Worker def tensor_prehook(g): 1545*da0073e9SAndroid Build Coastguard Worker tensor_hook_count[0] += 1 1546*da0073e9SAndroid Build Coastguard Worker 1547*da0073e9SAndroid Build Coastguard Worker def prehook(g): 1548*da0073e9SAndroid Build Coastguard Worker prehook_count[0] += 1 1549*da0073e9SAndroid Build Coastguard Worker 1550*da0073e9SAndroid Build Coastguard Worker def posthook(gI, gO): 1551*da0073e9SAndroid Build Coastguard Worker posthook_count[0] += 1 1552*da0073e9SAndroid Build Coastguard Worker 1553*da0073e9SAndroid Build Coastguard Worker a.register_hook(tensor_prehook) 1554*da0073e9SAndroid Build Coastguard Worker b.register_hook(tensor_prehook) 1555*da0073e9SAndroid Build Coastguard Worker acc = b.grad_fn.next_functions[0][0] 1556*da0073e9SAndroid Build Coastguard Worker acc.register_hook(posthook) 1557*da0073e9SAndroid Build Coastguard Worker acc.register_prehook(prehook) 1558*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_hook(posthook) 1559*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook) 1560*da0073e9SAndroid Build Coastguard Worker 1561*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(c, inputs=(b), retain_graph=True) 1562*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_hook_count[0], 1) 1563*da0073e9SAndroid Build Coastguard Worker self.assertEqual(posthook_count[0], 0) 1564*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prehook_count[0], 0) 1565*da0073e9SAndroid Build Coastguard Worker reset_counts() 1566*da0073e9SAndroid Build Coastguard Worker 1567*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(c, inputs=(a, b), retain_graph=True) 1568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_hook_count[0], 2) 1569*da0073e9SAndroid Build Coastguard Worker self.assertEqual(posthook_count[0], 1) 1570*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prehook_count[0], 1) 1571*da0073e9SAndroid Build Coastguard Worker reset_counts() 1572*da0073e9SAndroid Build Coastguard Worker 1573*da0073e9SAndroid Build Coastguard Worker c.backward(retain_graph=True) 1574*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_hook_count[0], 2) 1575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(posthook_count[0], 2) 1576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prehook_count[0], 2) 1577*da0073e9SAndroid Build Coastguard Worker reset_counts() 1578*da0073e9SAndroid Build Coastguard Worker 1579*da0073e9SAndroid Build Coastguard Worker c.backward(inputs=(a, b), retain_graph=True) 1580*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor_hook_count[0], 2) 1581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(posthook_count[0], 2) 1582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prehook_count[0], 2) 1583*da0073e9SAndroid Build Coastguard Worker 1584*da0073e9SAndroid Build Coastguard Worker def test_sharded_grad(self): 1585*da0073e9SAndroid Build Coastguard Worker leaves = [torch.zeros(5, 5, requires_grad=True) for _ in range(10)] 1586*da0073e9SAndroid Build Coastguard Worker intermediates = [l * i + l * l for i, l in enumerate(leaves)] 1587*da0073e9SAndroid Build Coastguard Worker loss = sum(v * i for i, v in enumerate(intermediates)).sum() 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker # define a helper for dividing intermediates into groups 1590*da0073e9SAndroid Build Coastguard Worker def group(l, group_size): 1591*da0073e9SAndroid Build Coastguard Worker return (l[i : i + group_size] for i in range(0, len(l), group_size)) 1592*da0073e9SAndroid Build Coastguard Worker 1593*da0073e9SAndroid Build Coastguard Worker # Compute the d loss / d intermediates in chunks of shard_size 1594*da0073e9SAndroid Build Coastguard Worker shard_size = 2 1595*da0073e9SAndroid Build Coastguard Worker d_intermediates = [ 1596*da0073e9SAndroid Build Coastguard Worker d_i 1597*da0073e9SAndroid Build Coastguard Worker for intermediates_batch in group(intermediates, shard_size) 1598*da0073e9SAndroid Build Coastguard Worker for d_i in torch.autograd.grad(loss, intermediates_batch) 1599*da0073e9SAndroid Build Coastguard Worker ] 1600*da0073e9SAndroid Build Coastguard Worker # Compute rest of backward pass 1601*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(intermediates, d_intermediates) 1602*da0073e9SAndroid Build Coastguard Worker 1603*da0073e9SAndroid Build Coastguard Worker for i, l in enumerate(leaves): 1604*da0073e9SAndroid Build Coastguard Worker self.assertEqual(l.grad, i * i * (1 + l)) 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker def test_backward_badcalls(self): 1607*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 1608*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1609*da0073e9SAndroid Build Coastguard Worker x.backward() 1610*da0073e9SAndroid Build Coastguard Worker 1611*da0073e9SAndroid Build Coastguard Worker def test_grad_badcalls(self): 1612*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1) 1613*da0073e9SAndroid Build Coastguard Worker y = x**2 1614*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1615*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(x, y) 1616*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1617*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x) 1618*da0073e9SAndroid Build Coastguard Worker 1619*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 1620*da0073e9SAndroid Build Coastguard Worker y = x**2 1621*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x) # this should succeed now 1622*da0073e9SAndroid Build Coastguard Worker 1623*da0073e9SAndroid Build Coastguard Worker def test_grad_empty_inputs(self): 1624*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0], requires_grad=True) 1625*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "grad requires non-empty inputs."): 1626*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(2 * x, [], grad_outputs=torch.tensor([1.0])) 1627*da0073e9SAndroid Build Coastguard Worker 1628*da0073e9SAndroid Build Coastguard Worker def test_grad_fn_badcalls(self): 1629*da0073e9SAndroid Build Coastguard Worker error_regex = "expected .* arguments, got .* instead" 1630*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 1631*da0073e9SAndroid Build Coastguard Worker y = x**2 1632*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, error_regex): 1633*da0073e9SAndroid Build Coastguard Worker y.grad_fn(x.detach(), x.detach()) # too many 1634*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, error_regex): 1635*da0073e9SAndroid Build Coastguard Worker y.grad_fn() # too few 1636*da0073e9SAndroid Build Coastguard Worker 1637*da0073e9SAndroid Build Coastguard Worker y.grad_fn(x.detach()) # this should succeed 1638*da0073e9SAndroid Build Coastguard Worker 1639*da0073e9SAndroid Build Coastguard Worker def test_grad_unreachable(self): 1640*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 1641*da0073e9SAndroid Build Coastguard Worker y = torch.ones(1, requires_grad=True) 1642*da0073e9SAndroid Build Coastguard Worker # Make sure x and y have grad accumulators allocated 1643*da0073e9SAndroid Build Coastguard Worker z = x * 2 1644*da0073e9SAndroid Build Coastguard Worker w = y * 2 1645*da0073e9SAndroid Build Coastguard Worker 1646*da0073e9SAndroid Build Coastguard Worker grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=True) 1647*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_x, x * 2) 1648*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grad_y) 1649*da0073e9SAndroid Build Coastguard Worker 1650*da0073e9SAndroid Build Coastguard Worker # This is slightly different than the case above, because z doesn't even 1651*da0073e9SAndroid Build Coastguard Worker # have a grad accumulator allocated. 1652*da0073e9SAndroid Build Coastguard Worker z = torch.ones(1, requires_grad=True) 1653*da0073e9SAndroid Build Coastguard Worker grad_x, grad_z = torch.autograd.grad(x * 2, [x, z], allow_unused=True) 1654*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_x, x * 2) 1655*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grad_z) 1656*da0073e9SAndroid Build Coastguard Worker 1657*da0073e9SAndroid Build Coastguard Worker # allow_unused=False, but grads contains None inside, should throw 1658*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Set allow_unused=True"): 1659*da0073e9SAndroid Build Coastguard Worker grad_x, grad_y = torch.autograd.grad(x * 2, [x, y], allow_unused=False) 1660*da0073e9SAndroid Build Coastguard Worker 1661*da0073e9SAndroid Build Coastguard Worker def test_grad_unreachable_discovery(self): 1662*da0073e9SAndroid Build Coastguard Worker # Test that certain nodes are not erroneously executed when an input 1663*da0073e9SAndroid Build Coastguard Worker # is unreachable. See #39784 1664*da0073e9SAndroid Build Coastguard Worker class MyFunc(torch.autograd.Function): 1665*da0073e9SAndroid Build Coastguard Worker @staticmethod 1666*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 1667*da0073e9SAndroid Build Coastguard Worker return x 1668*da0073e9SAndroid Build Coastguard Worker 1669*da0073e9SAndroid Build Coastguard Worker @staticmethod 1670*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 1671*da0073e9SAndroid Build Coastguard Worker self.fail("This node should not be executed!") 1672*da0073e9SAndroid Build Coastguard Worker 1673*da0073e9SAndroid Build Coastguard Worker x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2) 1674*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, requires_grad=True) 1675*da0073e9SAndroid Build Coastguard Worker (gY,) = torch.autograd.grad(x, (y,), allow_unused=True) 1676*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(gY) 1677*da0073e9SAndroid Build Coastguard Worker 1678*da0073e9SAndroid Build Coastguard Worker x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2) 1679*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, requires_grad=True) 1680*da0073e9SAndroid Build Coastguard Worker z = torch.randn(1, requires_grad=True) 1681*da0073e9SAndroid Build Coastguard Worker (gY, gZ) = torch.autograd.grad(x + z, (y, z), allow_unused=True) 1682*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(gY) 1683*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(gZ) 1684*da0073e9SAndroid Build Coastguard Worker 1685*da0073e9SAndroid Build Coastguard Worker x = MyFunc.apply(torch.randn(1, requires_grad=True) * 2) 1686*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, requires_grad=True) 1687*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(x, inputs=(y,)) # allow_unused is implicitly True! 1688*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(y.grad) 1689*da0073e9SAndroid Build Coastguard Worker 1690*da0073e9SAndroid Build Coastguard Worker def test_grad_batched_grad(self): 1691*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 1692*da0073e9SAndroid Build Coastguard Worker 1693*da0073e9SAndroid Build Coastguard Worker out = x.clone() # Size([2, 2]) 1694*da0073e9SAndroid Build Coastguard Worker batched_grad = ( 1695*da0073e9SAndroid Build Coastguard Worker torch.arange(3).expand(2, 2, 3).transpose(0, 2) 1696*da0073e9SAndroid Build Coastguard Worker ) # Size([3, 2, 2]) 1697*da0073e9SAndroid Build Coastguard Worker (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True) 1698*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1699*da0073e9SAndroid Build Coastguard Worker grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype) 1700*da0073e9SAndroid Build Coastguard Worker ) 1701*da0073e9SAndroid Build Coastguard Worker 1702*da0073e9SAndroid Build Coastguard Worker # Detect shape mismatch 1703*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones(2, 2) 1704*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1705*da0073e9SAndroid Build Coastguard Worker RuntimeError, "If `is_grads_batched=True`, we interpret the first" 1706*da0073e9SAndroid Build Coastguard Worker ): 1707*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1708*da0073e9SAndroid Build Coastguard Worker outputs=out, 1709*da0073e9SAndroid Build Coastguard Worker grad_outputs=(grad_out,), 1710*da0073e9SAndroid Build Coastguard Worker inputs=(x,), 1711*da0073e9SAndroid Build Coastguard Worker is_grads_batched=True, 1712*da0073e9SAndroid Build Coastguard Worker ) 1713*da0073e9SAndroid Build Coastguard Worker 1714*da0073e9SAndroid Build Coastguard Worker # Scalar outputs 1715*da0073e9SAndroid Build Coastguard Worker out = x.sum() # Size([]) 1716*da0073e9SAndroid Build Coastguard Worker batched_grad = torch.arange(3) # Size([3]) 1717*da0073e9SAndroid Build Coastguard Worker (grad,) = torch.autograd.grad(out, (x,), (batched_grad,), is_grads_batched=True) 1718*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 1719*da0073e9SAndroid Build Coastguard Worker grad, torch.arange(3).expand(2, 2, 3).transpose(0, 2).to(dtype=grad.dtype) 1720*da0073e9SAndroid Build Coastguard Worker ) 1721*da0073e9SAndroid Build Coastguard Worker 1722*da0073e9SAndroid Build Coastguard Worker # We consider scalar and sized-1 to be a mismatch. This is consistent with current non-batched behavior. 1723*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones(2).unsqueeze(1) 1724*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1725*da0073e9SAndroid Build Coastguard Worker RuntimeError, "If `is_grads_batched=True`, we interpret the first" 1726*da0073e9SAndroid Build Coastguard Worker ): 1727*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad( 1728*da0073e9SAndroid Build Coastguard Worker outputs=out, 1729*da0073e9SAndroid Build Coastguard Worker grad_outputs=(grad_out,), 1730*da0073e9SAndroid Build Coastguard Worker inputs=(x,), 1731*da0073e9SAndroid Build Coastguard Worker is_grads_batched=True, 1732*da0073e9SAndroid Build Coastguard Worker ) 1733*da0073e9SAndroid Build Coastguard Worker 1734*da0073e9SAndroid Build Coastguard Worker def test_hooks(self): 1735*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 1736*da0073e9SAndroid Build Coastguard Worker y = torch.ones(5, 5) * 4 1737*da0073e9SAndroid Build Coastguard Worker y.requires_grad_(True) 1738*da0073e9SAndroid Build Coastguard Worker 1739*da0073e9SAndroid Build Coastguard Worker counter = [0] 1740*da0073e9SAndroid Build Coastguard Worker 1741*da0073e9SAndroid Build Coastguard Worker def bw_hook(inc, grad): 1742*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(grad, torch.Tensor) 1743*da0073e9SAndroid Build Coastguard Worker counter[0] += inc 1744*da0073e9SAndroid Build Coastguard Worker 1745*da0073e9SAndroid Build Coastguard Worker z = x**2 + x * 2 + x * y + y 1746*da0073e9SAndroid Build Coastguard Worker x.register_hook(lambda *args: bw_hook(0, *args)) 1747*da0073e9SAndroid Build Coastguard Worker test = z.register_hook(lambda *args: bw_hook(1, *args)) 1748*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(5, 5), retain_graph=True) 1749*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 1750*da0073e9SAndroid Build Coastguard Worker 1751*da0073e9SAndroid Build Coastguard Worker test2 = z.register_hook(lambda *args: bw_hook(2, *args)) 1752*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(5, 5), retain_graph=True) 1753*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 4) 1754*da0073e9SAndroid Build Coastguard Worker 1755*da0073e9SAndroid Build Coastguard Worker test2.remove() 1756*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(5, 5), retain_graph=True) 1757*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 5) 1758*da0073e9SAndroid Build Coastguard Worker 1759*da0073e9SAndroid Build Coastguard Worker def bw_hook_modify(grad): 1760*da0073e9SAndroid Build Coastguard Worker return grad.mul(2) 1761*da0073e9SAndroid Build Coastguard Worker 1762*da0073e9SAndroid Build Coastguard Worker test.remove() 1763*da0073e9SAndroid Build Coastguard Worker z.register_hook(bw_hook_modify) 1764*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1765*da0073e9SAndroid Build Coastguard Worker y.grad.zero_() 1766*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(5, 5), retain_graph=True) 1767*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, (x + 1) * 2) 1768*da0073e9SAndroid Build Coastguard Worker 1769*da0073e9SAndroid Build Coastguard Worker y.register_hook(bw_hook_modify) 1770*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1771*da0073e9SAndroid Build Coastguard Worker y.grad.zero_() 1772*da0073e9SAndroid Build Coastguard Worker z.backward(torch.ones(5, 5)) 1773*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, (x + 1) * 4) 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker def _get_mul2(self, use_custom_function): 1776*da0073e9SAndroid Build Coastguard Worker if use_custom_function: 1777*da0073e9SAndroid Build Coastguard Worker 1778*da0073e9SAndroid Build Coastguard Worker class Mul2(Function): 1779*da0073e9SAndroid Build Coastguard Worker @staticmethod 1780*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 1781*da0073e9SAndroid Build Coastguard Worker return x * 2 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker @staticmethod 1784*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 1785*da0073e9SAndroid Build Coastguard Worker return gO * 2 1786*da0073e9SAndroid Build Coastguard Worker 1787*da0073e9SAndroid Build Coastguard Worker return Mul2.apply 1788*da0073e9SAndroid Build Coastguard Worker else: 1789*da0073e9SAndroid Build Coastguard Worker return lambda x: x * 2 1790*da0073e9SAndroid Build Coastguard Worker 1791*da0073e9SAndroid Build Coastguard Worker def test_grad_fn_prehooks(self): 1792*da0073e9SAndroid Build Coastguard Worker for use_custom_function in (True, False): 1793*da0073e9SAndroid Build Coastguard Worker mul2 = self._get_mul2(use_custom_function) 1794*da0073e9SAndroid Build Coastguard Worker 1795*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True) 1796*da0073e9SAndroid Build Coastguard Worker b = mul2(a) 1797*da0073e9SAndroid Build Coastguard Worker 1798*da0073e9SAndroid Build Coastguard Worker post_counter = [0] 1799*da0073e9SAndroid Build Coastguard Worker pre_counter = [0] 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker def posthook(grad_input, grad_output): 1802*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pre_counter[0], 3) 1803*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(grad_output[0], torch.ones(1) * 8)) 1804*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(grad_input[0], torch.ones(1) * 16)) 1805*da0073e9SAndroid Build Coastguard Worker post_counter[0] += 1 1806*da0073e9SAndroid Build Coastguard Worker return grad_input 1807*da0073e9SAndroid Build Coastguard Worker 1808*da0073e9SAndroid Build Coastguard Worker def prehook(grad_output): 1809*da0073e9SAndroid Build Coastguard Worker pre_counter[0] += 1 1810*da0073e9SAndroid Build Coastguard Worker return (grad_output[0] * 2,) 1811*da0073e9SAndroid Build Coastguard Worker 1812*da0073e9SAndroid Build Coastguard Worker # register posthook x 2 1813*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_hook(posthook) 1814*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_hook(posthook) 1815*da0073e9SAndroid Build Coastguard Worker # register prehook x 3 1816*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook) 1817*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(lambda x: None) 1818*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook) 1819*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook) 1820*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(lambda x: x) 1821*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(lambda x: None) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 1824*da0073e9SAndroid Build Coastguard Worker 1825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(post_counter[0], 2) 1826*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pre_counter[0], 3) 1827*da0073e9SAndroid Build Coastguard Worker 1828*da0073e9SAndroid Build Coastguard Worker # Return None 1829*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 3, requires_grad=True) 1830*da0073e9SAndroid Build Coastguard Worker b = mul2(a) 1831*da0073e9SAndroid Build Coastguard Worker 1832*da0073e9SAndroid Build Coastguard Worker def prehook(grad_output): 1833*da0073e9SAndroid Build Coastguard Worker pre_counter[0] += 1 1834*da0073e9SAndroid Build Coastguard Worker return None 1835*da0073e9SAndroid Build Coastguard Worker 1836*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook) 1837*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 1838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pre_counter[0], 4) 1839*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker def test_grad_fn_prehooks_multiple_outputs(self): 1842*da0073e9SAndroid Build Coastguard Worker # Compute gradients without hooks 1843*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 3, requires_grad=True) 1844*da0073e9SAndroid Build Coastguard Worker var, mean = torch.var_mean(b, dim=0) 1845*da0073e9SAndroid Build Coastguard Worker (var + mean).sum().backward() 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker # Compute gradients with hooks 1848*da0073e9SAndroid Build Coastguard Worker a = b.detach().requires_grad_() 1849*da0073e9SAndroid Build Coastguard Worker counter = [0] 1850*da0073e9SAndroid Build Coastguard Worker 1851*da0073e9SAndroid Build Coastguard Worker def prehook(grad_output): 1852*da0073e9SAndroid Build Coastguard Worker gvar, gmean = grad_output 1853*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 1854*da0073e9SAndroid Build Coastguard Worker return (gvar * 2, gmean * 2) 1855*da0073e9SAndroid Build Coastguard Worker 1856*da0073e9SAndroid Build Coastguard Worker var, mean = torch.var_mean(a, dim=0) 1857*da0073e9SAndroid Build Coastguard Worker mean.grad_fn.register_prehook(prehook) 1858*da0073e9SAndroid Build Coastguard Worker (var + mean).sum().backward() 1859*da0073e9SAndroid Build Coastguard Worker 1860*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 1861*da0073e9SAndroid Build Coastguard Worker # Compare 1862*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a.grad, b.grad * 2)) 1863*da0073e9SAndroid Build Coastguard Worker 1864*da0073e9SAndroid Build Coastguard Worker # Test with custom Function 1865*da0073e9SAndroid Build Coastguard Worker class DoubleMul2(Function): 1866*da0073e9SAndroid Build Coastguard Worker @staticmethod 1867*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, a, y): 1868*da0073e9SAndroid Build Coastguard Worker ctx.a = a 1869*da0073e9SAndroid Build Coastguard Worker return a * x * 2, a, a * y * 2 1870*da0073e9SAndroid Build Coastguard Worker 1871*da0073e9SAndroid Build Coastguard Worker @staticmethod 1872*da0073e9SAndroid Build Coastguard Worker def backward(ctx, g1, _a, g2): 1873*da0073e9SAndroid Build Coastguard Worker return ctx.a * g1 * 2, None, ctx.a * g2 * 2 1874*da0073e9SAndroid Build Coastguard Worker 1875*da0073e9SAndroid Build Coastguard Worker counter = [0] 1876*da0073e9SAndroid Build Coastguard Worker 1877*da0073e9SAndroid Build Coastguard Worker def prehook(grad_output): 1878*da0073e9SAndroid Build Coastguard Worker g1, ga, g2 = grad_output 1879*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ga) 1880*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 1881*da0073e9SAndroid Build Coastguard Worker return (g1 * 2, None, g2 * 2) 1882*da0073e9SAndroid Build Coastguard Worker 1883*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, requires_grad=True) 1884*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 3, requires_grad=True) 1885*da0073e9SAndroid Build Coastguard Worker k = 3 1886*da0073e9SAndroid Build Coastguard Worker c, _, d = DoubleMul2.apply(a, k, b) 1887*da0073e9SAndroid Build Coastguard Worker c.grad_fn.register_prehook(prehook) 1888*da0073e9SAndroid Build Coastguard Worker (c + d).sum().backward() 1889*da0073e9SAndroid Build Coastguard Worker 1890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 1891*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a.grad, torch.ones(1) * 4 * k)) 1892*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(b.grad, torch.ones(1) * 4 * k)) 1893*da0073e9SAndroid Build Coastguard Worker 1894*da0073e9SAndroid Build Coastguard Worker def test_grad_fn_prehooks_remove_hooks(self): 1895*da0073e9SAndroid Build Coastguard Worker for use_custom_function in (True, False): 1896*da0073e9SAndroid Build Coastguard Worker mul2 = self._get_mul2(use_custom_function) 1897*da0073e9SAndroid Build Coastguard Worker 1898*da0073e9SAndroid Build Coastguard Worker # Simply remove hooks 1899*da0073e9SAndroid Build Coastguard Worker 1900*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 3, requires_grad=True) 1901*da0073e9SAndroid Build Coastguard Worker b = mul2(a) 1902*da0073e9SAndroid Build Coastguard Worker counter = [0] 1903*da0073e9SAndroid Build Coastguard Worker 1904*da0073e9SAndroid Build Coastguard Worker def prehook(grad_output): 1905*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 1906*da0073e9SAndroid Build Coastguard Worker return None 1907*da0073e9SAndroid Build Coastguard Worker 1908*da0073e9SAndroid Build Coastguard Worker handle = b.grad_fn.register_prehook(prehook) 1909*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook) 1910*da0073e9SAndroid Build Coastguard Worker handle.remove() 1911*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 1912*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) 1913*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 1914*da0073e9SAndroid Build Coastguard Worker 1915*da0073e9SAndroid Build Coastguard Worker # Remove hooks during backward 1916*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 3, requires_grad=True) 1917*da0073e9SAndroid Build Coastguard Worker b = mul2(a) 1918*da0073e9SAndroid Build Coastguard Worker counter = [0] 1919*da0073e9SAndroid Build Coastguard Worker 1920*da0073e9SAndroid Build Coastguard Worker def prehook1(grad_output): 1921*da0073e9SAndroid Build Coastguard Worker handle2.remove() 1922*da0073e9SAndroid Build Coastguard Worker # Remove hook that is already removed is OK 1923*da0073e9SAndroid Build Coastguard Worker handle3.remove() 1924*da0073e9SAndroid Build Coastguard Worker return None 1925*da0073e9SAndroid Build Coastguard Worker 1926*da0073e9SAndroid Build Coastguard Worker def prehook2(grad_output): 1927*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 1928*da0073e9SAndroid Build Coastguard Worker return None 1929*da0073e9SAndroid Build Coastguard Worker 1930*da0073e9SAndroid Build Coastguard Worker # Hooks that registered first run first 1931*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_prehook(prehook1) 1932*da0073e9SAndroid Build Coastguard Worker handle2 = b.grad_fn.register_prehook(prehook2) 1933*da0073e9SAndroid Build Coastguard Worker handle3 = b.grad_fn.register_prehook(prehook2) 1934*da0073e9SAndroid Build Coastguard Worker handle3.remove() 1935*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 1936*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a.grad, torch.ones(3, 3) * 2)) 1937*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 1938*da0073e9SAndroid Build Coastguard Worker 1939*da0073e9SAndroid Build Coastguard Worker def test_node_post_hook_registered_during_unpack_hook(self): 1940*da0073e9SAndroid Build Coastguard Worker """ 1941*da0073e9SAndroid Build Coastguard Worker Test that post hooks registered during one of the node's 1942*da0073e9SAndroid Build Coastguard Worker unpack hooks are properly restricted and will run properly. 1943*da0073e9SAndroid Build Coastguard Worker """ 1944*da0073e9SAndroid Build Coastguard Worker test_case = self 1945*da0073e9SAndroid Build Coastguard Worker 1946*da0073e9SAndroid Build Coastguard Worker class RegisterPostNodeHook(torch.autograd.graph.saved_tensors_hooks): 1947*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1948*da0073e9SAndroid Build Coastguard Worker def pack_tensor(tensor: torch.Tensor) -> torch.Tensor: 1949*da0073e9SAndroid Build Coastguard Worker return tensor 1950*da0073e9SAndroid Build Coastguard Worker 1951*da0073e9SAndroid Build Coastguard Worker def unpack_tensor(tensor: torch.Tensor) -> torch.Tensor: 1952*da0073e9SAndroid Build Coastguard Worker node = torch._C._current_autograd_node() 1953*da0073e9SAndroid Build Coastguard Worker 1954*da0073e9SAndroid Build Coastguard Worker def hook(outputs, inputs): 1955*da0073e9SAndroid Build Coastguard Worker # Assert that inputs passed in are None 1956*da0073e9SAndroid Build Coastguard Worker test_case.assertTrue(all(i is None for i in inputs)) 1957*da0073e9SAndroid Build Coastguard Worker halved_outputs = tuple( 1958*da0073e9SAndroid Build Coastguard Worker o / 2.0 if o is not None else None for o in outputs 1959*da0073e9SAndroid Build Coastguard Worker ) 1960*da0073e9SAndroid Build Coastguard Worker return halved_outputs 1961*da0073e9SAndroid Build Coastguard Worker 1962*da0073e9SAndroid Build Coastguard Worker node.register_hook(hook) 1963*da0073e9SAndroid Build Coastguard Worker return tensor 1964*da0073e9SAndroid Build Coastguard Worker 1965*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_tensor, unpack_tensor) 1966*da0073e9SAndroid Build Coastguard Worker 1967*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 3, requires_grad=True) 1968*da0073e9SAndroid Build Coastguard Worker 1969*da0073e9SAndroid Build Coastguard Worker def model(): 1970*da0073e9SAndroid Build Coastguard Worker var, mean = torch.var_mean(a, dim=0) 1971*da0073e9SAndroid Build Coastguard Worker loss = (var + mean).sum() 1972*da0073e9SAndroid Build Coastguard Worker loss.backward() 1973*da0073e9SAndroid Build Coastguard Worker 1974*da0073e9SAndroid Build Coastguard Worker model() 1975*da0073e9SAndroid Build Coastguard Worker ref_grad = a.grad.clone() 1976*da0073e9SAndroid Build Coastguard Worker 1977*da0073e9SAndroid Build Coastguard Worker with RegisterPostNodeHook(): 1978*da0073e9SAndroid Build Coastguard Worker model() 1979*da0073e9SAndroid Build Coastguard Worker 1980*da0073e9SAndroid Build Coastguard Worker # Verify that the post hook got called and the grad propagation worked 1981*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ref_grad / 2.0 + ref_grad, a.grad) 1982*da0073e9SAndroid Build Coastguard Worker 1983*da0073e9SAndroid Build Coastguard Worker def test_hooks_cpp(self): 1984*da0073e9SAndroid Build Coastguard Worker # Tests hooks for autograd function implemented in C++ 1985*da0073e9SAndroid Build Coastguard Worker bn = torch.nn.BatchNorm1d(5, affine=False) 1986*da0073e9SAndroid Build Coastguard Worker bn.double() 1987*da0073e9SAndroid Build Coastguard Worker bn.eval() 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Worker counter = [0] 1990*da0073e9SAndroid Build Coastguard Worker 1991*da0073e9SAndroid Build Coastguard Worker def bw_hook(grad): 1992*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 1993*da0073e9SAndroid Build Coastguard Worker return grad * 2 1994*da0073e9SAndroid Build Coastguard Worker 1995*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, dtype=torch.double, requires_grad=True) 1996*da0073e9SAndroid Build Coastguard Worker z = bn(x) 1997*da0073e9SAndroid Build Coastguard Worker z.register_hook(bw_hook) 1998*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 1999*da0073e9SAndroid Build Coastguard Worker 2000*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1, msg="bw_hook not called") 2001*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2002*da0073e9SAndroid Build Coastguard Worker x.grad, torch.ones(5, 5, dtype=torch.double) * 2, atol=1e-5, rtol=0 2003*da0073e9SAndroid Build Coastguard Worker ) 2004*da0073e9SAndroid Build Coastguard Worker 2005*da0073e9SAndroid Build Coastguard Worker def test_hook_none(self): 2006*da0073e9SAndroid Build Coastguard Worker # WARNING: this is a test for autograd internals. 2007*da0073e9SAndroid Build Coastguard Worker # You should never have to use such things in your code. 2008*da0073e9SAndroid Build Coastguard Worker class NoneGradientFunction(Function): 2009*da0073e9SAndroid Build Coastguard Worker @staticmethod 2010*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 2011*da0073e9SAndroid Build Coastguard Worker assert ctx.needs_input_grad[0] 2012*da0073e9SAndroid Build Coastguard Worker assert not ctx.needs_input_grad[1] 2013*da0073e9SAndroid Build Coastguard Worker return x, y 2014*da0073e9SAndroid Build Coastguard Worker 2015*da0073e9SAndroid Build Coastguard Worker @staticmethod 2016*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_x, grad_y): 2017*da0073e9SAndroid Build Coastguard Worker return grad_x, None 2018*da0073e9SAndroid Build Coastguard Worker 2019*da0073e9SAndroid Build Coastguard Worker was_called = [False] 2020*da0073e9SAndroid Build Coastguard Worker 2021*da0073e9SAndroid Build Coastguard Worker def hook(grad): 2022*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(grad) 2023*da0073e9SAndroid Build Coastguard Worker was_called[0] = True 2024*da0073e9SAndroid Build Coastguard Worker 2025*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 2026*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5) 2027*da0073e9SAndroid Build Coastguard Worker rx, ry = NoneGradientFunction.apply(x, y) 2028*da0073e9SAndroid Build Coastguard Worker rx.register_hook(hook) 2029*da0073e9SAndroid Build Coastguard Worker ry.register_hook(hook) 2030*da0073e9SAndroid Build Coastguard Worker sum(rx, ry).sum().backward() 2031*da0073e9SAndroid Build Coastguard Worker self.assertTrue(was_called[0]) 2032*da0073e9SAndroid Build Coastguard Worker 2033*da0073e9SAndroid Build Coastguard Worker def test_retain_grad(self): 2034*da0073e9SAndroid Build Coastguard Worker input = torch.rand(1, 3, requires_grad=True) 2035*da0073e9SAndroid Build Coastguard Worker h1 = input * 3 2036*da0073e9SAndroid Build Coastguard Worker out = (h1 * h1).sum() 2037*da0073e9SAndroid Build Coastguard Worker 2038*da0073e9SAndroid Build Coastguard Worker # It should be possible to call retain_grad() multiple times 2039*da0073e9SAndroid Build Coastguard Worker h1.retain_grad() 2040*da0073e9SAndroid Build Coastguard Worker h1.retain_grad() 2041*da0073e9SAndroid Build Coastguard Worker 2042*da0073e9SAndroid Build Coastguard Worker # Gradient should be accumulated 2043*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 2044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h1 * 2, h1.grad) 2045*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 2046*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h1 * 4, h1.grad) 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2049*da0073e9SAndroid Build Coastguard Worker input.grad.zero_() 2050*da0073e9SAndroid Build Coastguard Worker # It should be a no-op for leaves 2051*da0073e9SAndroid Build Coastguard Worker input.retain_grad() 2052*da0073e9SAndroid Build Coastguard Worker input.retain_grad() 2053*da0073e9SAndroid Build Coastguard Worker out.backward() 2054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input * 18, input.grad) 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker # NB: See test/cpp/api/autograd.cpp for more tests on the interaction between 2057*da0073e9SAndroid Build Coastguard Worker # retains_grad and hooks in cpp 2058*da0073e9SAndroid Build Coastguard Worker def test_retain_grad_inplace(self): 2059*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True).clone() 2060*da0073e9SAndroid Build Coastguard Worker a.retain_grad() 2061*da0073e9SAndroid Build Coastguard Worker a.mul_(2) 2062*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 2063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.tensor([1.0])) 2064*da0073e9SAndroid Build Coastguard Worker 2065*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True).clone() 2066*da0073e9SAndroid Build Coastguard Worker a.retain_grad() 2067*da0073e9SAndroid Build Coastguard Worker # Inplace multiple times is OK 2068*da0073e9SAndroid Build Coastguard Worker a.mul_(2) 2069*da0073e9SAndroid Build Coastguard Worker a.mul_(2) 2070*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 2071*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.tensor([1.0])) 2072*da0073e9SAndroid Build Coastguard Worker 2073*da0073e9SAndroid Build Coastguard Worker # When in-place over view is done, the retains_grad hooks should be 2074*da0073e9SAndroid Build Coastguard Worker # moved from base's original grad_fn to the copyslices node. 2075*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0], requires_grad=True).clone() 2076*da0073e9SAndroid Build Coastguard Worker x.retain_grad() 2077*da0073e9SAndroid Build Coastguard Worker x_view = x[:] 2078*da0073e9SAndroid Build Coastguard Worker x_view *= 2 2079*da0073e9SAndroid Build Coastguard Worker x *= 2 2080*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 2081*da0073e9SAndroid Build Coastguard Worker # The grad is 1, not 4, because we are computing grad wrt the latest 2082*da0073e9SAndroid Build Coastguard Worker # version of x. 2083*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.tensor([1.0])) 2084*da0073e9SAndroid Build Coastguard Worker 2085*da0073e9SAndroid Build Coastguard Worker # If the base did not originally require grad, there should be no hook 2086*da0073e9SAndroid Build Coastguard Worker # to move. Make sure this case runs without error. 2087*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(4) 2088*da0073e9SAndroid Build Coastguard Worker y = x.view(2, 2) 2089*da0073e9SAndroid Build Coastguard Worker y.add_(torch.randn(2, 2, requires_grad=True)) 2090*da0073e9SAndroid Build Coastguard Worker 2091*da0073e9SAndroid Build Coastguard Worker def test_retains_grad_inplace_multiple_outputs(self): 2092*da0073e9SAndroid Build Coastguard Worker class DoubleMul(Function): 2093*da0073e9SAndroid Build Coastguard Worker @staticmethod 2094*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 2095*da0073e9SAndroid Build Coastguard Worker return x * 2, x * 3 2096*da0073e9SAndroid Build Coastguard Worker 2097*da0073e9SAndroid Build Coastguard Worker @staticmethod 2098*da0073e9SAndroid Build Coastguard Worker def backward(ctx, g1, g2): 2099*da0073e9SAndroid Build Coastguard Worker return g1 * 2 + g2 * 3 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker var_mean = partial(torch.var_mean, dim=0) 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Worker for fn in (DoubleMul.apply, var_mean): 2104*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 3, requires_grad=True) 2105*da0073e9SAndroid Build Coastguard Worker var, mean = fn(b) 2106*da0073e9SAndroid Build Coastguard Worker var.retain_grad() 2107*da0073e9SAndroid Build Coastguard Worker mean.retain_grad() 2108*da0073e9SAndroid Build Coastguard Worker # node has two retains_grad hooks 2109*da0073e9SAndroid Build Coastguard Worker var.mul_(2) 2110*da0073e9SAndroid Build Coastguard Worker # the retain_grad hook multi-output node refers should now be a nullptr 2111*da0073e9SAndroid Build Coastguard Worker (var + mean).sum().backward() 2112*da0073e9SAndroid Build Coastguard Worker gvar = var.grad 2113*da0073e9SAndroid Build Coastguard Worker gmean = mean.grad 2114*da0073e9SAndroid Build Coastguard Worker 2115*da0073e9SAndroid Build Coastguard Worker a = b.detach().requires_grad_(True) 2116*da0073e9SAndroid Build Coastguard Worker var, mean = fn(a) 2117*da0073e9SAndroid Build Coastguard Worker var.mul_(2) 2118*da0073e9SAndroid Build Coastguard Worker out = (var + mean).sum() 2119*da0073e9SAndroid Build Coastguard Worker gvar_expected, gmean_expected = torch.autograd.grad(out, inputs=(var, mean)) 2120*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(gvar, gvar_expected)) 2121*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(gmean, gmean_expected)) 2122*da0073e9SAndroid Build Coastguard Worker 2123*da0073e9SAndroid Build Coastguard Worker def test_retain_grad_inplace_over_view(self): 2124*da0073e9SAndroid Build Coastguard Worker base = torch.tensor([1.0], requires_grad=True).clone() 2125*da0073e9SAndroid Build Coastguard Worker view = base[:] 2126*da0073e9SAndroid Build Coastguard Worker view2 = base[:] 2127*da0073e9SAndroid Build Coastguard Worker view.retain_grad() 2128*da0073e9SAndroid Build Coastguard Worker view2.retain_grad() 2129*da0073e9SAndroid Build Coastguard Worker view.mul_(2) 2130*da0073e9SAndroid Build Coastguard Worker (view + view2).sum().backward() 2131*da0073e9SAndroid Build Coastguard Worker 2132*da0073e9SAndroid Build Coastguard Worker # The old grad_fn, slice, wouldn't be part of the graph during backward 2133*da0073e9SAndroid Build Coastguard Worker # so if the retains grad were not properly updated to the new grad_fn, 2134*da0073e9SAndroid Build Coastguard Worker # the grad would still be None 2135*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view.grad, view2.grad) 2136*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view.grad, torch.tensor([1.0])) 2137*da0073e9SAndroid Build Coastguard Worker 2138*da0073e9SAndroid Build Coastguard Worker def test_tensor_hooks_inplace(self): 2139*da0073e9SAndroid Build Coastguard Worker # Check that the second hook gets registered to the new version of tensor 2140*da0073e9SAndroid Build Coastguard Worker count1 = [0] 2141*da0073e9SAndroid Build Coastguard Worker count2 = [0] 2142*da0073e9SAndroid Build Coastguard Worker 2143*da0073e9SAndroid Build Coastguard Worker def fn1(grad): 2144*da0073e9SAndroid Build Coastguard Worker count1[0] += 1 2145*da0073e9SAndroid Build Coastguard Worker # x2 from mul, x2 from fn2 2146*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.tensor([4.0])) 2147*da0073e9SAndroid Build Coastguard Worker return grad * 2 2148*da0073e9SAndroid Build Coastguard Worker 2149*da0073e9SAndroid Build Coastguard Worker def fn2(grad): 2150*da0073e9SAndroid Build Coastguard Worker count2[0] += 1 2151*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.tensor([1.0])) 2152*da0073e9SAndroid Build Coastguard Worker return grad * 2 2153*da0073e9SAndroid Build Coastguard Worker 2154*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True) 2155*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2156*da0073e9SAndroid Build Coastguard Worker b.register_hook(fn1) 2157*da0073e9SAndroid Build Coastguard Worker b.mul_(2) 2158*da0073e9SAndroid Build Coastguard Worker b.register_hook(fn2) 2159*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 2160*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count1[0], 1) 2161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count2[0], 1) 2162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.tensor([8.0])) 2163*da0073e9SAndroid Build Coastguard Worker 2164*da0073e9SAndroid Build Coastguard Worker count3 = [0] 2165*da0073e9SAndroid Build Coastguard Worker 2166*da0073e9SAndroid Build Coastguard Worker def fn3(grad): 2167*da0073e9SAndroid Build Coastguard Worker count3[0] += 1 2168*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.tensor([4.0])) 2169*da0073e9SAndroid Build Coastguard Worker return grad * 2 2170*da0073e9SAndroid Build Coastguard Worker 2171*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True) 2172*da0073e9SAndroid Build Coastguard Worker b = a.clone() 2173*da0073e9SAndroid Build Coastguard Worker b.register_hook(fn3) 2174*da0073e9SAndroid Build Coastguard Worker # Inplace multiple times is OK 2175*da0073e9SAndroid Build Coastguard Worker b.mul_(2) 2176*da0073e9SAndroid Build Coastguard Worker b.mul_(2) 2177*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 2178*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count1[0], 1) 2179*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.tensor([8.0])) 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker def test_tensor_hooks_inplace_multiple_outputs(self): 2182*da0073e9SAndroid Build Coastguard Worker class DoubleMul(Function): 2183*da0073e9SAndroid Build Coastguard Worker @staticmethod 2184*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 2185*da0073e9SAndroid Build Coastguard Worker return x * 2, x * 3 2186*da0073e9SAndroid Build Coastguard Worker 2187*da0073e9SAndroid Build Coastguard Worker @staticmethod 2188*da0073e9SAndroid Build Coastguard Worker def backward(ctx, g1, g2): 2189*da0073e9SAndroid Build Coastguard Worker return g1 * 2 + g2 * 3 2190*da0073e9SAndroid Build Coastguard Worker 2191*da0073e9SAndroid Build Coastguard Worker var_mean = partial(torch.var_mean, dim=0) 2192*da0073e9SAndroid Build Coastguard Worker 2193*da0073e9SAndroid Build Coastguard Worker for fn in (DoubleMul.apply, var_mean): 2194*da0073e9SAndroid Build Coastguard Worker counts = [0, 0, 0] 2195*da0073e9SAndroid Build Coastguard Worker 2196*da0073e9SAndroid Build Coastguard Worker def fn0(grad): 2197*da0073e9SAndroid Build Coastguard Worker counts[0] += 1 2198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones_like(out1) * 2) 2199*da0073e9SAndroid Build Coastguard Worker 2200*da0073e9SAndroid Build Coastguard Worker def fn1(grad): 2201*da0073e9SAndroid Build Coastguard Worker counts[1] += 1 2202*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones_like(out1) * 3) 2203*da0073e9SAndroid Build Coastguard Worker 2204*da0073e9SAndroid Build Coastguard Worker def fn2(grad): 2205*da0073e9SAndroid Build Coastguard Worker counts[2] += 1 2206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones_like(out1)) 2207*da0073e9SAndroid Build Coastguard Worker 2208*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 3, requires_grad=True) 2209*da0073e9SAndroid Build Coastguard Worker out1, out2 = fn(b) 2210*da0073e9SAndroid Build Coastguard Worker out1.register_hook(fn0) 2211*da0073e9SAndroid Build Coastguard Worker out2.register_hook(fn1) 2212*da0073e9SAndroid Build Coastguard Worker # node refers to two hook dicts 2213*da0073e9SAndroid Build Coastguard Worker # out1 no longer no longer points to its old hook dict 2214*da0073e9SAndroid Build Coastguard Worker out1.mul_(2) 2215*da0073e9SAndroid Build Coastguard Worker # fn2 is registered to out1's new hook dict 2216*da0073e9SAndroid Build Coastguard Worker out1.register_hook(fn2) 2217*da0073e9SAndroid Build Coastguard Worker (out1 + out2 * 3).sum().backward() 2218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counts, [1, 1, 1]) 2219*da0073e9SAndroid Build Coastguard Worker 2220*da0073e9SAndroid Build Coastguard Worker def test_tensor_hooks_inplace_over_view(self): 2221*da0073e9SAndroid Build Coastguard Worker # There might be a better UX here, but this is the way it is now 2222*da0073e9SAndroid Build Coastguard Worker count = [0] 2223*da0073e9SAndroid Build Coastguard Worker 2224*da0073e9SAndroid Build Coastguard Worker def fn0(grad): 2225*da0073e9SAndroid Build Coastguard Worker self.fail() 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker def fn1(grad): 2228*da0073e9SAndroid Build Coastguard Worker self.fail() 2229*da0073e9SAndroid Build Coastguard Worker 2230*da0073e9SAndroid Build Coastguard Worker def fn2(grad): 2231*da0073e9SAndroid Build Coastguard Worker count[0] += 1 2232*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.tensor([1.0])) 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker base = torch.tensor([1.0], requires_grad=True).clone() 2235*da0073e9SAndroid Build Coastguard Worker view = base[:] 2236*da0073e9SAndroid Build Coastguard Worker view2 = base[:] 2237*da0073e9SAndroid Build Coastguard Worker view.register_hook(fn0) 2238*da0073e9SAndroid Build Coastguard Worker view2.register_hook(fn1) 2239*da0073e9SAndroid Build Coastguard Worker view.mul_(2) 2240*da0073e9SAndroid Build Coastguard Worker # We need to explicitly trigger an update to view to update its grad_fn 2241*da0073e9SAndroid Build Coastguard Worker view2.grad_fn 2242*da0073e9SAndroid Build Coastguard Worker view2.register_hook(fn2) 2243*da0073e9SAndroid Build Coastguard Worker (view + view2).sum().backward() 2244*da0073e9SAndroid Build Coastguard Worker # The hooks originally registered to view are not fired, one must explicitly 2245*da0073e9SAndroid Build Coastguard Worker # trigger an update to the view's grad_fn, and then register a new hook 2246*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 1) 2247*da0073e9SAndroid Build Coastguard Worker 2248*da0073e9SAndroid Build Coastguard Worker def test_retain_grad_cycle(self): 2249*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 2250*da0073e9SAndroid Build Coastguard Worker 2251*da0073e9SAndroid Build Coastguard Worker def run_test(): 2252*da0073e9SAndroid Build Coastguard Worker y = x * 2 2253*da0073e9SAndroid Build Coastguard Worker y.retain_grad() 2254*da0073e9SAndroid Build Coastguard Worker 2255*da0073e9SAndroid Build Coastguard Worker return y / 2, torch._C._WeakTensorRef(y) 2256*da0073e9SAndroid Build Coastguard Worker 2257*da0073e9SAndroid Build Coastguard Worker z, ref = run_test() 2258*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref.expired()) 2259*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 2260*da0073e9SAndroid Build Coastguard Worker 2261*da0073e9SAndroid Build Coastguard Worker def test_backward(self): 2262*da0073e9SAndroid Build Coastguard Worker v = torch.randn(5, 5, requires_grad=True) 2263*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 2264*da0073e9SAndroid Build Coastguard Worker y = (torch.rand(5, 5) + 0.1).requires_grad_(True) 2265*da0073e9SAndroid Build Coastguard Worker z = torch.randn(5, 5, requires_grad=True) 2266*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(5, 5) 2267*da0073e9SAndroid Build Coastguard Worker 2268*da0073e9SAndroid Build Coastguard Worker v.backward(grad_output) 2269*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, grad_output) 2270*da0073e9SAndroid Build Coastguard Worker 2271*da0073e9SAndroid Build Coastguard Worker a = x + (y * z) + 4 * z**2 * x / y 2272*da0073e9SAndroid Build Coastguard Worker a.backward(grad_output) 2273*da0073e9SAndroid Build Coastguard Worker x_grad = 4 * z.pow(2) / y + 1 2274*da0073e9SAndroid Build Coastguard Worker y_grad = z - 4 * x * z.pow(2) / y.pow(2) 2275*da0073e9SAndroid Build Coastguard Worker z_grad = 8 * x * z / y + y 2276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad * grad_output) 2277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad * grad_output) 2278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.grad, z_grad * grad_output) 2279*da0073e9SAndroid Build Coastguard Worker 2280*da0073e9SAndroid Build Coastguard Worker def test_to_sparse_backward(self): 2281*da0073e9SAndroid Build Coastguard Worker to_attr_names = ( 2282*da0073e9SAndroid Build Coastguard Worker "to_dense", 2283*da0073e9SAndroid Build Coastguard Worker "to_sparse", 2284*da0073e9SAndroid Build Coastguard Worker "to_sparse_csr", 2285*da0073e9SAndroid Build Coastguard Worker "to_sparse_csc", 2286*da0073e9SAndroid Build Coastguard Worker "to_sparse_bsr", 2287*da0073e9SAndroid Build Coastguard Worker "to_sparse_bsc", 2288*da0073e9SAndroid Build Coastguard Worker ) 2289*da0073e9SAndroid Build Coastguard Worker to_params = ((), (), (), (), (2,), (2,)) 2290*da0073e9SAndroid Build Coastguard Worker to_attr_names_params = dict(zip(to_attr_names, to_params)) 2291*da0073e9SAndroid Build Coastguard Worker 2292*da0073e9SAndroid Build Coastguard Worker def check_inversion_possible( 2293*da0073e9SAndroid Build Coastguard Worker t, layout1, layout1_params, layout2, layout2_params 2294*da0073e9SAndroid Build Coastguard Worker ): 2295*da0073e9SAndroid Build Coastguard Worker l = (layout1, layout2) 2296*da0073e9SAndroid Build Coastguard Worker p = (layout1_params, layout2_params) 2297*da0073e9SAndroid Build Coastguard Worker for l1, l2, p1, p2 in ((*l, *p), (*l[::-1], *p[::-1])): 2298*da0073e9SAndroid Build Coastguard Worker try: 2299*da0073e9SAndroid Build Coastguard Worker to_l1 = getattr(t, l1)(*p1) 2300*da0073e9SAndroid Build Coastguard Worker to_l2 = getattr(to_l1, l2)(*p2) 2301*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 2302*da0073e9SAndroid Build Coastguard Worker return False 2303*da0073e9SAndroid Build Coastguard Worker 2304*da0073e9SAndroid Build Coastguard Worker return True 2305*da0073e9SAndroid Build Coastguard Worker 2306*da0073e9SAndroid Build Coastguard Worker self_strided = torch.rand(4, 4, dtype=torch.double) + 1 2307*da0073e9SAndroid Build Coastguard Worker grad_strided = torch.rand(4, 4, dtype=torch.double) + 1 2308*da0073e9SAndroid Build Coastguard Worker 2309*da0073e9SAndroid Build Coastguard Worker for from_to_attr in to_attr_names: 2310*da0073e9SAndroid Build Coastguard Worker from_params = to_attr_names_params[from_to_attr] 2311*da0073e9SAndroid Build Coastguard Worker self_from = getattr(self_strided, from_to_attr)( 2312*da0073e9SAndroid Build Coastguard Worker *from_params 2313*da0073e9SAndroid Build Coastguard Worker ).requires_grad_(True) 2314*da0073e9SAndroid Build Coastguard Worker 2315*da0073e9SAndroid Build Coastguard Worker for to_to_attr in to_attr_names[1:]: 2316*da0073e9SAndroid Build Coastguard Worker to_params = to_attr_names_params[to_to_attr] 2317*da0073e9SAndroid Build Coastguard Worker 2318*da0073e9SAndroid Build Coastguard Worker if check_inversion_possible( 2319*da0073e9SAndroid Build Coastguard Worker self_strided, from_to_attr, from_params, to_to_attr, to_params 2320*da0073e9SAndroid Build Coastguard Worker ): 2321*da0073e9SAndroid Build Coastguard Worker self_to = getattr(self_from, to_to_attr)(*to_params) 2322*da0073e9SAndroid Build Coastguard Worker grad_to = getattr(grad_strided, to_to_attr)(*to_params) 2323*da0073e9SAndroid Build Coastguard Worker 2324*da0073e9SAndroid Build Coastguard Worker # No gradcheck support for BSR/BSC, so the grads are checked explicitly 2325*da0073e9SAndroid Build Coastguard Worker grad_res = torch.autograd.grad(self_to, self_from, grad_to)[0] 2326*da0073e9SAndroid Build Coastguard Worker 2327*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_res.layout, self_from.layout) 2328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_res.to_dense(), grad_strided) 2329*da0073e9SAndroid Build Coastguard Worker 2330*da0073e9SAndroid Build Coastguard Worker def test_sparse_mm_backward(self): 2331*da0073e9SAndroid Build Coastguard Worker size = (3, 3) 2332*da0073e9SAndroid Build Coastguard Worker 2333*da0073e9SAndroid Build Coastguard Worker mm_test_cases = product(*(([False, True],) * 4)) 2334*da0073e9SAndroid Build Coastguard Worker 2335*da0073e9SAndroid Build Coastguard Worker for a_req_grad, a_is_sparse, b_req_grad, b_is_sparse in mm_test_cases: 2336*da0073e9SAndroid Build Coastguard Worker # We should only be testing cases with sparse inputs, and at least one 2337*da0073e9SAndroid Build Coastguard Worker # input needs to require grad so we can call a backward pass 2338*da0073e9SAndroid Build Coastguard Worker if not ((a_is_sparse or b_is_sparse) and (a_req_grad or b_req_grad)): 2339*da0073e9SAndroid Build Coastguard Worker continue 2340*da0073e9SAndroid Build Coastguard Worker a = torch.randn(size) 2341*da0073e9SAndroid Build Coastguard Worker if a_is_sparse: 2342*da0073e9SAndroid Build Coastguard Worker # detaching as `a` needs to be a leaf 2343*da0073e9SAndroid Build Coastguard Worker a = a.to_sparse().detach() 2344*da0073e9SAndroid Build Coastguard Worker b = torch.randn(size) 2345*da0073e9SAndroid Build Coastguard Worker if b_is_sparse: 2346*da0073e9SAndroid Build Coastguard Worker # detaching as `b` needs to be a leaf 2347*da0073e9SAndroid Build Coastguard Worker b = b.to_sparse().detach() 2348*da0073e9SAndroid Build Coastguard Worker 2349*da0073e9SAndroid Build Coastguard Worker a = a.requires_grad_(a_req_grad) 2350*da0073e9SAndroid Build Coastguard Worker b = b.requires_grad_(b_req_grad) 2351*da0073e9SAndroid Build Coastguard Worker 2352*da0073e9SAndroid Build Coastguard Worker r = a.mm(b) 2353*da0073e9SAndroid Build Coastguard Worker s = r.sum().backward() 2354*da0073e9SAndroid Build Coastguard Worker a_grad = None if a.grad is None else a.grad.clone().detach() 2355*da0073e9SAndroid Build Coastguard Worker b_grad = None if b.grad is None else b.grad.clone().detach() 2356*da0073e9SAndroid Build Coastguard Worker 2357*da0073e9SAndroid Build Coastguard Worker # Redo with only dense tensors 2358*da0073e9SAndroid Build Coastguard Worker a = ( 2359*da0073e9SAndroid Build Coastguard Worker (a.to_dense() if a.is_sparse else a) 2360*da0073e9SAndroid Build Coastguard Worker .clone() 2361*da0073e9SAndroid Build Coastguard Worker .detach() 2362*da0073e9SAndroid Build Coastguard Worker .requires_grad_(a_req_grad) 2363*da0073e9SAndroid Build Coastguard Worker ) 2364*da0073e9SAndroid Build Coastguard Worker b = ( 2365*da0073e9SAndroid Build Coastguard Worker (b.to_dense() if b.is_sparse else b) 2366*da0073e9SAndroid Build Coastguard Worker .clone() 2367*da0073e9SAndroid Build Coastguard Worker .detach() 2368*da0073e9SAndroid Build Coastguard Worker .requires_grad_(b_req_grad) 2369*da0073e9SAndroid Build Coastguard Worker ) 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker r = a.mm(b) 2372*da0073e9SAndroid Build Coastguard Worker r.sum().backward() 2373*da0073e9SAndroid Build Coastguard Worker 2374*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a_grad, a.grad) 2375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_grad, b.grad) 2376*da0073e9SAndroid Build Coastguard Worker 2377*da0073e9SAndroid Build Coastguard Worker def test_multi_backward(self): 2378*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 2379*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 2380*da0073e9SAndroid Build Coastguard Worker 2381*da0073e9SAndroid Build Coastguard Worker q = torch.randn(5, 5, requires_grad=True) 2382*da0073e9SAndroid Build Coastguard Worker 2383*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5, requires_grad=True) 2384*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5, requires_grad=True) 2385*da0073e9SAndroid Build Coastguard Worker 2386*da0073e9SAndroid Build Coastguard Worker q2 = q * 2 2387*da0073e9SAndroid Build Coastguard Worker z = x + y + q2 2388*da0073e9SAndroid Build Coastguard Worker c = a * b + q2 2389*da0073e9SAndroid Build Coastguard Worker grad_z = torch.randn(5, 5) 2390*da0073e9SAndroid Build Coastguard Worker grad_c = torch.randn(5, 5) 2391*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([z, c], [grad_z, grad_c]) 2392*da0073e9SAndroid Build Coastguard Worker 2393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, grad_z) 2394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, grad_z) 2395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, grad_c * b) 2396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, grad_c * a) 2397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.grad, (grad_c + grad_z) * 2) 2398*da0073e9SAndroid Build Coastguard Worker 2399*da0073e9SAndroid Build Coastguard Worker def test_multi_backward_no_grad(self): 2400*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 2401*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=False) 2402*da0073e9SAndroid Build Coastguard Worker 2403*da0073e9SAndroid Build Coastguard Worker z = x + y 2404*da0073e9SAndroid Build Coastguard Worker q = y * 2 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker # NB: we currently raise an exception if any arguments to backwards 2407*da0073e9SAndroid Build Coastguard Worker # have requires_grad=False and don't have a grad_fn. We may want to 2408*da0073e9SAndroid Build Coastguard Worker # relax that check to a warning. 2409*da0073e9SAndroid Build Coastguard Worker def call_backwards(): 2410*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)]) 2411*da0073e9SAndroid Build Coastguard Worker 2412*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, call_backwards) 2413*da0073e9SAndroid Build Coastguard Worker 2414*da0073e9SAndroid Build Coastguard Worker def test_backward_with_inputs(self): 2415*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2416*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2417*da0073e9SAndroid Build Coastguard Worker 2418*da0073e9SAndroid Build Coastguard Worker def fn(): 2419*da0073e9SAndroid Build Coastguard Worker return x**2 + y * x + y**2 2420*da0073e9SAndroid Build Coastguard Worker 2421*da0073e9SAndroid Build Coastguard Worker gradient = torch.ones(2, 2) 2422*da0073e9SAndroid Build Coastguard Worker x_grad_expected = 2 * x + y 2423*da0073e9SAndroid Build Coastguard Worker y_grad_expected = x + 2 * y 2424*da0073e9SAndroid Build Coastguard Worker 2425*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2426*da0073e9SAndroid Build Coastguard Worker def reset_grad(): 2427*da0073e9SAndroid Build Coastguard Worker x.grad.zero_() 2428*da0073e9SAndroid Build Coastguard Worker y.grad.zero_() 2429*da0073e9SAndroid Build Coastguard Worker 2430*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(fn(), gradient, inputs=[x, y]) 2431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad_expected) 2432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad_expected) 2433*da0073e9SAndroid Build Coastguard Worker 2434*da0073e9SAndroid Build Coastguard Worker reset_grad() 2435*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(fn(), gradient, inputs=[x]) 2436*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad_expected) 2437*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, torch.zeros(2, 2), exact_dtype=False) 2438*da0073e9SAndroid Build Coastguard Worker 2439*da0073e9SAndroid Build Coastguard Worker reset_grad() 2440*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(fn(), gradient, inputs=[y]) 2441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad_expected) 2442*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False) 2443*da0073e9SAndroid Build Coastguard Worker 2444*da0073e9SAndroid Build Coastguard Worker reset_grad() 2445*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(fn(), gradient, inputs=y) 2446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad_expected) 2447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.zeros(2, 2), exact_dtype=False) 2448*da0073e9SAndroid Build Coastguard Worker 2449*da0073e9SAndroid Build Coastguard Worker reset_grad() 2450*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 2451*da0073e9SAndroid Build Coastguard Worker RuntimeError, 2452*da0073e9SAndroid Build Coastguard Worker "cannot be empty", 2453*da0073e9SAndroid Build Coastguard Worker lambda: torch.autograd.backward(fn(), gradient, inputs=[]), 2454*da0073e9SAndroid Build Coastguard Worker ) 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker def test_backward_with_nonleaf_inputs(self): 2457*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2458*da0073e9SAndroid Build Coastguard Worker x_nonleaf = x * 1 2459*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2460*da0073e9SAndroid Build Coastguard Worker z = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 2461*da0073e9SAndroid Build Coastguard Worker 2462*da0073e9SAndroid Build Coastguard Worker out = x_nonleaf**2 + y * x_nonleaf + y**2 2463*da0073e9SAndroid Build Coastguard Worker 2464*da0073e9SAndroid Build Coastguard Worker out.backward( 2465*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, dtype=torch.double), 2466*da0073e9SAndroid Build Coastguard Worker create_graph=True, 2467*da0073e9SAndroid Build Coastguard Worker inputs=[x, y, x_nonleaf], 2468*da0073e9SAndroid Build Coastguard Worker ) 2469*da0073e9SAndroid Build Coastguard Worker x_grad_expected = 2 * x + y 2470*da0073e9SAndroid Build Coastguard Worker y_grad_expected = x + 2 * y 2471*da0073e9SAndroid Build Coastguard Worker x_non_leaf_expected = 2 * x_nonleaf + y 2472*da0073e9SAndroid Build Coastguard Worker 2473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, y_grad_expected) 2474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x_grad_expected) 2475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_nonleaf.grad, x_non_leaf_expected) 2476*da0073e9SAndroid Build Coastguard Worker 2477*da0073e9SAndroid Build Coastguard Worker # backward doesn't have an allow_unused flag, so the behavior of backward 2478*da0073e9SAndroid Build Coastguard Worker # when variable is not part of the graph is as if allow_used were true 2479*da0073e9SAndroid Build Coastguard Worker # x.grad will simply be None. 2480*da0073e9SAndroid Build Coastguard Worker out.backward( 2481*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, dtype=torch.double), create_graph=True, inputs=[z] 2482*da0073e9SAndroid Build Coastguard Worker ) 2483*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(z.grad) 2484*da0073e9SAndroid Build Coastguard Worker 2485*da0073e9SAndroid Build Coastguard Worker def test_dependent_backward(self): 2486*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, requires_grad=True) 2487*da0073e9SAndroid Build Coastguard Worker y = x**2 2488*da0073e9SAndroid Build Coastguard Worker z = y**3 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker go_y = torch.randn(10) 2491*da0073e9SAndroid Build Coastguard Worker go_z = torch.randn(10) 2492*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([y, z], [go_y, go_z]) 2493*da0073e9SAndroid Build Coastguard Worker 2494*da0073e9SAndroid Build Coastguard Worker xd = x 2495*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, 2 * xd * go_y + 6 * xd.pow(5) * go_z) 2496*da0073e9SAndroid Build Coastguard Worker 2497*da0073e9SAndroid Build Coastguard Worker def test_save_output_nr(self): 2498*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, requires_grad=True) 2499*da0073e9SAndroid Build Coastguard Worker 2500*da0073e9SAndroid Build Coastguard Worker class MultiOutputFn(Function): 2501*da0073e9SAndroid Build Coastguard Worker @staticmethod 2502*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 2503*da0073e9SAndroid Build Coastguard Worker return x[:5], x[5:] 2504*da0073e9SAndroid Build Coastguard Worker 2505*da0073e9SAndroid Build Coastguard Worker @staticmethod 2506*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *grad): 2507*da0073e9SAndroid Build Coastguard Worker return torch.cat(grad) 2508*da0073e9SAndroid Build Coastguard Worker 2509*da0073e9SAndroid Build Coastguard Worker a, b = MultiOutputFn.apply(x) 2510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.output_nr, 1) 2511*da0073e9SAndroid Build Coastguard Worker 2512*da0073e9SAndroid Build Coastguard Worker class TestFn(Function): 2513*da0073e9SAndroid Build Coastguard Worker @staticmethod 2514*da0073e9SAndroid Build Coastguard Worker def forward(ctx, b): 2515*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(b) 2516*da0073e9SAndroid Build Coastguard Worker return b * 2 2517*da0073e9SAndroid Build Coastguard Worker 2518*da0073e9SAndroid Build Coastguard Worker @staticmethod 2519*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_b): 2520*da0073e9SAndroid Build Coastguard Worker (b,) = ctx.saved_tensors 2521*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.output_nr, 1) 2522*da0073e9SAndroid Build Coastguard Worker 2523*da0073e9SAndroid Build Coastguard Worker TestFn.apply(b).sum().backward() 2524*da0073e9SAndroid Build Coastguard Worker 2525*da0073e9SAndroid Build Coastguard Worker def test_first_grad_fn_access_in_no_grad_mode(self): 2526*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1 + 1j], requires_grad=True).clone() 2527*da0073e9SAndroid Build Coastguard Worker v = a.real 2528*da0073e9SAndroid Build Coastguard Worker a.add_(1) 2529*da0073e9SAndroid Build Coastguard Worker with torch.autograd.grad_mode.no_grad(): 2530*da0073e9SAndroid Build Coastguard Worker v.grad_fn 2531*da0073e9SAndroid Build Coastguard Worker 2532*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("too slow") 2533*da0073e9SAndroid Build Coastguard Worker def test_free_deep_graph(self): 2534*da0073e9SAndroid Build Coastguard Worker def scope(): 2535*da0073e9SAndroid Build Coastguard Worker depth = 150000 2536*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, requires_grad=True) 2537*da0073e9SAndroid Build Coastguard Worker y = x.clone() 2538*da0073e9SAndroid Build Coastguard Worker 2539*da0073e9SAndroid Build Coastguard Worker # build a "chain" computation graph 2540*da0073e9SAndroid Build Coastguard Worker for _ in range(depth): 2541*da0073e9SAndroid Build Coastguard Worker y = y + y * 0.000001 2542*da0073e9SAndroid Build Coastguard Worker 2543*da0073e9SAndroid Build Coastguard Worker # graph deletion occurs when the above locals go out of scope. 2544*da0073e9SAndroid Build Coastguard Worker # In this case `del y` will trigger it but it's easier to leave 2545*da0073e9SAndroid Build Coastguard Worker # it to Python to delete the locals. 2546*da0073e9SAndroid Build Coastguard Worker 2547*da0073e9SAndroid Build Coastguard Worker # Should not stack overflow 2548*da0073e9SAndroid Build Coastguard Worker scope() 2549*da0073e9SAndroid Build Coastguard Worker 2550*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("too slow") 2551*da0073e9SAndroid Build Coastguard Worker def test_free_deep_graph_complicated(self): 2552*da0073e9SAndroid Build Coastguard Worker def scope(): 2553*da0073e9SAndroid Build Coastguard Worker depth = 100000 2554*da0073e9SAndroid Build Coastguard Worker randchoice = torch.randint(2, [depth, 2]) 2555*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, requires_grad=True) 2556*da0073e9SAndroid Build Coastguard Worker y = x.clone() 2557*da0073e9SAndroid Build Coastguard Worker 2558*da0073e9SAndroid Build Coastguard Worker # Hold the two previous values 2559*da0073e9SAndroid Build Coastguard Worker prev_values = [None, None] 2560*da0073e9SAndroid Build Coastguard Worker 2561*da0073e9SAndroid Build Coastguard Worker # Build a "chain with skip connections" graph 2562*da0073e9SAndroid Build Coastguard Worker for _ in range(depth): 2563*da0073e9SAndroid Build Coastguard Worker prev_tensors = [ 2564*da0073e9SAndroid Build Coastguard Worker tensor for tensor in prev_values[:-1] if tensor is not None 2565*da0073e9SAndroid Build Coastguard Worker ] 2566*da0073e9SAndroid Build Coastguard Worker prev_values.append(y) 2567*da0073e9SAndroid Build Coastguard Worker prev_values.pop(0) 2568*da0073e9SAndroid Build Coastguard Worker 2569*da0073e9SAndroid Build Coastguard Worker # Definitely pick one tensor to add 2570*da0073e9SAndroid Build Coastguard Worker y += y * 0.000001 2571*da0073e9SAndroid Build Coastguard Worker 2572*da0073e9SAndroid Build Coastguard Worker # Possibly add other tensors 2573*da0073e9SAndroid Build Coastguard Worker nprev = len(prev_tensors) 2574*da0073e9SAndroid Build Coastguard Worker if nprev == 2: 2575*da0073e9SAndroid Build Coastguard Worker y += randchoice[depth].mul(torch.cat(prev_tensors)).sum() 2576*da0073e9SAndroid Build Coastguard Worker 2577*da0073e9SAndroid Build Coastguard Worker # graph deletion occurs when the above locals go out of scope. 2578*da0073e9SAndroid Build Coastguard Worker 2579*da0073e9SAndroid Build Coastguard Worker # Should not stack overflow 2580*da0073e9SAndroid Build Coastguard Worker scope() 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("too slow") 2583*da0073e9SAndroid Build Coastguard Worker def test_free_deep_graph_pyfunction(self): 2584*da0073e9SAndroid Build Coastguard Worker class MyOp(Function): 2585*da0073e9SAndroid Build Coastguard Worker @staticmethod 2586*da0073e9SAndroid Build Coastguard Worker def forward(ctx, tensor1, tensor2): 2587*da0073e9SAndroid Build Coastguard Worker return tensor1 + tensor2 2588*da0073e9SAndroid Build Coastguard Worker 2589*da0073e9SAndroid Build Coastguard Worker @staticmethod 2590*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 2591*da0073e9SAndroid Build Coastguard Worker return grad_output, grad_output 2592*da0073e9SAndroid Build Coastguard Worker 2593*da0073e9SAndroid Build Coastguard Worker def scope(): 2594*da0073e9SAndroid Build Coastguard Worker depth = 150000 2595*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, requires_grad=True) 2596*da0073e9SAndroid Build Coastguard Worker y = x.clone() 2597*da0073e9SAndroid Build Coastguard Worker 2598*da0073e9SAndroid Build Coastguard Worker # build deeply nested computation graph 2599*da0073e9SAndroid Build Coastguard Worker for _ in range(depth): 2600*da0073e9SAndroid Build Coastguard Worker y = MyOp.apply(y, y) 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker # graph deletion occurs when the above locals go out of scope. 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker # Should not stack overflow 2605*da0073e9SAndroid Build Coastguard Worker scope() 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker def test_no_unnecessary_save(self): 2608*da0073e9SAndroid Build Coastguard Worker # If we kept x in the derivative Function of x * 2 we would 2609*da0073e9SAndroid Build Coastguard Worker # get an error in the backward that would complain that we've 2610*da0073e9SAndroid Build Coastguard Worker # modified x, which was needed for gradient computation. 2611*da0073e9SAndroid Build Coastguard Worker # Since we should elide unnecessary saves, this test should pass. 2612*da0073e9SAndroid Build Coastguard Worker mu = torch.ones(1, requires_grad=True) 2613*da0073e9SAndroid Build Coastguard Worker x = torch.empty(1) 2614*da0073e9SAndroid Build Coastguard Worker loss = 0 2615*da0073e9SAndroid Build Coastguard Worker for i in range(3): 2616*da0073e9SAndroid Build Coastguard Worker x.detach_() 2617*da0073e9SAndroid Build Coastguard Worker x.copy_(mu + i) 2618*da0073e9SAndroid Build Coastguard Worker ft = torch.tensor([float(i)]) 2619*da0073e9SAndroid Build Coastguard Worker multiplied = x * ft 2620*da0073e9SAndroid Build Coastguard Worker s = multiplied.sum() 2621*da0073e9SAndroid Build Coastguard Worker loss += s 2622*da0073e9SAndroid Build Coastguard Worker loss.backward() 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker def test_no_grad(self): 2625*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 2626*da0073e9SAndroid Build Coastguard Worker y = torch.ones(5, 5) * 4 2627*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2628*da0073e9SAndroid Build Coastguard Worker w = x + y 2629*da0073e9SAndroid Build Coastguard Worker 2630*da0073e9SAndroid Build Coastguard Worker def adder(x, y): 2631*da0073e9SAndroid Build Coastguard Worker return x + y 2632*da0073e9SAndroid Build Coastguard Worker 2633*da0073e9SAndroid Build Coastguard Worker adders = [torch.no_grad()(adder), torch.no_grad(adder)] 2634*da0073e9SAndroid Build Coastguard Worker 2635*da0073e9SAndroid Build Coastguard Worker for adder in adders: 2636*da0073e9SAndroid Build Coastguard Worker z = adder(x, y) 2637*da0073e9SAndroid Build Coastguard Worker 2638*da0073e9SAndroid Build Coastguard Worker self.assertFalse(w.requires_grad) 2639*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) 2640*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(w.grad_fn) 2641*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z.requires_grad) 2642*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) 2643*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(z.grad_fn) 2644*da0073e9SAndroid Build Coastguard Worker 2645*da0073e9SAndroid Build Coastguard Worker # test nested decorator and with-statement on no_grad 2646*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2647*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2648*da0073e9SAndroid Build Coastguard Worker w = adder(x, y) 2649*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2650*da0073e9SAndroid Build Coastguard Worker 2651*da0073e9SAndroid Build Coastguard Worker def test_enable_grad_decorator_no_paren(self): 2652*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, requires_grad=True) 2653*da0073e9SAndroid Build Coastguard Worker 2654*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad 2655*da0073e9SAndroid Build Coastguard Worker def doubler(x): 2656*da0073e9SAndroid Build Coastguard Worker return x * 2 2657*da0073e9SAndroid Build Coastguard Worker 2658*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2659*da0073e9SAndroid Build Coastguard Worker z = doubler(x) 2660*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.requires_grad) 2661*da0073e9SAndroid Build Coastguard Worker 2662*da0073e9SAndroid Build Coastguard Worker def test_set_grad_generator_functions(self): 2663*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2664*da0073e9SAndroid Build Coastguard Worker def gen_no_grad(): 2665*da0073e9SAndroid Build Coastguard Worker for i in range(10): 2666*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.is_grad_enabled(), False) 2667*da0073e9SAndroid Build Coastguard Worker yield i 2668*da0073e9SAndroid Build Coastguard Worker 2669*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2670*da0073e9SAndroid Build Coastguard Worker for _ in gen_no_grad(): 2671*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.is_grad_enabled(), True) 2672*da0073e9SAndroid Build Coastguard Worker 2673*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad() 2674*da0073e9SAndroid Build Coastguard Worker def gen_enable_grad(): 2675*da0073e9SAndroid Build Coastguard Worker for i in range(10): 2676*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.is_grad_enabled(), True) 2677*da0073e9SAndroid Build Coastguard Worker yield i 2678*da0073e9SAndroid Build Coastguard Worker 2679*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2680*da0073e9SAndroid Build Coastguard Worker for _ in gen_enable_grad(): 2681*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.is_grad_enabled(), False) 2682*da0073e9SAndroid Build Coastguard Worker 2683*da0073e9SAndroid Build Coastguard Worker def test_set_grad_generator_functions_recursive(self): 2684*da0073e9SAndroid Build Coastguard Worker # enable_grad_decorator_recursive and no_grad_decorator_recursive call each other 2685*da0073e9SAndroid Build Coastguard Worker # recursively, to ensure that the decorators preserve the caller's setting 2686*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad() 2687*da0073e9SAndroid Build Coastguard Worker def enable_grad_decorator_recursive(depth): 2688*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2689*da0073e9SAndroid Build Coastguard Worker if depth > 0: 2690*da0073e9SAndroid Build Coastguard Worker no_grad_decorator_recursive(depth - 1) 2691*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2692*da0073e9SAndroid Build Coastguard Worker 2693*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2694*da0073e9SAndroid Build Coastguard Worker def no_grad_decorator_recursive(depth): 2695*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2696*da0073e9SAndroid Build Coastguard Worker if depth > 0: 2697*da0073e9SAndroid Build Coastguard Worker enable_grad_decorator_recursive(depth - 1) 2698*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2699*da0073e9SAndroid Build Coastguard Worker 2700*da0073e9SAndroid Build Coastguard Worker # enable_grad_context_manager_recursive and no_grad_context_manager_recursive call 2701*da0073e9SAndroid Build Coastguard Worker # each other recursively, to ensure that the decorators preserve the caller's setting 2702*da0073e9SAndroid Build Coastguard Worker def enable_grad_context_manager_recursive(depth): 2703*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2704*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2705*da0073e9SAndroid Build Coastguard Worker if depth > 0: 2706*da0073e9SAndroid Build Coastguard Worker no_grad_context_manager_recursive(depth - 1) 2707*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2708*da0073e9SAndroid Build Coastguard Worker 2709*da0073e9SAndroid Build Coastguard Worker def no_grad_context_manager_recursive(depth): 2710*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2711*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2712*da0073e9SAndroid Build Coastguard Worker if depth > 0: 2713*da0073e9SAndroid Build Coastguard Worker enable_grad_context_manager_recursive(depth - 1) 2714*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2715*da0073e9SAndroid Build Coastguard Worker 2716*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2717*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2718*da0073e9SAndroid Build Coastguard Worker enable_grad_decorator_recursive(10) 2719*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2720*da0073e9SAndroid Build Coastguard Worker enable_grad_context_manager_recursive(10) 2721*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2722*da0073e9SAndroid Build Coastguard Worker 2723*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2724*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2725*da0073e9SAndroid Build Coastguard Worker enable_grad_decorator_recursive(10) 2726*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2727*da0073e9SAndroid Build Coastguard Worker enable_grad_context_manager_recursive(10) 2728*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2729*da0073e9SAndroid Build Coastguard Worker 2730*da0073e9SAndroid Build Coastguard Worker def test_set_grad_coroutines(self): 2731*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2732*da0073e9SAndroid Build Coastguard Worker def coro_no_grad(n=10): 2733*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2734*da0073e9SAndroid Build Coastguard Worker for i in range(n): 2735*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2736*da0073e9SAndroid Build Coastguard Worker r = yield i 2737*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, r) 2739*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2740*da0073e9SAndroid Build Coastguard Worker 2741*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad() 2742*da0073e9SAndroid Build Coastguard Worker def coro_enable_grad(n=10): 2743*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2744*da0073e9SAndroid Build Coastguard Worker for i in range(n): 2745*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2746*da0073e9SAndroid Build Coastguard Worker r = yield i 2747*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2748*da0073e9SAndroid Build Coastguard Worker self.assertEqual(i, r) 2749*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2750*da0073e9SAndroid Build Coastguard Worker 2751*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2752*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2753*da0073e9SAndroid Build Coastguard Worker coro, r = coro_no_grad(), None 2754*da0073e9SAndroid Build Coastguard Worker try: 2755*da0073e9SAndroid Build Coastguard Worker while True: 2756*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2757*da0073e9SAndroid Build Coastguard Worker r = coro.send(r) 2758*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2759*da0073e9SAndroid Build Coastguard Worker 2760*da0073e9SAndroid Build Coastguard Worker except StopIteration: 2761*da0073e9SAndroid Build Coastguard Worker pass 2762*da0073e9SAndroid Build Coastguard Worker 2763*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2764*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2765*da0073e9SAndroid Build Coastguard Worker coro, r = coro_enable_grad(), None 2766*da0073e9SAndroid Build Coastguard Worker try: 2767*da0073e9SAndroid Build Coastguard Worker while True: 2768*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2769*da0073e9SAndroid Build Coastguard Worker r = coro.send(r) 2770*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2771*da0073e9SAndroid Build Coastguard Worker 2772*da0073e9SAndroid Build Coastguard Worker except StopIteration: 2773*da0073e9SAndroid Build Coastguard Worker pass 2774*da0073e9SAndroid Build Coastguard Worker 2775*da0073e9SAndroid Build Coastguard Worker def test_set_grad_coroutines_benign_exceptions(self): 2776*da0073e9SAndroid Build Coastguard Worker class RecoverableException(Exception): 2777*da0073e9SAndroid Build Coastguard Worker pass 2778*da0073e9SAndroid Build Coastguard Worker 2779*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2780*da0073e9SAndroid Build Coastguard Worker def coro_no_grad(n=10): 2781*da0073e9SAndroid Build Coastguard Worker has_raised = False 2782*da0073e9SAndroid Build Coastguard Worker for i in range(n): 2783*da0073e9SAndroid Build Coastguard Worker try: 2784*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2785*da0073e9SAndroid Build Coastguard Worker yield (-i if has_raised else i) 2786*da0073e9SAndroid Build Coastguard Worker 2787*da0073e9SAndroid Build Coastguard Worker except RecoverableException: 2788*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2789*da0073e9SAndroid Build Coastguard Worker has_raised = True 2790*da0073e9SAndroid Build Coastguard Worker 2791*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad() 2792*da0073e9SAndroid Build Coastguard Worker def coro_enable_grad(n=10): 2793*da0073e9SAndroid Build Coastguard Worker has_raised = False 2794*da0073e9SAndroid Build Coastguard Worker for i in range(n): 2795*da0073e9SAndroid Build Coastguard Worker try: 2796*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2797*da0073e9SAndroid Build Coastguard Worker yield (-i if has_raised else i) 2798*da0073e9SAndroid Build Coastguard Worker 2799*da0073e9SAndroid Build Coastguard Worker except RecoverableException: 2800*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2801*da0073e9SAndroid Build Coastguard Worker has_raised = True 2802*da0073e9SAndroid Build Coastguard Worker 2803*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2804*da0073e9SAndroid Build Coastguard Worker coro = coro_no_grad() 2805*da0073e9SAndroid Build Coastguard Worker assert 0 == next(coro) 2806*da0073e9SAndroid Build Coastguard Worker try: 2807*da0073e9SAndroid Build Coastguard Worker while True: 2808*da0073e9SAndroid Build Coastguard Worker r = coro.throw(RecoverableException) 2809*da0073e9SAndroid Build Coastguard Worker self.assertLess(r, 0) 2810*da0073e9SAndroid Build Coastguard Worker 2811*da0073e9SAndroid Build Coastguard Worker except StopIteration: 2812*da0073e9SAndroid Build Coastguard Worker pass 2813*da0073e9SAndroid Build Coastguard Worker 2814*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2815*da0073e9SAndroid Build Coastguard Worker coro = coro_enable_grad() 2816*da0073e9SAndroid Build Coastguard Worker assert 0 == next(coro) 2817*da0073e9SAndroid Build Coastguard Worker try: 2818*da0073e9SAndroid Build Coastguard Worker while True: 2819*da0073e9SAndroid Build Coastguard Worker r = coro.throw(RecoverableException) 2820*da0073e9SAndroid Build Coastguard Worker self.assertLess(r, 0) 2821*da0073e9SAndroid Build Coastguard Worker 2822*da0073e9SAndroid Build Coastguard Worker except StopIteration: 2823*da0073e9SAndroid Build Coastguard Worker pass 2824*da0073e9SAndroid Build Coastguard Worker 2825*da0073e9SAndroid Build Coastguard Worker def test_set_grad_coroutines_critical_exceptions(self): 2826*da0073e9SAndroid Build Coastguard Worker class UnrecoverableException(Exception): 2827*da0073e9SAndroid Build Coastguard Worker pass 2828*da0073e9SAndroid Build Coastguard Worker 2829*da0073e9SAndroid Build Coastguard Worker class SecondaryException(Exception): 2830*da0073e9SAndroid Build Coastguard Worker pass 2831*da0073e9SAndroid Build Coastguard Worker 2832*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2833*da0073e9SAndroid Build Coastguard Worker def coro_no_grad(n=10): 2834*da0073e9SAndroid Build Coastguard Worker has_raised = False 2835*da0073e9SAndroid Build Coastguard Worker for i in range(n): 2836*da0073e9SAndroid Build Coastguard Worker try: 2837*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2838*da0073e9SAndroid Build Coastguard Worker yield (-i if has_raised else i) 2839*da0073e9SAndroid Build Coastguard Worker 2840*da0073e9SAndroid Build Coastguard Worker except UnrecoverableException: 2841*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2842*da0073e9SAndroid Build Coastguard Worker raise SecondaryException from None 2843*da0073e9SAndroid Build Coastguard Worker 2844*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad() 2845*da0073e9SAndroid Build Coastguard Worker def coro_enable_grad(n=10): 2846*da0073e9SAndroid Build Coastguard Worker has_raised = False 2847*da0073e9SAndroid Build Coastguard Worker for i in range(n): 2848*da0073e9SAndroid Build Coastguard Worker try: 2849*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2850*da0073e9SAndroid Build Coastguard Worker yield (-i if has_raised else i) 2851*da0073e9SAndroid Build Coastguard Worker 2852*da0073e9SAndroid Build Coastguard Worker except UnrecoverableException: 2853*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2854*da0073e9SAndroid Build Coastguard Worker raise SecondaryException from None 2855*da0073e9SAndroid Build Coastguard Worker 2856*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2857*da0073e9SAndroid Build Coastguard Worker coro = coro_no_grad() 2858*da0073e9SAndroid Build Coastguard Worker assert 0 == next(coro) 2859*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(SecondaryException): 2860*da0073e9SAndroid Build Coastguard Worker coro.throw(UnrecoverableException) 2861*da0073e9SAndroid Build Coastguard Worker 2862*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2863*da0073e9SAndroid Build Coastguard Worker coro = coro_enable_grad() 2864*da0073e9SAndroid Build Coastguard Worker assert 0 == next(coro) 2865*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(SecondaryException): 2866*da0073e9SAndroid Build Coastguard Worker coro.throw(UnrecoverableException) 2867*da0073e9SAndroid Build Coastguard Worker 2868*da0073e9SAndroid Build Coastguard Worker def test_set_grad_coroutines_exit(self): 2869*da0073e9SAndroid Build Coastguard Worker @torch.no_grad() 2870*da0073e9SAndroid Build Coastguard Worker def coro_no_grad(state): 2871*da0073e9SAndroid Build Coastguard Worker for i in range(10): 2872*da0073e9SAndroid Build Coastguard Worker try: 2873*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2874*da0073e9SAndroid Build Coastguard Worker yield i 2875*da0073e9SAndroid Build Coastguard Worker 2876*da0073e9SAndroid Build Coastguard Worker except GeneratorExit: 2877*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_grad_enabled()) 2878*da0073e9SAndroid Build Coastguard Worker state.add("GeneratorExit") 2879*da0073e9SAndroid Build Coastguard Worker raise 2880*da0073e9SAndroid Build Coastguard Worker 2881*da0073e9SAndroid Build Coastguard Worker @torch.enable_grad() 2882*da0073e9SAndroid Build Coastguard Worker def coro_enable_grad(state): 2883*da0073e9SAndroid Build Coastguard Worker for i in range(10): 2884*da0073e9SAndroid Build Coastguard Worker try: 2885*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2886*da0073e9SAndroid Build Coastguard Worker yield i 2887*da0073e9SAndroid Build Coastguard Worker 2888*da0073e9SAndroid Build Coastguard Worker except GeneratorExit: 2889*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 2890*da0073e9SAndroid Build Coastguard Worker state.add("GeneratorExit") 2891*da0073e9SAndroid Build Coastguard Worker raise 2892*da0073e9SAndroid Build Coastguard Worker 2893*da0073e9SAndroid Build Coastguard Worker state = set() 2894*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 2895*da0073e9SAndroid Build Coastguard Worker coro = coro_no_grad(state) 2896*da0073e9SAndroid Build Coastguard Worker for i in range(5): 2897*da0073e9SAndroid Build Coastguard Worker next(coro) 2898*da0073e9SAndroid Build Coastguard Worker 2899*da0073e9SAndroid Build Coastguard Worker coro.close() 2900*da0073e9SAndroid Build Coastguard Worker self.assertTrue("GeneratorExit" in state) 2901*da0073e9SAndroid Build Coastguard Worker 2902*da0073e9SAndroid Build Coastguard Worker state = set() 2903*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2904*da0073e9SAndroid Build Coastguard Worker coro = coro_enable_grad(state) 2905*da0073e9SAndroid Build Coastguard Worker for i in range(5): 2906*da0073e9SAndroid Build Coastguard Worker next(coro) 2907*da0073e9SAndroid Build Coastguard Worker 2908*da0073e9SAndroid Build Coastguard Worker coro.close() 2909*da0073e9SAndroid Build Coastguard Worker self.assertTrue("GeneratorExit" in state) 2910*da0073e9SAndroid Build Coastguard Worker 2911*da0073e9SAndroid Build Coastguard Worker def test_no_grad_python_function(self): 2912*da0073e9SAndroid Build Coastguard Worker """Python Functions should respect grad mode.""" 2913*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 2914*da0073e9SAndroid Build Coastguard Worker 2915*da0073e9SAndroid Build Coastguard Worker class MyOp(Function): 2916*da0073e9SAndroid Build Coastguard Worker @staticmethod 2917*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 2918*da0073e9SAndroid Build Coastguard Worker return x + 1 2919*da0073e9SAndroid Build Coastguard Worker 2920*da0073e9SAndroid Build Coastguard Worker @staticmethod 2921*da0073e9SAndroid Build Coastguard Worker def backward(self, dy): 2922*da0073e9SAndroid Build Coastguard Worker return dy 2923*da0073e9SAndroid Build Coastguard Worker 2924*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2925*da0073e9SAndroid Build Coastguard Worker y = MyOp.apply(x) 2926*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 2927*da0073e9SAndroid Build Coastguard Worker 2928*da0073e9SAndroid Build Coastguard Worker def test_indexing(self): 2929*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1.0, 17).view(4, 4) 2930*da0073e9SAndroid Build Coastguard Worker y = Variable(x, requires_grad=True) 2931*da0073e9SAndroid Build Coastguard Worker 2932*da0073e9SAndroid Build Coastguard Worker def compare(x, y, idx, indexed_tensor, indexed_var): 2933*da0073e9SAndroid Build Coastguard Worker indexed_var_t = indexed_var.data 2934*da0073e9SAndroid Build Coastguard Worker if not isinstance(indexed_tensor, torch.Tensor): 2935*da0073e9SAndroid Build Coastguard Worker indexed_var_t = indexed_var_t[0] 2936*da0073e9SAndroid Build Coastguard Worker self.assertEqual(indexed_tensor, indexed_var_t) 2937*da0073e9SAndroid Build Coastguard Worker 2938*da0073e9SAndroid Build Coastguard Worker indexed_var.sum().backward() 2939*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.empty(x.size()).fill_(0) 2940*da0073e9SAndroid Build Coastguard Worker expected_grad[idx] = 1 2941*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, expected_grad) 2942*da0073e9SAndroid Build Coastguard Worker 2943*da0073e9SAndroid Build Coastguard Worker def check_index(x, y, idx): 2944*da0073e9SAndroid Build Coastguard Worker if y.grad is not None: 2945*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 2946*da0073e9SAndroid Build Coastguard Worker y.grad.zero_() 2947*da0073e9SAndroid Build Coastguard Worker indexed_tensor = x[idx] 2948*da0073e9SAndroid Build Coastguard Worker indexed_var = y[idx] 2949*da0073e9SAndroid Build Coastguard Worker compare(x, y, idx, indexed_tensor, indexed_var) 2950*da0073e9SAndroid Build Coastguard Worker 2951*da0073e9SAndroid Build Coastguard Worker check_index(x, y, 1) 2952*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (1, 1)) 2953*da0073e9SAndroid Build Coastguard Worker check_index(x, y, slice(1, None)) 2954*da0073e9SAndroid Build Coastguard Worker check_index(x, y, slice(None, 2)) 2955*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None, 2), 2)) 2956*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(1, 2), 2)) 2957*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (1, slice(2, None))) 2958*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None, None), slice(2, None))) 2959*da0073e9SAndroid Build Coastguard Worker check_index(x, y, torch.LongTensor([0, 2])) 2960*da0073e9SAndroid Build Coastguard Worker check_index(x, y, torch.rand(4, 4).bernoulli().bool()) 2961*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (Ellipsis, slice(2, None))) 2962*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0], [0])) 2963*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([1, 2, 3], [0])) 2964*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([1, 2], [2, 1])) 2965*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 3]])) 2966*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([slice(None), [2, 3]])) 2967*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([[2, 3], slice(None)])) 2968*da0073e9SAndroid Build Coastguard Worker 2969*da0073e9SAndroid Build Coastguard Worker # advanced indexing, with less dim, or ellipsis 2970*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0])) 2971*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0],)) 2972*da0073e9SAndroid Build Coastguard Worker 2973*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1.0, 49).view(4, 3, 4) 2974*da0073e9SAndroid Build Coastguard Worker y = Variable(x, requires_grad=True) 2975*da0073e9SAndroid Build Coastguard Worker 2976*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None), [0], [0])) 2977*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0], [0], slice(None))) 2978*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None), [0, 1, 2], [0])) 2979*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0, 1, 2], [0], slice(None))) 2980*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None), [1, 2], [2, 1])) 2981*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([1, 2], [2, 1], slice(None))) 2982*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None), [[1, 2], [2, 0]], [[0, 1], [2, 3]])) 2983*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([[1, 2], [3, 0]], [[0, 1], [2, 2]], slice(None))) 2984*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None), slice(None), [2, 1])) 2985*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (slice(None), [2, 1], slice(None))) 2986*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([2, 1], slice(None), slice(None))) 2987*da0073e9SAndroid Build Coastguard Worker 2988*da0073e9SAndroid Build Coastguard Worker # advanced indexing, with less dim, or ellipsis 2989*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0],)) 2990*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0], slice(None))) 2991*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([0], Ellipsis)) 2992*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([1, 2], [0, 1])) 2993*da0073e9SAndroid Build Coastguard Worker check_index(x, y, ([1, 2], [0, 1], Ellipsis)) 2994*da0073e9SAndroid Build Coastguard Worker check_index(x, y, (Ellipsis, [1, 2], [0, 1])) 2995*da0073e9SAndroid Build Coastguard Worker 2996*da0073e9SAndroid Build Coastguard Worker # advanced indexing, with a tensor wrapped in a variable 2997*da0073e9SAndroid Build Coastguard Worker z = torch.LongTensor([0, 1]) 2998*da0073e9SAndroid Build Coastguard Worker zv = Variable(z, requires_grad=False) 2999*da0073e9SAndroid Build Coastguard Worker seq = [z, Ellipsis] 3000*da0073e9SAndroid Build Coastguard Worker seqv = [zv, Ellipsis] 3001*da0073e9SAndroid Build Coastguard Worker 3002*da0073e9SAndroid Build Coastguard Worker if y.grad is not None: 3003*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3004*da0073e9SAndroid Build Coastguard Worker y.grad.zero_() 3005*da0073e9SAndroid Build Coastguard Worker indexed_tensor = x[seq] 3006*da0073e9SAndroid Build Coastguard Worker indexed_var = y[seqv] 3007*da0073e9SAndroid Build Coastguard Worker compare(x, y, seq, indexed_tensor, indexed_var) 3008*da0073e9SAndroid Build Coastguard Worker 3009*da0073e9SAndroid Build Coastguard Worker def test_indexing_duplicates(self): 3010*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1.0, 17).view(4, 4) 3011*da0073e9SAndroid Build Coastguard Worker y = Variable(x, requires_grad=True) 3012*da0073e9SAndroid Build Coastguard Worker 3013*da0073e9SAndroid Build Coastguard Worker idx = torch.LongTensor([1, 1, 3, 2, 1, 2]) 3014*da0073e9SAndroid Build Coastguard Worker y[idx].sum().backward() 3015*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.zeros(4, 4) 3016*da0073e9SAndroid Build Coastguard Worker for i in idx: 3017*da0073e9SAndroid Build Coastguard Worker expected_grad[i] += 1 3018*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, expected_grad) 3019*da0073e9SAndroid Build Coastguard Worker 3020*da0073e9SAndroid Build Coastguard Worker # with advanced indexing 3021*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1.0, 17).view(4, 4) 3022*da0073e9SAndroid Build Coastguard Worker y = Variable(x, requires_grad=True) 3023*da0073e9SAndroid Build Coastguard Worker 3024*da0073e9SAndroid Build Coastguard Worker idx = [[1, 1, 3, 2, 1, 2], [0]] 3025*da0073e9SAndroid Build Coastguard Worker y[idx].sum().backward() 3026*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.zeros(4, 4) 3027*da0073e9SAndroid Build Coastguard Worker for i in idx[0]: 3028*da0073e9SAndroid Build Coastguard Worker for j in idx[1]: 3029*da0073e9SAndroid Build Coastguard Worker expected_grad[i][j] += 1 3030*da0073e9SAndroid Build Coastguard Worker 3031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, expected_grad) 3032*da0073e9SAndroid Build Coastguard Worker 3033*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1.0, 17).view(4, 4) 3034*da0073e9SAndroid Build Coastguard Worker y = Variable(x, requires_grad=True) 3035*da0073e9SAndroid Build Coastguard Worker idx = [[[1, 2], [0, 0]], [[0, 1], [1, 1]]] 3036*da0073e9SAndroid Build Coastguard Worker y[idx].sum().backward() 3037*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.tensor( 3038*da0073e9SAndroid Build Coastguard Worker [ 3039*da0073e9SAndroid Build Coastguard Worker [0.0, 2.0, 0.0, 0.0], 3040*da0073e9SAndroid Build Coastguard Worker [1.0, 0.0, 0.0, 0.0], 3041*da0073e9SAndroid Build Coastguard Worker [0.0, 1.0, 0.0, 0.0], 3042*da0073e9SAndroid Build Coastguard Worker [0.0, 0.0, 0.0, 0.0], 3043*da0073e9SAndroid Build Coastguard Worker ] 3044*da0073e9SAndroid Build Coastguard Worker ) 3045*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, expected_grad) 3046*da0073e9SAndroid Build Coastguard Worker 3047*da0073e9SAndroid Build Coastguard Worker x = torch.arange(1.0, 65).view(4, 4, 4) 3048*da0073e9SAndroid Build Coastguard Worker y = Variable(x, requires_grad=True) 3049*da0073e9SAndroid Build Coastguard Worker 3050*da0073e9SAndroid Build Coastguard Worker idx = [[1, 1, 1], slice(None), slice(None)] 3051*da0073e9SAndroid Build Coastguard Worker y[idx].sum().backward() 3052*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.empty(4, 4, 4).zero_() 3053*da0073e9SAndroid Build Coastguard Worker expected_grad[1].fill_(3) 3054*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, expected_grad) 3055*da0073e9SAndroid Build Coastguard Worker 3056*da0073e9SAndroid Build Coastguard Worker def test_index_backward_does_not_save_tensor(self): 3057*da0073e9SAndroid Build Coastguard Worker # Example from https://github.com/pytorch/pytorch/issues/24853. 3058*da0073e9SAndroid Build Coastguard Worker # if `index(tensor, indices)` saves `tensor` for backwards, then it will 3059*da0073e9SAndroid Build Coastguard Worker # trigger a version check on `tensor` during the backward pass, which 3060*da0073e9SAndroid Build Coastguard Worker # will cause the following code to error because `tensor` gets modified 3061*da0073e9SAndroid Build Coastguard Worker # by the indexing line. 3062*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0, 0, 0]) 3063*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(3, requires_grad=True) 3064*da0073e9SAndroid Build Coastguard Worker tensor = b + 0 3065*da0073e9SAndroid Build Coastguard Worker tensor[a != 0] = tensor[a != 0] 3066*da0073e9SAndroid Build Coastguard Worker tensor.backward(torch.zeros_like(tensor)) 3067*da0073e9SAndroid Build Coastguard Worker 3068*da0073e9SAndroid Build Coastguard Worker def test_volatile_deprecated(self): 3069*da0073e9SAndroid Build Coastguard Worker v = torch.autograd.torch.randn(3, 3) 3070*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 3071*da0073e9SAndroid Build Coastguard Worker self.assertFalse(v.volatile) 3072*da0073e9SAndroid Build Coastguard Worker self.assertIn("volatile", str(w[0].message)) 3073*da0073e9SAndroid Build Coastguard Worker 3074*da0073e9SAndroid Build Coastguard Worker def test_saved_variables_deprecated(self): 3075*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 3076*da0073e9SAndroid Build Coastguard Worker @staticmethod 3077*da0073e9SAndroid Build Coastguard Worker def forward(ctx, tensor1, tensor2): 3078*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(tensor1, tensor2) 3079*da0073e9SAndroid Build Coastguard Worker return tensor1 + tensor2 3080*da0073e9SAndroid Build Coastguard Worker 3081*da0073e9SAndroid Build Coastguard Worker @staticmethod 3082*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3083*da0073e9SAndroid Build Coastguard Worker var1, var2 = ctx.saved_variables 3084*da0073e9SAndroid Build Coastguard Worker return (grad_output, grad_output) 3085*da0073e9SAndroid Build Coastguard Worker 3086*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as warns: 3087*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 3088*da0073e9SAndroid Build Coastguard Worker x = torch.randn((3, 3), requires_grad=True) 3089*da0073e9SAndroid Build Coastguard Worker y = torch.randn((3, 3), requires_grad=True) 3090*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(x, y).sum().backward() 3091*da0073e9SAndroid Build Coastguard Worker 3092*da0073e9SAndroid Build Coastguard Worker has_deprecated = ( 3093*da0073e9SAndroid Build Coastguard Worker "deprecated" in str(warn) and "saved_variables" in str(warn) 3094*da0073e9SAndroid Build Coastguard Worker for warn in warns 3095*da0073e9SAndroid Build Coastguard Worker ) 3096*da0073e9SAndroid Build Coastguard Worker has_deprecated = reduce(lambda x, y: x or y, has_deprecated) 3097*da0073e9SAndroid Build Coastguard Worker self.assertTrue(has_deprecated) 3098*da0073e9SAndroid Build Coastguard Worker 3099*da0073e9SAndroid Build Coastguard Worker def test_requires_grad(self): 3100*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 3101*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5) 3102*da0073e9SAndroid Build Coastguard Worker z = torch.randn(5, 5, requires_grad=True) 3103*da0073e9SAndroid Build Coastguard Worker a = x + y 3104*da0073e9SAndroid Build Coastguard Worker self.assertFalse(a.requires_grad) 3105*da0073e9SAndroid Build Coastguard Worker b = a + z 3106*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.requires_grad) 3107*da0073e9SAndroid Build Coastguard Worker 3108*da0073e9SAndroid Build Coastguard Worker def error(): 3109*da0073e9SAndroid Build Coastguard Worker raise RuntimeError 3110*da0073e9SAndroid Build Coastguard Worker 3111*da0073e9SAndroid Build Coastguard Worker # Make sure backward isn't called on these 3112*da0073e9SAndroid Build Coastguard Worker a._backward_hooks = OrderedDict() 3113*da0073e9SAndroid Build Coastguard Worker x._backward_hooks = OrderedDict() 3114*da0073e9SAndroid Build Coastguard Worker y._backward_hooks = OrderedDict() 3115*da0073e9SAndroid Build Coastguard Worker a._backward_hooks["test"] = error 3116*da0073e9SAndroid Build Coastguard Worker x._backward_hooks["test"] = error 3117*da0073e9SAndroid Build Coastguard Worker y._backward_hooks["test"] = error 3118*da0073e9SAndroid Build Coastguard Worker b.backward(torch.ones(5, 5)) 3119*da0073e9SAndroid Build Coastguard Worker 3120*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_(self): 3121*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 3122*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 3123*da0073e9SAndroid Build Coastguard Worker self.assertIs(x, x.requires_grad_()) 3124*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 3125*da0073e9SAndroid Build Coastguard Worker self.assertIs(y, y.requires_grad_()) 3126*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.requires_grad) 3127*da0073e9SAndroid Build Coastguard Worker self.assertIs(x, x.requires_grad_(True)) 3128*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 3129*da0073e9SAndroid Build Coastguard Worker self.assertIs(y, y.requires_grad_(True)) 3130*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.requires_grad) 3131*da0073e9SAndroid Build Coastguard Worker z = x * y 3132*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z.requires_grad_(False)) 3133*da0073e9SAndroid Build Coastguard Worker self.assertIs(z, z.requires_grad_()) 3134*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.requires_grad) 3135*da0073e9SAndroid Build Coastguard Worker self.assertIs(z, z.requires_grad_(True)) 3136*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.requires_grad) 3137*da0073e9SAndroid Build Coastguard Worker 3138*da0073e9SAndroid Build Coastguard Worker self.assertIs(x, x.requires_grad_(False)) 3139*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.requires_grad) 3140*da0073e9SAndroid Build Coastguard Worker self.assertIs(y, y.requires_grad_(False)) 3141*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 3142*da0073e9SAndroid Build Coastguard Worker 3143*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_inplace(self): 3144*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5) 3145*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5, requires_grad=True) 3146*da0073e9SAndroid Build Coastguard Worker a += b 3147*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.requires_grad) 3148*da0073e9SAndroid Build Coastguard Worker 3149*da0073e9SAndroid Build Coastguard Worker # non-leaf 3150*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5) + 0 3151*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5, requires_grad=True) 3152*da0073e9SAndroid Build Coastguard Worker a += b 3153*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.requires_grad) 3154*da0073e9SAndroid Build Coastguard Worker 3155*da0073e9SAndroid Build Coastguard Worker def test_no_requires_grad_inplace(self): 3156*da0073e9SAndroid Build Coastguard Worker # basic case, should be able to modify inplace while requires_grad is False 3157*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3) 3158*da0073e9SAndroid Build Coastguard Worker a.add_(5) 3159*da0073e9SAndroid Build Coastguard Worker a.requires_grad = True 3160*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 3161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.ones(2, 3)) 3162*da0073e9SAndroid Build Coastguard Worker 3163*da0073e9SAndroid Build Coastguard Worker # same but with a view 3164*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3) 3165*da0073e9SAndroid Build Coastguard Worker b = a[:] 3166*da0073e9SAndroid Build Coastguard Worker b.add_(5) 3167*da0073e9SAndroid Build Coastguard Worker a.requires_grad = True 3168*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 3169*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, torch.ones(2, 3)) 3170*da0073e9SAndroid Build Coastguard Worker 3171*da0073e9SAndroid Build Coastguard Worker # should fail if requires_grad = True when we modify inplace 3172*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3) 3173*da0073e9SAndroid Build Coastguard Worker b = a[:] 3174*da0073e9SAndroid Build Coastguard Worker a.requires_grad = True 3175*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3176*da0073e9SAndroid Build Coastguard Worker a.add_(5) 3177*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3178*da0073e9SAndroid Build Coastguard Worker b.add_(5) 3179*da0073e9SAndroid Build Coastguard Worker 3180*da0073e9SAndroid Build Coastguard Worker def test_attribute_deletion(self): 3181*da0073e9SAndroid Build Coastguard Worker x = torch.randn((5, 5), requires_grad=True) 3182*da0073e9SAndroid Build Coastguard Worker del x.grad 3183*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad) 3184*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3185*da0073e9SAndroid Build Coastguard Worker del x.data 3186*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(TypeError): 3187*da0073e9SAndroid Build Coastguard Worker x.data = None 3188*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3189*da0073e9SAndroid Build Coastguard Worker del x.requires_grad 3190*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3191*da0073e9SAndroid Build Coastguard Worker del x._grad_fn 3192*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3193*da0073e9SAndroid Build Coastguard Worker del x._backward_hooks 3194*da0073e9SAndroid Build Coastguard Worker 3195*da0073e9SAndroid Build Coastguard Worker def test_duplicate_backward_root(self): 3196*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5, requires_grad=True) 3197*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 5, requires_grad=True) 3198*da0073e9SAndroid Build Coastguard Worker 3199*da0073e9SAndroid Build Coastguard Worker x = a * b 3200*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn_like(x) 3201*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([x, x], [grad_output, grad_output]) 3202*da0073e9SAndroid Build Coastguard Worker 3203*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, b * grad_output * 2) 3204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, a * grad_output * 2) 3205*da0073e9SAndroid Build Coastguard Worker 3206*da0073e9SAndroid Build Coastguard Worker def test_backward_no_grad(self): 3207*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 5, requires_grad=True) 3208*da0073e9SAndroid Build Coastguard Worker b = a + 2 3209*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 3210*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([b], [None]) 3211*da0073e9SAndroid Build Coastguard Worker 3212*da0073e9SAndroid Build Coastguard Worker def test_backward_twice_with_saved_values(self): 3213*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, requires_grad=True, dtype=torch.double) 3214*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(3, dtype=torch.double) 3215*da0073e9SAndroid Build Coastguard Worker c[[1, 2]] = b[[1, 1]] 3216*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3217*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 3218*da0073e9SAndroid Build Coastguard Worker RuntimeError, 3219*da0073e9SAndroid Build Coastguard Worker "Specify retain_graph=True", 3220*da0073e9SAndroid Build Coastguard Worker lambda: c.backward(torch.tensor([1, 1, 1], dtype=torch.double)), 3221*da0073e9SAndroid Build Coastguard Worker ) 3222*da0073e9SAndroid Build Coastguard Worker 3223*da0073e9SAndroid Build Coastguard Worker def test_backward_twice_retained_graph_with_saved_values(self): 3224*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, requires_grad=True, dtype=torch.double) 3225*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(3, dtype=torch.double) 3226*da0073e9SAndroid Build Coastguard Worker c[[1, 2]] = b[[1, 1]] 3227*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True) 3228*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3229*da0073e9SAndroid Build Coastguard Worker 3230*da0073e9SAndroid Build Coastguard Worker def test_backward_twice_without_saved_values(self): 3231*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, requires_grad=True, dtype=torch.double) 3232*da0073e9SAndroid Build Coastguard Worker c = b + 1 3233*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3234*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3235*da0073e9SAndroid Build Coastguard Worker 3236*da0073e9SAndroid Build Coastguard Worker def test_backward_twice_retained_graph_without_saved_values(self): 3237*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, requires_grad=True, dtype=torch.double) 3238*da0073e9SAndroid Build Coastguard Worker c = torch.zeros(3, dtype=torch.double) 3239*da0073e9SAndroid Build Coastguard Worker c[[1, 2]] = b[[1, 1]] 3240*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True) 3241*da0073e9SAndroid Build Coastguard Worker c.backward(torch.tensor([1, 1, 1], dtype=torch.double)) 3242*da0073e9SAndroid Build Coastguard Worker 3243*da0073e9SAndroid Build Coastguard Worker def test_backward_create_graph_warns(self): 3244*da0073e9SAndroid Build Coastguard Worker with set_warn_always_context(True): 3245*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, requires_grad=True, dtype=torch.double) 3246*da0073e9SAndroid Build Coastguard Worker c = b * b 3247*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 3248*da0073e9SAndroid Build Coastguard Worker c.backward(torch.ones_like(c), create_graph=True) 3249*da0073e9SAndroid Build Coastguard Worker b.grad = None 3250*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 3251*da0073e9SAndroid Build Coastguard Worker any( 3252*da0073e9SAndroid Build Coastguard Worker "Using backward() with create_graph=True" in str(w.message) 3253*da0073e9SAndroid Build Coastguard Worker for w in ws 3254*da0073e9SAndroid Build Coastguard Worker ) 3255*da0073e9SAndroid Build Coastguard Worker ) 3256*da0073e9SAndroid Build Coastguard Worker 3257*da0073e9SAndroid Build Coastguard Worker # Should not warn for grad 3258*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as ws: 3259*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(c, b, torch.ones_like(c), create_graph=True) 3260*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 3261*da0073e9SAndroid Build Coastguard Worker any( 3262*da0073e9SAndroid Build Coastguard Worker "Using backward() with create_graph=True" in str(w.message) 3263*da0073e9SAndroid Build Coastguard Worker for w in ws 3264*da0073e9SAndroid Build Coastguard Worker ) 3265*da0073e9SAndroid Build Coastguard Worker ) 3266*da0073e9SAndroid Build Coastguard Worker 3267*da0073e9SAndroid Build Coastguard Worker def test_next_functions(self): 3268*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 3269*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 3270*da0073e9SAndroid Build Coastguard Worker 3271*da0073e9SAndroid Build Coastguard Worker a = x + y 3272*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(a.grad_fn) 3273*da0073e9SAndroid Build Coastguard Worker next_functions = a.grad_fn.next_functions 3274*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(next_functions), 2) 3275*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(next_functions[0][0], torch._C._functions.AccumulateGrad) 3276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next_functions[0][1], 0) 3277*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(next_functions[1][0], torch._C._functions.AccumulateGrad) 3278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(next_functions[1][1], 0) 3279*da0073e9SAndroid Build Coastguard Worker 3280*da0073e9SAndroid Build Coastguard Worker b = a + 5 3281*da0073e9SAndroid Build Coastguard Worker next_functions = b.grad_fn.next_functions 3282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(next_functions), 2) 3283*da0073e9SAndroid Build Coastguard Worker self.assertIs(next_functions[0][0], a.grad_fn) 3284*da0073e9SAndroid Build Coastguard Worker self.assertIs(next_functions[1][0], None) 3285*da0073e9SAndroid Build Coastguard Worker 3286*da0073e9SAndroid Build Coastguard Worker def test_inplace(self): 3287*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 3288*da0073e9SAndroid Build Coastguard Worker y = Variable(torch.ones(5, 5) * 4, requires_grad=True) 3289*da0073e9SAndroid Build Coastguard Worker 3290*da0073e9SAndroid Build Coastguard Worker z = x * y 3291*da0073e9SAndroid Build Coastguard Worker q = z + y 3292*da0073e9SAndroid Build Coastguard Worker w = z * y 3293*da0073e9SAndroid Build Coastguard Worker z.add_(2) 3294*da0073e9SAndroid Build Coastguard Worker # Add doesn't need it's inputs to do backward, so it shouldn't raise 3295*da0073e9SAndroid Build Coastguard Worker q.backward(torch.ones(5, 5), retain_graph=True) 3296*da0073e9SAndroid Build Coastguard Worker # Mul saves both inputs in forward, so it should raise 3297*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: w.backward(torch.ones(5, 5))) 3298*da0073e9SAndroid Build Coastguard Worker 3299*da0073e9SAndroid Build Coastguard Worker z = x * y 3300*da0073e9SAndroid Build Coastguard Worker q = z * y 3301*da0073e9SAndroid Build Coastguard Worker r = z + y 3302*da0073e9SAndroid Build Coastguard Worker w = z.add_(y) 3303*da0073e9SAndroid Build Coastguard Worker # w is a the last expression, so this should succeed 3304*da0073e9SAndroid Build Coastguard Worker w.backward(torch.ones(5, 5), retain_graph=True) 3305*da0073e9SAndroid Build Coastguard Worker # r doesn't use the modified value in backward, so it should succeed 3306*da0073e9SAndroid Build Coastguard Worker r.backward(torch.ones(5, 5), retain_graph=True) 3307*da0073e9SAndroid Build Coastguard Worker # q uses dirty z, so it should raise 3308*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) 3309*da0073e9SAndroid Build Coastguard Worker 3310*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3311*da0073e9SAndroid Build Coastguard Worker x.grad.zero_() 3312*da0073e9SAndroid Build Coastguard Worker m = x / 2 3313*da0073e9SAndroid Build Coastguard Worker z = m + y / 8 3314*da0073e9SAndroid Build Coastguard Worker q = z * y 3315*da0073e9SAndroid Build Coastguard Worker r = z + y 3316*da0073e9SAndroid Build Coastguard Worker prev_version = z._version 3317*da0073e9SAndroid Build Coastguard Worker w = z.exp_() 3318*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(z._version, prev_version) 3319*da0073e9SAndroid Build Coastguard Worker r.backward(torch.ones(5, 5), retain_graph=True) 3320*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(5, 5) / 2) 3321*da0073e9SAndroid Build Coastguard Worker w.backward(torch.ones(5, 5), retain_graph=True) 3322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.empty(5, 5).fill_((1 + math.e) / 2)) 3323*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: q.backward(torch.ones(5, 5))) 3324*da0073e9SAndroid Build Coastguard Worker 3325*da0073e9SAndroid Build Coastguard Worker leaf = torch.ones(5, 5, requires_grad=True) 3326*da0073e9SAndroid Build Coastguard Worker x = leaf.clone() 3327*da0073e9SAndroid Build Coastguard Worker x.add_(10) 3328*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.ones(5, 5) * 11) 3329*da0073e9SAndroid Build Coastguard Worker # x should be still usable 3330*da0073e9SAndroid Build Coastguard Worker y = x + 2 3331*da0073e9SAndroid Build Coastguard Worker y.backward(torch.ones(5, 5)) 3332*da0073e9SAndroid Build Coastguard Worker self.assertEqual(leaf.grad, torch.ones(5, 5)) 3333*da0073e9SAndroid Build Coastguard Worker z = x * y 3334*da0073e9SAndroid Build Coastguard Worker x.add_(2) 3335*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5))) 3336*da0073e9SAndroid Build Coastguard Worker 3337*da0073e9SAndroid Build Coastguard Worker def test_mark_non_differentiable(self): 3338*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 3339*da0073e9SAndroid Build Coastguard Worker @staticmethod 3340*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 3341*da0073e9SAndroid Build Coastguard Worker output = input > 0 3342*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(output) 3343*da0073e9SAndroid Build Coastguard Worker return output 3344*da0073e9SAndroid Build Coastguard Worker 3345*da0073e9SAndroid Build Coastguard Worker @staticmethod 3346*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3347*da0073e9SAndroid Build Coastguard Worker return (grad_output * 0).to(torch.double) 3348*da0073e9SAndroid Build Coastguard Worker 3349*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 3350*da0073e9SAndroid Build Coastguard Worker mask = MyFunction.apply(x) 3351*da0073e9SAndroid Build Coastguard Worker self.assertFalse(mask.requires_grad) 3352*da0073e9SAndroid Build Coastguard Worker y = x.masked_fill(mask, 0) 3353*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 3354*da0073e9SAndroid Build Coastguard Worker 3355*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 3356*da0073e9SAndroid Build Coastguard Worker def test_mark_non_differentiable_mixed(self): 3357*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 3358*da0073e9SAndroid Build Coastguard Worker @staticmethod 3359*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 3360*da0073e9SAndroid Build Coastguard Worker a = input + 1 3361*da0073e9SAndroid Build Coastguard Worker b = input + 2 3362*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(a) 3363*da0073e9SAndroid Build Coastguard Worker return a, b 3364*da0073e9SAndroid Build Coastguard Worker 3365*da0073e9SAndroid Build Coastguard Worker @staticmethod 3366*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_a, grad_b): 3367*da0073e9SAndroid Build Coastguard Worker self.assertTrue((grad_a == 0).all()) 3368*da0073e9SAndroid Build Coastguard Worker self.assertTrue((grad_b == 1).all()) 3369*da0073e9SAndroid Build Coastguard Worker return grad_b 3370*da0073e9SAndroid Build Coastguard Worker 3371*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 3372*da0073e9SAndroid Build Coastguard Worker a, b = MyFunction.apply(x) 3373*da0073e9SAndroid Build Coastguard Worker self.assertFalse(a.requires_grad) 3374*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b.requires_grad) 3375*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 3376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(5, 5)) 3377*da0073e9SAndroid Build Coastguard Worker 3378*da0073e9SAndroid Build Coastguard Worker def test_mark_non_differentiable_none(self): 3379*da0073e9SAndroid Build Coastguard Worker # This used to segfault because MyFunction would send back null 3380*da0073e9SAndroid Build Coastguard Worker # gradients to MulBackward, which is implemented in C++. C++ 3381*da0073e9SAndroid Build Coastguard Worker # implemented functions expect incoming grad_outputs to be non-null. 3382*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 3383*da0073e9SAndroid Build Coastguard Worker @staticmethod 3384*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 3385*da0073e9SAndroid Build Coastguard Worker output = input.clone() 3386*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(output) 3387*da0073e9SAndroid Build Coastguard Worker return output 3388*da0073e9SAndroid Build Coastguard Worker 3389*da0073e9SAndroid Build Coastguard Worker @staticmethod 3390*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3391*da0073e9SAndroid Build Coastguard Worker return None 3392*da0073e9SAndroid Build Coastguard Worker 3393*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 3394*da0073e9SAndroid Build Coastguard Worker r = MyFunction.apply(x * x) 3395*da0073e9SAndroid Build Coastguard Worker (r * x).sum().backward() 3396*da0073e9SAndroid Build Coastguard Worker 3397*da0073e9SAndroid Build Coastguard Worker def test_return_duplicate(self): 3398*da0073e9SAndroid Build Coastguard Worker class DoubleDuplicate(Function): 3399*da0073e9SAndroid Build Coastguard Worker @staticmethod 3400*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3401*da0073e9SAndroid Build Coastguard Worker output = x * 2 3402*da0073e9SAndroid Build Coastguard Worker return output, output 3403*da0073e9SAndroid Build Coastguard Worker 3404*da0073e9SAndroid Build Coastguard Worker @staticmethod 3405*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad1, grad2): 3406*da0073e9SAndroid Build Coastguard Worker return grad1 * 2 + grad2 * 2 3407*da0073e9SAndroid Build Coastguard Worker 3408*da0073e9SAndroid Build Coastguard Worker def fn(x): 3409*da0073e9SAndroid Build Coastguard Worker a, b = DoubleDuplicate.apply(x) 3410*da0073e9SAndroid Build Coastguard Worker self.assertIs(a, b) 3411*da0073e9SAndroid Build Coastguard Worker return a + b 3412*da0073e9SAndroid Build Coastguard Worker 3413*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.double, requires_grad=True) 3414*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, [x]) 3415*da0073e9SAndroid Build Coastguard Worker gradgradcheck(fn, [x]) 3416*da0073e9SAndroid Build Coastguard Worker 3417*da0073e9SAndroid Build Coastguard Worker def test_return_duplicate_inplace(self): 3418*da0073e9SAndroid Build Coastguard Worker class DoubleInplace(Function): 3419*da0073e9SAndroid Build Coastguard Worker @staticmethod 3420*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3421*da0073e9SAndroid Build Coastguard Worker x.mul_(2) 3422*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(x) 3423*da0073e9SAndroid Build Coastguard Worker return x, x 3424*da0073e9SAndroid Build Coastguard Worker 3425*da0073e9SAndroid Build Coastguard Worker @staticmethod 3426*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad1, grad2): 3427*da0073e9SAndroid Build Coastguard Worker return grad1 * 2 + grad2 * 2 3428*da0073e9SAndroid Build Coastguard Worker 3429*da0073e9SAndroid Build Coastguard Worker def inplace_fn(x): 3430*da0073e9SAndroid Build Coastguard Worker a, b = DoubleInplace.apply(x.clone()) 3431*da0073e9SAndroid Build Coastguard Worker self.assertIs(a, b) 3432*da0073e9SAndroid Build Coastguard Worker return a + b 3433*da0073e9SAndroid Build Coastguard Worker 3434*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.double, requires_grad=True) 3435*da0073e9SAndroid Build Coastguard Worker gradcheck(inplace_fn, [x]) 3436*da0073e9SAndroid Build Coastguard Worker gradgradcheck(inplace_fn, [x]) 3437*da0073e9SAndroid Build Coastguard Worker 3438*da0073e9SAndroid Build Coastguard Worker # Can't modify leaf variables in-place 3439*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x)) 3440*da0073e9SAndroid Build Coastguard Worker # Functions which modify views in-place must return only one output 3441*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: InplaceFunction.apply(x.clone()[0])) 3442*da0073e9SAndroid Build Coastguard Worker 3443*da0073e9SAndroid Build Coastguard Worker def _test_setitem(self, size, index): 3444*da0073e9SAndroid Build Coastguard Worker x = torch.ones(*size, requires_grad=True) 3445*da0073e9SAndroid Build Coastguard Worker y = x + 2 3446*da0073e9SAndroid Build Coastguard Worker y_version = y._version 3447*da0073e9SAndroid Build Coastguard Worker y[index] = 2 3448*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(y._version, y_version) 3449*da0073e9SAndroid Build Coastguard Worker y.backward(torch.ones(*size)) 3450*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.ones(*size) 3451*da0073e9SAndroid Build Coastguard Worker expected_grad[index] = 0 3452*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, expected_grad) 3453*da0073e9SAndroid Build Coastguard Worker 3454*da0073e9SAndroid Build Coastguard Worker def _test_setitem_tensor(self, size, index): 3455*da0073e9SAndroid Build Coastguard Worker x = torch.ones(*size, requires_grad=True) 3456*da0073e9SAndroid Build Coastguard Worker y = x + 2 3457*da0073e9SAndroid Build Coastguard Worker y_version = y._version 3458*da0073e9SAndroid Build Coastguard Worker value = x.new(x[index].size()).fill_(7) 3459*da0073e9SAndroid Build Coastguard Worker value.requires_grad = True 3460*da0073e9SAndroid Build Coastguard Worker y[index] = value 3461*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(y._version, y_version) 3462*da0073e9SAndroid Build Coastguard Worker y.backward(torch.ones(*size)) 3463*da0073e9SAndroid Build Coastguard Worker expected_grad_input = torch.ones(*size) 3464*da0073e9SAndroid Build Coastguard Worker expected_grad_input[index] = 0 3465*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, expected_grad_input) 3466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(value.grad, torch.ones_like(value)) 3467*da0073e9SAndroid Build Coastguard Worker 3468*da0073e9SAndroid Build Coastguard Worker # case when x broadcasts to as y[1] 3469*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, requires_grad=True) 3470*da0073e9SAndroid Build Coastguard Worker y = torch.zeros(2, 3, 4) 3471*da0073e9SAndroid Build Coastguard Worker y[1] = x 3472*da0073e9SAndroid Build Coastguard Worker y.backward(torch.randn(2, 3, 4)) 3473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.size(), x.grad.size()) 3474*da0073e9SAndroid Build Coastguard Worker 3475*da0073e9SAndroid Build Coastguard Worker def test_setitem(self): 3476*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5), 1) 3477*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5,), 1) 3478*da0073e9SAndroid Build Coastguard Worker self._test_setitem((1,), 0) 3479*da0073e9SAndroid Build Coastguard Worker self._test_setitem((10,), [[0, 4, 2]]) 3480*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5), [[0, 4], [2, 2]]) 3481*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5, 5), [slice(None), slice(None), [1, 3]]) 3482*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5, 5), [slice(None), [1, 3], slice(None)]) 3483*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5, 5), [[1, 3], slice(None), slice(None)]) 3484*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5, 5), [slice(None), [2, 4], [1, 3]]) 3485*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5, 5), [[1, 3], [2, 4], slice(None)]) 3486*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5), 3) 3487*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5), [[0, 1], [1, 0]]) 3488*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5,), 3) 3489*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor( 3490*da0073e9SAndroid Build Coastguard Worker (5,), Variable(torch.LongTensor([3]), requires_grad=False).sum() 3491*da0073e9SAndroid Build Coastguard Worker ) 3492*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5,), [[0, 1, 2, 3]]) 3493*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5, 5), [slice(None), slice(None), [1, 3]]) 3494*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5, 5), [slice(None), [1, 3], slice(None)]) 3495*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5, 5), [[1, 3], slice(None), slice(None)]) 3496*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5, 5), [slice(None), [2, 4], [1, 3]]) 3497*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5, 5), [[1, 3], [2, 4], slice(None)]) 3498*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor( 3499*da0073e9SAndroid Build Coastguard Worker (5, 5, 5), 3500*da0073e9SAndroid Build Coastguard Worker [ 3501*da0073e9SAndroid Build Coastguard Worker Variable(torch.LongTensor([1, 3]), requires_grad=False), 3502*da0073e9SAndroid Build Coastguard Worker [2, 4], 3503*da0073e9SAndroid Build Coastguard Worker slice(None), 3504*da0073e9SAndroid Build Coastguard Worker ], 3505*da0073e9SAndroid Build Coastguard Worker ) 3506*da0073e9SAndroid Build Coastguard Worker 3507*da0073e9SAndroid Build Coastguard Worker def test_setitem_mask(self): 3508*da0073e9SAndroid Build Coastguard Worker mask = torch.BoolTensor(5, 5).bernoulli_() 3509*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5, 5), Variable(mask)) 3510*da0073e9SAndroid Build Coastguard Worker self._test_setitem((5,), Variable(mask[0])) 3511*da0073e9SAndroid Build Coastguard Worker self._test_setitem((1,), Variable(mask[0, 0:1])) 3512*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5, 5), Variable(mask)) 3513*da0073e9SAndroid Build Coastguard Worker self._test_setitem_tensor((5,), Variable(mask[0])) 3514*da0073e9SAndroid Build Coastguard Worker 3515*da0073e9SAndroid Build Coastguard Worker def test_select_sum(self): 3516*da0073e9SAndroid Build Coastguard Worker # both select and sum return Scalars in ATen; ensure they work together. 3517*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, dtype=torch.double, requires_grad=True) 3518*da0073e9SAndroid Build Coastguard Worker 3519*da0073e9SAndroid Build Coastguard Worker def func(x): 3520*da0073e9SAndroid Build Coastguard Worker return x.select(0, 1).sum() 3521*da0073e9SAndroid Build Coastguard Worker 3522*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [x]) 3523*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, [x]) 3524*da0073e9SAndroid Build Coastguard Worker 3525*da0073e9SAndroid Build Coastguard Worker def test_diagonal_expanded_v(self): 3526*da0073e9SAndroid Build Coastguard Worker value = torch.rand([]) 3527*da0073e9SAndroid Build Coastguard Worker v_expanded = torch.tensor(value).expand(10) 3528*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, 10, dtype=torch.double, requires_grad=True) 3529*da0073e9SAndroid Build Coastguard Worker (result,) = torch.autograd.grad(a.diagonal(), a, v_expanded) 3530*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, torch.eye(10, dtype=torch.double) * value) 3531*da0073e9SAndroid Build Coastguard Worker 3532*da0073e9SAndroid Build Coastguard Worker def test_select_expanded_v(self): 3533*da0073e9SAndroid Build Coastguard Worker v_expanded = torch.rand(10).expand(10, 10) 3534*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, 10, 10, requires_grad=True) 3535*da0073e9SAndroid Build Coastguard Worker (result,) = torch.autograd.grad(a[0], a, v_expanded) 3536*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(10, 10, 10) 3537*da0073e9SAndroid Build Coastguard Worker expected[0] = v_expanded 3538*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 3539*da0073e9SAndroid Build Coastguard Worker 3540*da0073e9SAndroid Build Coastguard Worker def test_slice_expanded_v(self): 3541*da0073e9SAndroid Build Coastguard Worker v_expanded = torch.rand(10, 1).expand(2, 10, 10) 3542*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, 10, 10, requires_grad=True) 3543*da0073e9SAndroid Build Coastguard Worker (result,) = torch.autograd.grad(a[3:5], a, v_expanded) 3544*da0073e9SAndroid Build Coastguard Worker expected = torch.zeros(10, 10, 10) 3545*da0073e9SAndroid Build Coastguard Worker expected[3:5] = v_expanded 3546*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result, expected) 3547*da0073e9SAndroid Build Coastguard Worker 3548*da0073e9SAndroid Build Coastguard Worker def test_unused_output(self): 3549*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 3550*da0073e9SAndroid Build Coastguard Worker outputs = x.chunk(5) 3551*da0073e9SAndroid Build Coastguard Worker o = outputs[2] 3552*da0073e9SAndroid Build Coastguard Worker o = o * 4 + 2 3553*da0073e9SAndroid Build Coastguard Worker o.sum().backward() 3554*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.zeros(10, 10) 3555*da0073e9SAndroid Build Coastguard Worker expected_grad[4:6] = 4 3556*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, expected_grad) 3557*da0073e9SAndroid Build Coastguard Worker 3558*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3559*da0073e9SAndroid Build Coastguard Worker x.grad.zero_() 3560*da0073e9SAndroid Build Coastguard Worker grad_output = torch.randn(2, 10) 3561*da0073e9SAndroid Build Coastguard Worker outputs = x.chunk(5) 3562*da0073e9SAndroid Build Coastguard Worker outputs[0].backward(grad_output) 3563*da0073e9SAndroid Build Coastguard Worker expected_grad = torch.zeros(10, 10) 3564*da0073e9SAndroid Build Coastguard Worker expected_grad[:2] = grad_output 3565*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, expected_grad) 3566*da0073e9SAndroid Build Coastguard Worker 3567*da0073e9SAndroid Build Coastguard Worker # TODO: opinfo this or move to the sparse test suite 3568*da0073e9SAndroid Build Coastguard Worker def _test_sparse_gather(self, size_x, size_ind, dim): 3569*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size_x, requires_grad=True) 3570*da0073e9SAndroid Build Coastguard Worker if len(size_ind) > 0 and len(size_x) > 0: 3571*da0073e9SAndroid Build Coastguard Worker ind = torch.randint(x.size(dim), size_ind) 3572*da0073e9SAndroid Build Coastguard Worker else: 3573*da0073e9SAndroid Build Coastguard Worker ind = torch.zeros(size_ind, dtype=torch.int64) 3574*da0073e9SAndroid Build Coastguard Worker out = torch.gather(x, dim, ind, sparse_grad=False) 3575*da0073e9SAndroid Build Coastguard Worker grad = torch.rand_like(out) 3576*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 3577*da0073e9SAndroid Build Coastguard Worker grad_dense = x.grad.clone() 3578*da0073e9SAndroid Build Coastguard Worker x.grad = None 3579*da0073e9SAndroid Build Coastguard Worker out = torch.gather(x, dim, ind, sparse_grad=True) 3580*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 3581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_dense, x.grad.to_dense()) 3582*da0073e9SAndroid Build Coastguard Worker 3583*da0073e9SAndroid Build Coastguard Worker def test_sparse_gather_dim0(self): 3584*da0073e9SAndroid Build Coastguard Worker self._test_sparse_gather((10, 10), (5, 10), 0) 3585*da0073e9SAndroid Build Coastguard Worker 3586*da0073e9SAndroid Build Coastguard Worker def test_sparse_gather_dim1(self): 3587*da0073e9SAndroid Build Coastguard Worker self._test_sparse_gather((10, 10, 5), (10, 5, 5), 1) 3588*da0073e9SAndroid Build Coastguard Worker 3589*da0073e9SAndroid Build Coastguard Worker def test_sparse_gather_dim_neg(self): 3590*da0073e9SAndroid Build Coastguard Worker self._test_sparse_gather((10, 10, 5), (10, 10, 2), -1) 3591*da0073e9SAndroid Build Coastguard Worker 3592*da0073e9SAndroid Build Coastguard Worker def test_sparse_gather_ind_scalar(self): 3593*da0073e9SAndroid Build Coastguard Worker self._test_sparse_gather((10,), (), 0) 3594*da0073e9SAndroid Build Coastguard Worker 3595*da0073e9SAndroid Build Coastguard Worker def test_sparse_gather_x_scalar(self): 3596*da0073e9SAndroid Build Coastguard Worker self._test_sparse_gather((), (2,), 0) 3597*da0073e9SAndroid Build Coastguard Worker 3598*da0073e9SAndroid Build Coastguard Worker def test_sparse_gather_both_scalar(self): 3599*da0073e9SAndroid Build Coastguard Worker self._test_sparse_gather((), (), 0) 3600*da0073e9SAndroid Build Coastguard Worker 3601*da0073e9SAndroid Build Coastguard Worker def test_gc_in_destructor(self): 3602*da0073e9SAndroid Build Coastguard Worker """ 3603*da0073e9SAndroid Build Coastguard Worker Previously, if a Function destructor triggered a garbage collection, 3604*da0073e9SAndroid Build Coastguard Worker the Variable's tp_dealloc handler would get called twice leading to a 3605*da0073e9SAndroid Build Coastguard Worker segfault. 3606*da0073e9SAndroid Build Coastguard Worker """ 3607*da0073e9SAndroid Build Coastguard Worker 3608*da0073e9SAndroid Build Coastguard Worker class CollectOnDelete(Function): 3609*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 3610*da0073e9SAndroid Build Coastguard Worker return x 3611*da0073e9SAndroid Build Coastguard Worker 3612*da0073e9SAndroid Build Coastguard Worker def backward(self, grad_output): 3613*da0073e9SAndroid Build Coastguard Worker return grad_output 3614*da0073e9SAndroid Build Coastguard Worker 3615*da0073e9SAndroid Build Coastguard Worker def __del__(self): 3616*da0073e9SAndroid Build Coastguard Worker gc.collect() 3617*da0073e9SAndroid Build Coastguard Worker 3618*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 3619*da0073e9SAndroid Build Coastguard Worker CollectOnDelete().forward(torch.randn(1, requires_grad=True)).backward() 3620*da0073e9SAndroid Build Coastguard Worker 3621*da0073e9SAndroid Build Coastguard Worker def test_naughty_autograd_function_attribute_access(self): 3622*da0073e9SAndroid Build Coastguard Worker class Id(Function): 3623*da0073e9SAndroid Build Coastguard Worker @staticmethod 3624*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3625*da0073e9SAndroid Build Coastguard Worker return x 3626*da0073e9SAndroid Build Coastguard Worker 3627*da0073e9SAndroid Build Coastguard Worker @staticmethod 3628*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_x): 3629*da0073e9SAndroid Build Coastguard Worker return grad_x 3630*da0073e9SAndroid Build Coastguard Worker 3631*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(DeprecationWarning, "should not be instantiated"): 3632*da0073e9SAndroid Build Coastguard Worker f = Id() 3633*da0073e9SAndroid Build Coastguard Worker 3634*da0073e9SAndroid Build Coastguard Worker # After raising warning, should still return an instance 3635*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(f, Id) 3636*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, requires_grad=True) 3637*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3638*da0073e9SAndroid Build Coastguard Worker RuntimeError, "non-static forward method is deprecated" 3639*da0073e9SAndroid Build Coastguard Worker ): 3640*da0073e9SAndroid Build Coastguard Worker f(x) 3641*da0073e9SAndroid Build Coastguard Worker t = Id.apply(x) 3642*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad_fn.name(), "IdBackward") 3643*da0073e9SAndroid Build Coastguard Worker 3644*da0073e9SAndroid Build Coastguard Worker # THPFunction is the base class of both grad_fn and autograd functions, 3645*da0073e9SAndroid Build Coastguard Worker # which means that a lot of accessors on them may segfault. Test that we 3646*da0073e9SAndroid Build Coastguard Worker # properly error in this case. 3647*da0073e9SAndroid Build Coastguard Worker t = torch.ones(1, requires_grad=True) 3648*da0073e9SAndroid Build Coastguard Worker t._backward_hooks = {} 3649*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3650*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Attribute '_register_hook_dict' is invalid" 3651*da0073e9SAndroid Build Coastguard Worker ): 3652*da0073e9SAndroid Build Coastguard Worker f._register_hook_dict(t) 3653*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3654*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Attribute 'register_hook' is invalid" 3655*da0073e9SAndroid Build Coastguard Worker ): 3656*da0073e9SAndroid Build Coastguard Worker f.register_hook(lambda x, y: None) 3657*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3658*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Attribute 'next_functions' is invalid" 3659*da0073e9SAndroid Build Coastguard Worker ): 3660*da0073e9SAndroid Build Coastguard Worker f.next_functions 3661*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Attribute 'name' is invalid"): 3662*da0073e9SAndroid Build Coastguard Worker f.name() 3663*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 3664*da0073e9SAndroid Build Coastguard Worker RuntimeError, "underlying PyNode has already been deallocated" 3665*da0073e9SAndroid Build Coastguard Worker ): 3666*da0073e9SAndroid Build Coastguard Worker f.metadata 3667*da0073e9SAndroid Build Coastguard Worker 3668*da0073e9SAndroid Build Coastguard Worker @unittest.expectedFailure 3669*da0073e9SAndroid Build Coastguard Worker def test_naughty_anomaly_access(self): 3670*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 3671*da0073e9SAndroid Build Coastguard Worker @staticmethod 3672*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3673*da0073e9SAndroid Build Coastguard Worker return x 3674*da0073e9SAndroid Build Coastguard Worker 3675*da0073e9SAndroid Build Coastguard Worker @staticmethod 3676*da0073e9SAndroid Build Coastguard Worker def backward(ctx, g): 3677*da0073e9SAndroid Build Coastguard Worker return g 3678*da0073e9SAndroid Build Coastguard Worker 3679*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, requires_grad=True) 3680*da0073e9SAndroid Build Coastguard Worker y = MyFunction.apply(x) 3681*da0073e9SAndroid Build Coastguard Worker y.backward() 3682*da0073e9SAndroid Build Coastguard Worker y.grad_fn.metadata 3683*da0073e9SAndroid Build Coastguard Worker g = y.grad_fn 3684*da0073e9SAndroid Build Coastguard Worker del y 3685*da0073e9SAndroid Build Coastguard Worker g.metadata # this currently fails, but shouldn't 3686*da0073e9SAndroid Build Coastguard Worker 3687*da0073e9SAndroid Build Coastguard Worker def test_naughty_autograd_function_stashing_ctx(self): 3688*da0073e9SAndroid Build Coastguard Worker saved_ctx = [] 3689*da0073e9SAndroid Build Coastguard Worker 3690*da0073e9SAndroid Build Coastguard Worker class Id(Function): 3691*da0073e9SAndroid Build Coastguard Worker @staticmethod 3692*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3693*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 3694*da0073e9SAndroid Build Coastguard Worker return x 3695*da0073e9SAndroid Build Coastguard Worker 3696*da0073e9SAndroid Build Coastguard Worker @staticmethod 3697*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_x): 3698*da0073e9SAndroid Build Coastguard Worker saved_ctx.append(ctx) 3699*da0073e9SAndroid Build Coastguard Worker return ctx.saved_tensors 3700*da0073e9SAndroid Build Coastguard Worker 3701*da0073e9SAndroid Build Coastguard Worker p = torch.zeros(1, requires_grad=True) 3702*da0073e9SAndroid Build Coastguard Worker loss = Id.apply(p) 3703*da0073e9SAndroid Build Coastguard Worker loss.backward(retain_graph=True) 3704*da0073e9SAndroid Build Coastguard Worker del loss 3705*da0073e9SAndroid Build Coastguard Worker # At this point in time, it complains that the graph has been freed 3706*da0073e9SAndroid Build Coastguard Worker # (which indeed true, although a somewhat indirect way of stating the 3707*da0073e9SAndroid Build Coastguard Worker # problem). 3708*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: saved_ctx[0].saved_tensors) 3709*da0073e9SAndroid Build Coastguard Worker 3710*da0073e9SAndroid Build Coastguard Worker def test_custom_autograd_repeated_grad_grad(self): 3711*da0073e9SAndroid Build Coastguard Worker # This test failed the equality check in PR #22983; it's an interesting 3712*da0073e9SAndroid Build Coastguard Worker # and different test case worth enshrining. mult1 is not testing 3713*da0073e9SAndroid Build Coastguard Worker # anything that interesting, but mult2 is the interesting case. 3714*da0073e9SAndroid Build Coastguard Worker 3715*da0073e9SAndroid Build Coastguard Worker def mult1(x): 3716*da0073e9SAndroid Build Coastguard Worker return x.prod(dim=-1).prod(dim=-1) 3717*da0073e9SAndroid Build Coastguard Worker 3718*da0073e9SAndroid Build Coastguard Worker class Mult(torch.autograd.Function): 3719*da0073e9SAndroid Build Coastguard Worker @staticmethod 3720*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3721*da0073e9SAndroid Build Coastguard Worker y = mult1(x) 3722*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, y) 3723*da0073e9SAndroid Build Coastguard Worker return y 3724*da0073e9SAndroid Build Coastguard Worker 3725*da0073e9SAndroid Build Coastguard Worker @staticmethod 3726*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3727*da0073e9SAndroid Build Coastguard Worker x, y = ctx.saved_tensors 3728*da0073e9SAndroid Build Coastguard Worker return (grad_output * y)[:, None, None] / x 3729*da0073e9SAndroid Build Coastguard Worker 3730*da0073e9SAndroid Build Coastguard Worker mult2 = Mult.apply 3731*da0073e9SAndroid Build Coastguard Worker 3732*da0073e9SAndroid Build Coastguard Worker def check_gradgrad_repeated(x, y): 3733*da0073e9SAndroid Build Coastguard Worker (gy,) = torch.autograd.grad(y[0], x, create_graph=True) 3734*da0073e9SAndroid Build Coastguard Worker (ggy_1,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True) 3735*da0073e9SAndroid Build Coastguard Worker (gy,) = torch.autograd.grad(y[0], x, create_graph=True) 3736*da0073e9SAndroid Build Coastguard Worker (ggy_2,) = torch.autograd.grad(gy[0, 0, 0], x, retain_graph=True) 3737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ggy_1[0, 0, 1], ggy_2[0, 0, 1]) 3738*da0073e9SAndroid Build Coastguard Worker 3739*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 4, 4).requires_grad_() 3740*da0073e9SAndroid Build Coastguard Worker check_gradgrad_repeated(x, mult1(x)) 3741*da0073e9SAndroid Build Coastguard Worker check_gradgrad_repeated(x, mult2(x)) 3742*da0073e9SAndroid Build Coastguard Worker 3743*da0073e9SAndroid Build Coastguard Worker def test_custom_autograd_no_early_free(self): 3744*da0073e9SAndroid Build Coastguard Worker # This test failed complaining that buffers had already been freed 3745*da0073e9SAndroid Build Coastguard Worker # prior to #22983. Also pretty interesting test case. 3746*da0073e9SAndroid Build Coastguard Worker class Double(torch.autograd.Function): 3747*da0073e9SAndroid Build Coastguard Worker @staticmethod 3748*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 3749*da0073e9SAndroid Build Coastguard Worker y = x**2 3750*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, y) 3751*da0073e9SAndroid Build Coastguard Worker return y 3752*da0073e9SAndroid Build Coastguard Worker 3753*da0073e9SAndroid Build Coastguard Worker @staticmethod 3754*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3755*da0073e9SAndroid Build Coastguard Worker x, _ = ctx.saved_tensors 3756*da0073e9SAndroid Build Coastguard Worker return grad_output * 2 * x 3757*da0073e9SAndroid Build Coastguard Worker 3758*da0073e9SAndroid Build Coastguard Worker # this is equivalent, but uses the output of .forward() in .backward() 3759*da0073e9SAndroid Build Coastguard Worker class Double2(Double): 3760*da0073e9SAndroid Build Coastguard Worker @staticmethod 3761*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 3762*da0073e9SAndroid Build Coastguard Worker x, y = ctx.saved_tensors 3763*da0073e9SAndroid Build Coastguard Worker return grad_output * 2 * y / x 3764*da0073e9SAndroid Build Coastguard Worker 3765*da0073e9SAndroid Build Coastguard Worker double = Double.apply 3766*da0073e9SAndroid Build Coastguard Worker double2 = Double2.apply 3767*da0073e9SAndroid Build Coastguard Worker 3768*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2).double().requires_grad_() 3769*da0073e9SAndroid Build Coastguard Worker 3770*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(double, x)) 3771*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradgradcheck(double, x)) 3772*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(double2, x)) 3773*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradgradcheck(double2, x)) 3774*da0073e9SAndroid Build Coastguard Worker 3775*da0073e9SAndroid Build Coastguard Worker y = double(x) 3776*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x, create_graph=True) 3777*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x) 3778*da0073e9SAndroid Build Coastguard Worker 3779*da0073e9SAndroid Build Coastguard Worker y = double2(x) 3780*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x, create_graph=True) 3781*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(y, x) # should not error! 3782*da0073e9SAndroid Build Coastguard Worker 3783*da0073e9SAndroid Build Coastguard Worker def test_detach(self): 3784*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 3785*da0073e9SAndroid Build Coastguard Worker y = x + 2 3786*da0073e9SAndroid Build Coastguard Worker y = y.detach() 3787*da0073e9SAndroid Build Coastguard Worker z = y * 4 + 2 3788*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 3789*da0073e9SAndroid Build Coastguard Worker self.assertFalse(z.requires_grad) 3790*da0073e9SAndroid Build Coastguard Worker 3791*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 3792*da0073e9SAndroid Build Coastguard Worker y = x * 2 3793*da0073e9SAndroid Build Coastguard Worker y = y.detach() 3794*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 3795*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(y.grad_fn) 3796*da0073e9SAndroid Build Coastguard Worker z = x + y 3797*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 3798*da0073e9SAndroid Build Coastguard Worker # This is an incorrect gradient, but we assume that's what the user 3799*da0073e9SAndroid Build Coastguard Worker # wanted. detach() is an advanced option. 3800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(10, 10)) 3801*da0073e9SAndroid Build Coastguard Worker 3802*da0073e9SAndroid Build Coastguard Worker # in-place detach 3803*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 3804*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, requires_grad=True) 3805*da0073e9SAndroid Build Coastguard Worker a = x * 2 3806*da0073e9SAndroid Build Coastguard Worker (y + a).sum().backward(retain_graph=True) 3807*da0073e9SAndroid Build Coastguard Worker a.detach_() 3808*da0073e9SAndroid Build Coastguard Worker self.assertFalse(a.requires_grad) 3809*da0073e9SAndroid Build Coastguard Worker (y + a).sum().backward() # this won't backprop to x 3810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(10, 10) * 2) 3811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, torch.ones(10, 10) * 2) 3812*da0073e9SAndroid Build Coastguard Worker 3813*da0073e9SAndroid Build Coastguard Worker # in-place detach on a view raises an exception 3814*da0073e9SAndroid Build Coastguard Worker view = x.narrow(0, 1, 4) 3815*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "view", lambda: view.detach_()) 3816*da0073e9SAndroid Build Coastguard Worker 3817*da0073e9SAndroid Build Coastguard Worker def test_detach_base(self): 3818*da0073e9SAndroid Build Coastguard Worker "detaching base does not detach view" 3819*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 3820*da0073e9SAndroid Build Coastguard Worker view = x.narrow(0, 1, 4) 3821*da0073e9SAndroid Build Coastguard Worker x.detach_() 3822*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.requires_grad) 3823*da0073e9SAndroid Build Coastguard Worker self.assertTrue(view.requires_grad) 3824*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(view.grad_fn) 3825*da0073e9SAndroid Build Coastguard Worker self.assertIs(view._base, x) 3826*da0073e9SAndroid Build Coastguard Worker 3827*da0073e9SAndroid Build Coastguard Worker def test_detach_then_inplace_raises_in_autograd(self): 3828*da0073e9SAndroid Build Coastguard Worker x = torch.randn([], requires_grad=True) 3829*da0073e9SAndroid Build Coastguard Worker orig_x = x.detach().clone() 3830*da0073e9SAndroid Build Coastguard Worker 3831*da0073e9SAndroid Build Coastguard Worker y = x**2 # saves x 3832*da0073e9SAndroid Build Coastguard Worker z = x.detach() 3833*da0073e9SAndroid Build Coastguard Worker z.zero_() 3834*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "has been modified by an inplace"): 3835*da0073e9SAndroid Build Coastguard Worker y.backward() 3836*da0073e9SAndroid Build Coastguard Worker 3837*da0073e9SAndroid Build Coastguard Worker def _test_type_conversion_backward(self, t): 3838*da0073e9SAndroid Build Coastguard Worker fvar = Variable(t(torch.randn(5, 5).float()), requires_grad=True) 3839*da0073e9SAndroid Build Coastguard Worker fvar.double().sum().backward() 3840*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fvar.grad, torch.ones_like(fvar)) 3841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(fvar.grad), type(fvar)) 3842*da0073e9SAndroid Build Coastguard Worker dvar = Variable(t(torch.randn(5, 5).double()), requires_grad=True) 3843*da0073e9SAndroid Build Coastguard Worker dvar.float().sum().backward() 3844*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dvar.grad, torch.ones_like(dvar)) 3845*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(dvar.grad), type(dvar)) 3846*da0073e9SAndroid Build Coastguard Worker 3847*da0073e9SAndroid Build Coastguard Worker def test_type_conversions(self): 3848*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 3849*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.float(), torch.FloatTensor) 3850*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.int(), torch.IntTensor) 3851*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3852*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.float().cuda(), torch.cuda.FloatTensor) 3853*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.int().cuda(), torch.cuda.IntTensor) 3854*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.int().cuda().cpu(), torch.IntTensor) 3855*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() >= 2: 3856*da0073e9SAndroid Build Coastguard Worker x2 = x.float().cuda(1) 3857*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x2, torch.cuda.FloatTensor) 3858*da0073e9SAndroid Build Coastguard Worker self.assertIs(x2.get_device(), 1) 3859*da0073e9SAndroid Build Coastguard Worker x2 = x.float().cuda() 3860*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x2, torch.cuda.FloatTensor) 3861*da0073e9SAndroid Build Coastguard Worker self.assertIs(x2.get_device(), 0) 3862*da0073e9SAndroid Build Coastguard Worker x2 = x2.cuda(1) 3863*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x2, torch.cuda.FloatTensor) 3864*da0073e9SAndroid Build Coastguard Worker self.assertIs(x2.get_device(), 1) 3865*da0073e9SAndroid Build Coastguard Worker y = Variable(torch.randn(5).cuda(1), requires_grad=True) 3866*da0073e9SAndroid Build Coastguard Worker y.cpu().sum().backward() 3867*da0073e9SAndroid Build Coastguard Worker self.assertIs(y.grad.get_device(), 1) 3868*da0073e9SAndroid Build Coastguard Worker self.assertIs(y.long().get_device(), 1) 3869*da0073e9SAndroid Build Coastguard Worker 3870*da0073e9SAndroid Build Coastguard Worker for t in [ 3871*da0073e9SAndroid Build Coastguard Worker torch.DoubleTensor, 3872*da0073e9SAndroid Build Coastguard Worker torch.FloatTensor, 3873*da0073e9SAndroid Build Coastguard Worker torch.IntTensor, 3874*da0073e9SAndroid Build Coastguard Worker torch.ByteTensor, 3875*da0073e9SAndroid Build Coastguard Worker ]: 3876*da0073e9SAndroid Build Coastguard Worker for y_var in (True, False): 3877*da0073e9SAndroid Build Coastguard Worker y = torch.randint(5, (5, 5), dtype=t.dtype) 3878*da0073e9SAndroid Build Coastguard Worker y = Variable(y) if y_var else y 3879*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.type(t), t) 3880*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.type_as(y), t) 3881*da0073e9SAndroid Build Coastguard Worker # TODO: t.dtype should work 3882*da0073e9SAndroid Build Coastguard Worker t_dtype = t().dtype 3883*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x.type(t_dtype), t) 3884*da0073e9SAndroid Build Coastguard Worker self.assertIs(t_dtype, x.type(t_dtype).dtype) 3885*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.data_ptr(), y.type(t).data_ptr()) 3886*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3887*da0073e9SAndroid Build Coastguard Worker for x_cuda in (True, False): 3888*da0073e9SAndroid Build Coastguard Worker for y_cuda in (True, False): 3889*da0073e9SAndroid Build Coastguard Worker x_c = x.cuda() if x_cuda else x 3890*da0073e9SAndroid Build Coastguard Worker y_c = y.cuda() if y_cuda else y 3891*da0073e9SAndroid Build Coastguard Worker _, y_type = y_c.type().rsplit(".", 1) 3892*da0073e9SAndroid Build Coastguard Worker y_typestr = ("torch.cuda." if y_cuda else "torch.") + y_type 3893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_c.type(), x_c.type(y_typestr).type()) 3894*da0073e9SAndroid Build Coastguard Worker self.assertIs(y_c.dtype, x_c.type(y_c.dtype).dtype) 3895*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 3896*da0073e9SAndroid Build Coastguard Worker y_c.data_ptr(), 3897*da0073e9SAndroid Build Coastguard Worker y_c.cuda().data_ptr() if y_cuda else y_c.data_ptr(), 3898*da0073e9SAndroid Build Coastguard Worker ) 3899*da0073e9SAndroid Build Coastguard Worker 3900*da0073e9SAndroid Build Coastguard Worker self._test_type_conversion_backward(lambda x: x) 3901*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 3902*da0073e9SAndroid Build Coastguard Worker self._test_type_conversion_backward(lambda x: x.cuda()) 3903*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() >= 2: 3904*da0073e9SAndroid Build Coastguard Worker # one of these has to be the non-default device 3905*da0073e9SAndroid Build Coastguard Worker self._test_type_conversion_backward(lambda x: x.cuda(0)) 3906*da0073e9SAndroid Build Coastguard Worker self._test_type_conversion_backward(lambda x: x.cuda(1)) 3907*da0073e9SAndroid Build Coastguard Worker 3908*da0073e9SAndroid Build Coastguard Worker def test_isolated_node(self): 3909*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 3910*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 3911*da0073e9SAndroid Build Coastguard Worker 3912*da0073e9SAndroid Build Coastguard Worker a = x + y 3913*da0073e9SAndroid Build Coastguard Worker b = torch.max(a, 1, True)[1].repeat(1, 5).double() 3914*da0073e9SAndroid Build Coastguard Worker o = (b + a).sum() 3915*da0073e9SAndroid Build Coastguard Worker o.backward() 3916*da0073e9SAndroid Build Coastguard Worker 3917*da0073e9SAndroid Build Coastguard Worker def test_shape(self): 3918*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 4) 3919*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, len(x.shape)) 3920*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape[0], 3) 3921*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.shape[1], 4) 3922*da0073e9SAndroid Build Coastguard Worker 3923*da0073e9SAndroid Build Coastguard Worker def test_numpy_requires_grad(self): 3924*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 3925*da0073e9SAndroid Build Coastguard Worker err_msg_outputs = r"Can't call numpy\(\) on Tensor that requires grad. Use tensor.detach\(\).numpy\(\) instead." 3926*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg_outputs): 3927*da0073e9SAndroid Build Coastguard Worker x.numpy() 3928*da0073e9SAndroid Build Coastguard Worker 3929*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3930*da0073e9SAndroid Build Coastguard Worker x.numpy() 3931*da0073e9SAndroid Build Coastguard Worker 3932*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 3933*da0073e9SAndroid Build Coastguard Worker x.numpy() 3934*da0073e9SAndroid Build Coastguard Worker 3935*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 3936*da0073e9SAndroid Build Coastguard Worker x.numpy() 3937*da0073e9SAndroid Build Coastguard Worker 3938*da0073e9SAndroid Build Coastguard Worker def test_return_leaf(self): 3939*da0073e9SAndroid Build Coastguard Worker class Identity(Function): 3940*da0073e9SAndroid Build Coastguard Worker @staticmethod 3941*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 3942*da0073e9SAndroid Build Coastguard Worker return a, a + b 3943*da0073e9SAndroid Build Coastguard Worker 3944*da0073e9SAndroid Build Coastguard Worker @staticmethod 3945*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_a, grad_b): 3946*da0073e9SAndroid Build Coastguard Worker return grad_a + grad_b, grad_b 3947*da0073e9SAndroid Build Coastguard Worker 3948*da0073e9SAndroid Build Coastguard Worker hook_called = [False] 3949*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 3950*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 3951*da0073e9SAndroid Build Coastguard Worker 3952*da0073e9SAndroid Build Coastguard Worker q, p = Identity.apply(x, y) 3953*da0073e9SAndroid Build Coastguard Worker 3954*da0073e9SAndroid Build Coastguard Worker # Make sure hooks only receive grad from usage of q, not x. 3955*da0073e9SAndroid Build Coastguard Worker def hook(grad): 3956*da0073e9SAndroid Build Coastguard Worker hook_called[0] = True 3957*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, torch.ones(5, 5)) 3958*da0073e9SAndroid Build Coastguard Worker 3959*da0073e9SAndroid Build Coastguard Worker q.register_hook(hook) 3960*da0073e9SAndroid Build Coastguard Worker (q + p + x).sum().backward() 3961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(5, 5) * 3) 3962*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, torch.ones(5, 5)) 3963*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hook_called[0]) 3964*da0073e9SAndroid Build Coastguard Worker 3965*da0073e9SAndroid Build Coastguard Worker def test_return_leaf_inplace(self): 3966*da0073e9SAndroid Build Coastguard Worker class Inplace(InplaceFunction): 3967*da0073e9SAndroid Build Coastguard Worker @staticmethod 3968*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 3969*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(a) 3970*da0073e9SAndroid Build Coastguard Worker return a.add_(b), b + 2 3971*da0073e9SAndroid Build Coastguard Worker 3972*da0073e9SAndroid Build Coastguard Worker @staticmethod 3973*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_a, grad_b): 3974*da0073e9SAndroid Build Coastguard Worker return grad_a, grad_a + grad_b 3975*da0073e9SAndroid Build Coastguard Worker 3976*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 3977*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, requires_grad=True) 3978*da0073e9SAndroid Build Coastguard Worker 3979*da0073e9SAndroid Build Coastguard Worker q, p = Inplace.apply(x, y) 3980*da0073e9SAndroid Build Coastguard Worker self.assertIs(q, x) 3981*da0073e9SAndroid Build Coastguard Worker self.assertIs(q.grad_fn.__class__, Inplace._backward_cls) 3982*da0073e9SAndroid Build Coastguard Worker self.assertTrue(q.requires_grad) 3983*da0073e9SAndroid Build Coastguard Worker q.sum().backward() 3984*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, torch.ones(5, 5)) 3985*da0073e9SAndroid Build Coastguard Worker 3986*da0073e9SAndroid Build Coastguard Worker def test_leaf_assignment(self): 3987*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5) 3988*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, requires_grad=True) 3989*da0073e9SAndroid Build Coastguard Worker z = torch.randn(5, requires_grad=True) 3990*da0073e9SAndroid Build Coastguard Worker 3991*da0073e9SAndroid Build Coastguard Worker x[0] = y 3992*da0073e9SAndroid Build Coastguard Worker x[1] = 2 * z 3993*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 3994*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(x.grad_fn, None) 3995*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 3996*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, torch.ones(5)) 3997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.grad, torch.ones(5) * 2) 3998*da0073e9SAndroid Build Coastguard Worker 3999*da0073e9SAndroid Build Coastguard Worker def test_no_grad_assignment(self): 4000*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 4001*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5) 4002*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4003*da0073e9SAndroid Build Coastguard Worker x[0] = y 4004*da0073e9SAndroid Build Coastguard Worker 4005*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 4006*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x.grad_fn) 4007*da0073e9SAndroid Build Coastguard Worker 4008*da0073e9SAndroid Build Coastguard Worker def test_no_grad_modifies_version(self): 4009*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, requires_grad=True) 4010*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, requires_grad=True) 4011*da0073e9SAndroid Build Coastguard Worker z = (x * y).sum() 4012*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4013*da0073e9SAndroid Build Coastguard Worker x *= 2 4014*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 4015*da0073e9SAndroid Build Coastguard Worker RuntimeError, "modified by an inplace operation", lambda: z.backward() 4016*da0073e9SAndroid Build Coastguard Worker ) 4017*da0073e9SAndroid Build Coastguard Worker 4018*da0073e9SAndroid Build Coastguard Worker def test_increment_version(self): 4019*da0073e9SAndroid Build Coastguard Worker a = torch.rand(5, requires_grad=True) 4020*da0073e9SAndroid Build Coastguard Worker v = a._version 4021*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.increment_version(a) 4022*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a._version, v + 1) 4023*da0073e9SAndroid Build Coastguard Worker 4024*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(5, dtype=torch.int) 4025*da0073e9SAndroid Build Coastguard Worker v = a._version 4026*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.increment_version(a) 4027*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a._version, v + 1) 4028*da0073e9SAndroid Build Coastguard Worker 4029*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 4030*da0073e9SAndroid Build Coastguard Worker a = torch.rand(5, requires_grad=True) 4031*da0073e9SAndroid Build Coastguard Worker # does not error 4032*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.increment_version(a) 4033*da0073e9SAndroid Build Coastguard Worker 4034*da0073e9SAndroid Build Coastguard Worker # does not error 4035*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.increment_version(a) 4036*da0073e9SAndroid Build Coastguard Worker 4037*da0073e9SAndroid Build Coastguard Worker def test_no_grad_input(self): 4038*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 4039*da0073e9SAndroid Build Coastguard Worker @staticmethod 4040*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 4041*da0073e9SAndroid Build Coastguard Worker return x 4042*da0073e9SAndroid Build Coastguard Worker 4043*da0073e9SAndroid Build Coastguard Worker @staticmethod 4044*da0073e9SAndroid Build Coastguard Worker def backward(self, grad_output): 4045*da0073e9SAndroid Build Coastguard Worker return grad_output 4046*da0073e9SAndroid Build Coastguard Worker 4047*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, requires_grad=True) 4048*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 4049*da0073e9SAndroid Build Coastguard Worker y = MyFunction.apply(x) 4050*da0073e9SAndroid Build Coastguard Worker 4051*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 4052*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(y.grad_fn) 4053*da0073e9SAndroid Build Coastguard Worker 4054*da0073e9SAndroid Build Coastguard Worker def test_backward_copy(self): 4055*da0073e9SAndroid Build Coastguard Worker # This tests checks backward engine for a very subtle bug that appreared 4056*da0073e9SAndroid Build Coastguard Worker # in one of the initial versions of autograd. Gradients tensors were 4057*da0073e9SAndroid Build Coastguard Worker # simply stored in lists while the function waited for all its gradients 4058*da0073e9SAndroid Build Coastguard Worker # to be computed. However, sometimes an output was used multiple times, 4059*da0073e9SAndroid Build Coastguard Worker # so the gradients needed to be summed. Engine used to keep a need_copy 4060*da0073e9SAndroid Build Coastguard Worker # set of tensors that will need a clone upon next addition and removed 4061*da0073e9SAndroid Build Coastguard Worker # them from the set as soon as the clone was performed. However, this 4062*da0073e9SAndroid Build Coastguard Worker # could lead to incorrect results if the same gradient tensor was 4063*da0073e9SAndroid Build Coastguard Worker # buffered in three places in the graph: 4064*da0073e9SAndroid Build Coastguard Worker # 1. When accumulating gradients in one of these places it was cloned 4065*da0073e9SAndroid Build Coastguard Worker # and removed from need_copy set. 4066*da0073e9SAndroid Build Coastguard Worker # 2. When accumulating in second place, it wasn't in the need_copy set, 4067*da0073e9SAndroid Build Coastguard Worker # so the gradients were simply accumulated in-place (which already 4068*da0073e9SAndroid Build Coastguard Worker # modified the grad in 3rd place) 4069*da0073e9SAndroid Build Coastguard Worker # 3. When accumulating in the third place, it wasn't in the need_copy set 4070*da0073e9SAndroid Build Coastguard Worker # as well, so the incoming gradient was summed in-place, yielding 4071*da0073e9SAndroid Build Coastguard Worker # incorrect results in all functions, except the first one. 4072*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 4073*da0073e9SAndroid Build Coastguard Worker y = torch.ones(5, 5, requires_grad=True) 4074*da0073e9SAndroid Build Coastguard Worker # Simulate that we're in the middle of the graph 4075*da0073e9SAndroid Build Coastguard Worker a = x + 2 4076*da0073e9SAndroid Build Coastguard Worker b = y + 2 4077*da0073e9SAndroid Build Coastguard Worker c = x + 2 4078*da0073e9SAndroid Build Coastguard Worker # This op will just return grad_output two times in backward 4079*da0073e9SAndroid Build Coastguard Worker add1 = a + b 4080*da0073e9SAndroid Build Coastguard Worker add2 = add1 + c 4081*da0073e9SAndroid Build Coastguard Worker # Simulate a long branch, so grad_output will get buffered. 4082*da0073e9SAndroid Build Coastguard Worker for _ in range(4): 4083*da0073e9SAndroid Build Coastguard Worker a = a * 2 4084*da0073e9SAndroid Build Coastguard Worker b = b * 2 4085*da0073e9SAndroid Build Coastguard Worker c = c * 2 4086*da0073e9SAndroid Build Coastguard Worker branch = a + b + c 4087*da0073e9SAndroid Build Coastguard Worker out = add2 + branch 4088*da0073e9SAndroid Build Coastguard Worker # expected gradients are: 4089*da0073e9SAndroid Build Coastguard Worker # for x: 34 (16 from final a, 16 from final c, 2 from add2) 4090*da0073e9SAndroid Build Coastguard Worker # for y: 17 (16 from final b, 1 from add2) 4091*da0073e9SAndroid Build Coastguard Worker grad_output = torch.ones(5, 5) 4092*da0073e9SAndroid Build Coastguard Worker out.backward(grad_output) 4093*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(5, 5) * 34) 4094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, torch.ones(5, 5) * 17) 4095*da0073e9SAndroid Build Coastguard Worker 4096*da0073e9SAndroid Build Coastguard Worker def test_save_none_for_backward(self): 4097*da0073e9SAndroid Build Coastguard Worker test_case = self 4098*da0073e9SAndroid Build Coastguard Worker 4099*da0073e9SAndroid Build Coastguard Worker class MyFn(Function): 4100*da0073e9SAndroid Build Coastguard Worker @staticmethod 4101*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 4102*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(None, input, None) 4103*da0073e9SAndroid Build Coastguard Worker return input * input 4104*da0073e9SAndroid Build Coastguard Worker 4105*da0073e9SAndroid Build Coastguard Worker @staticmethod 4106*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 4107*da0073e9SAndroid Build Coastguard Worker n1, input, n2 = ctx.saved_tensors 4108*da0073e9SAndroid Build Coastguard Worker test_case.assertIsNone(n1) 4109*da0073e9SAndroid Build Coastguard Worker test_case.assertIsNone(n2) 4110*da0073e9SAndroid Build Coastguard Worker return 2 * input * grad_output 4111*da0073e9SAndroid Build Coastguard Worker 4112*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 4113*da0073e9SAndroid Build Coastguard Worker y = MyFn.apply(x) 4114*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 4115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, 2 * x) 4116*da0073e9SAndroid Build Coastguard Worker 4117*da0073e9SAndroid Build Coastguard Worker def test_too_many_grads(self): 4118*da0073e9SAndroid Build Coastguard Worker class MyFn(Function): 4119*da0073e9SAndroid Build Coastguard Worker @staticmethod 4120*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 4121*da0073e9SAndroid Build Coastguard Worker return input 4122*da0073e9SAndroid Build Coastguard Worker 4123*da0073e9SAndroid Build Coastguard Worker @staticmethod 4124*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 4125*da0073e9SAndroid Build Coastguard Worker return grad_output, None, None 4126*da0073e9SAndroid Build Coastguard Worker 4127*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, requires_grad=True) 4128*da0073e9SAndroid Build Coastguard Worker y = MyFn.apply(x) 4129*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 4130*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x)) 4131*da0073e9SAndroid Build Coastguard Worker 4132*da0073e9SAndroid Build Coastguard Worker def test_pickle(self): 4133*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 4134*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, requires_grad=False) 4135*da0073e9SAndroid Build Coastguard Worker 4136*da0073e9SAndroid Build Coastguard Worker def assert_strict_equal(var1, var2): 4137*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var1, var2) 4138*da0073e9SAndroid Build Coastguard Worker self.assertEqual(var1.requires_grad, var2.requires_grad) 4139*da0073e9SAndroid Build Coastguard Worker 4140*da0073e9SAndroid Build Coastguard Worker serialized = [pickle.dumps([x, y], protocol=p) for p in range(3)] 4141*da0073e9SAndroid Build Coastguard Worker for dump in serialized: 4142*da0073e9SAndroid Build Coastguard Worker xc, yc = pickle.loads(dump) 4143*da0073e9SAndroid Build Coastguard Worker assert_strict_equal(xc, x) 4144*da0073e9SAndroid Build Coastguard Worker assert_strict_equal(yc, y) 4145*da0073e9SAndroid Build Coastguard Worker 4146*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 4147*da0073e9SAndroid Build Coastguard Worker def test_dep_nograd(self): 4148*da0073e9SAndroid Build Coastguard Worker class F1(Function): 4149*da0073e9SAndroid Build Coastguard Worker @staticmethod 4150*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 4151*da0073e9SAndroid Build Coastguard Worker out = torch.randn(input.size()) 4152*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(out) 4153*da0073e9SAndroid Build Coastguard Worker return input, out 4154*da0073e9SAndroid Build Coastguard Worker 4155*da0073e9SAndroid Build Coastguard Worker @staticmethod 4156*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output, ignored): 4157*da0073e9SAndroid Build Coastguard Worker return grad_output 4158*da0073e9SAndroid Build Coastguard Worker 4159*da0073e9SAndroid Build Coastguard Worker class F2(Function): 4160*da0073e9SAndroid Build Coastguard Worker @staticmethod 4161*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input, ignored): 4162*da0073e9SAndroid Build Coastguard Worker return input 4163*da0073e9SAndroid Build Coastguard Worker 4164*da0073e9SAndroid Build Coastguard Worker @staticmethod 4165*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 4166*da0073e9SAndroid Build Coastguard Worker return grad_output, None 4167*da0073e9SAndroid Build Coastguard Worker 4168*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, requires_grad=True) 4169*da0073e9SAndroid Build Coastguard Worker a, b = F1.apply(x) 4170*da0073e9SAndroid Build Coastguard Worker b = b + 1 # separate F1 from F2 by another op 4171*da0073e9SAndroid Build Coastguard Worker self.assertTrue(a.requires_grad) 4172*da0073e9SAndroid Build Coastguard Worker self.assertFalse(b.requires_grad) 4173*da0073e9SAndroid Build Coastguard Worker c = F2.apply(a, b) 4174*da0073e9SAndroid Build Coastguard Worker c.backward(torch.ones(c.size())) 4175*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(x.size())) 4176*da0073e9SAndroid Build Coastguard Worker 4177*da0073e9SAndroid Build Coastguard Worker def test_set_grad_enabled(self): 4178*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1.0], requires_grad=True) 4179*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 4180*da0073e9SAndroid Build Coastguard Worker y = x * 2 4181*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 4182*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(True): 4183*da0073e9SAndroid Build Coastguard Worker y = x * 2 4184*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.requires_grad) 4185*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(False): 4186*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(True) 4187*da0073e9SAndroid Build Coastguard Worker y = x * 2 4188*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.requires_grad) 4189*da0073e9SAndroid Build Coastguard Worker 4190*da0073e9SAndroid Build Coastguard Worker def test_set_grad_enabled_wraps(self): 4191*da0073e9SAndroid Build Coastguard Worker for decorator in [True, False]: 4192*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 4193*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 4194*da0073e9SAndroid Build Coastguard Worker 4195*da0073e9SAndroid Build Coastguard Worker if decorator: 4196*da0073e9SAndroid Build Coastguard Worker # This should not mutate the global grad mode! 4197*da0073e9SAndroid Build Coastguard Worker @torch.set_grad_enabled(False) 4198*da0073e9SAndroid Build Coastguard Worker def inner_func(x): 4199*da0073e9SAndroid Build Coastguard Worker return x.sin() 4200*da0073e9SAndroid Build Coastguard Worker 4201*da0073e9SAndroid Build Coastguard Worker else: 4202*da0073e9SAndroid Build Coastguard Worker 4203*da0073e9SAndroid Build Coastguard Worker def inner_func(x): 4204*da0073e9SAndroid Build Coastguard Worker return x.sin() 4205*da0073e9SAndroid Build Coastguard Worker 4206*da0073e9SAndroid Build Coastguard Worker # This is non-idiomatic usage! 4207*da0073e9SAndroid Build Coastguard Worker # More idiomatic usage: torch.set_grad_enabled(False)(inner_func) 4208*da0073e9SAndroid Build Coastguard Worker obj = torch.set_grad_enabled(False) 4209*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not torch.is_grad_enabled()) 4210*da0073e9SAndroid Build Coastguard Worker 4211*da0073e9SAndroid Build Coastguard Worker # this will consume the set_grad_enabled global mutation! 4212*da0073e9SAndroid Build Coastguard Worker inner_func = obj(inner_func) 4213*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 4214*da0073e9SAndroid Build Coastguard Worker 4215*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_grad_enabled()) 4216*da0073e9SAndroid Build Coastguard Worker 4217*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, requires_grad=True) 4218*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not inner_func(x).requires_grad) 4219*da0073e9SAndroid Build Coastguard Worker 4220*da0073e9SAndroid Build Coastguard Worker def test_simple_reentrant(self): 4221*da0073e9SAndroid Build Coastguard Worker y_data = torch.randn(2, 2) 4222*da0073e9SAndroid Build Coastguard Worker 4223*da0073e9SAndroid Build Coastguard Worker class Reenter(Function): 4224*da0073e9SAndroid Build Coastguard Worker @staticmethod 4225*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 4226*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 4227*da0073e9SAndroid Build Coastguard Worker ctx.x = Variable(x, requires_grad=True) 4228*da0073e9SAndroid Build Coastguard Worker ctx.y = Variable(y_data, requires_grad=True) 4229*da0073e9SAndroid Build Coastguard Worker ctx.output_var = ctx.x * ctx.y 4230*da0073e9SAndroid Build Coastguard Worker return ctx.output_var.detach() 4231*da0073e9SAndroid Build Coastguard Worker 4232*da0073e9SAndroid Build Coastguard Worker @staticmethod 4233*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 4234*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 4235*da0073e9SAndroid Build Coastguard Worker ctx.output_var.sum().backward() 4236*da0073e9SAndroid Build Coastguard Worker return ctx.x.grad * grad_output 4237*da0073e9SAndroid Build Coastguard Worker 4238*da0073e9SAndroid Build Coastguard Worker # Reentrant starts on CPU thread, finishs on GPU thread 4239*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 4240*da0073e9SAndroid Build Coastguard Worker out = Reenter.apply(x) 4241*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 4242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y_data) 4243*da0073e9SAndroid Build Coastguard Worker 4244*da0073e9SAndroid Build Coastguard Worker def test_reentrant_child_error(self): 4245*da0073e9SAndroid Build Coastguard Worker # Parent graph. 4246*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 3, requires_grad=True) 4247*da0073e9SAndroid Build Coastguard Worker c = a * a 4248*da0073e9SAndroid Build Coastguard Worker 4249*da0073e9SAndroid Build Coastguard Worker # Reentrant child graph. 4250*da0073e9SAndroid Build Coastguard Worker b = torch.rand(3, 3, requires_grad=True) 4251*da0073e9SAndroid Build Coastguard Worker e = b * b 4252*da0073e9SAndroid Build Coastguard Worker f = TestAutograd.SimulateBackwardError.apply(e) 4253*da0073e9SAndroid Build Coastguard Worker reentrant_root = f.sum() 4254*da0073e9SAndroid Build Coastguard Worker 4255*da0073e9SAndroid Build Coastguard Worker class ReentrantFunc(Function): 4256*da0073e9SAndroid Build Coastguard Worker @staticmethod 4257*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp): 4258*da0073e9SAndroid Build Coastguard Worker return inp.clone() 4259*da0073e9SAndroid Build Coastguard Worker 4260*da0073e9SAndroid Build Coastguard Worker @staticmethod 4261*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 4262*da0073e9SAndroid Build Coastguard Worker # Reentrant backward in child will throw an error. 4263*da0073e9SAndroid Build Coastguard Worker reentrant_root.backward() 4264*da0073e9SAndroid Build Coastguard Worker return grad 4265*da0073e9SAndroid Build Coastguard Worker 4266*da0073e9SAndroid Build Coastguard Worker d = ReentrantFunc.apply(c) 4267*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Simulate error"): 4268*da0073e9SAndroid Build Coastguard Worker d.sum().backward() 4269*da0073e9SAndroid Build Coastguard Worker 4270*da0073e9SAndroid Build Coastguard Worker def test_var_mean_differentiable(self): 4271*da0073e9SAndroid Build Coastguard Worker dim = [2, 4] 4272*da0073e9SAndroid Build Coastguard Worker keepdim = False 4273*da0073e9SAndroid Build Coastguard Worker input1 = torch.randn(3, 4, 5, 6, 2, 3, requires_grad=True) 4274*da0073e9SAndroid Build Coastguard Worker input2 = deepcopy(input1) 4275*da0073e9SAndroid Build Coastguard Worker var1, mean1 = torch.var_mean(input1, dim=dim, keepdim=keepdim) 4276*da0073e9SAndroid Build Coastguard Worker var2 = input2.var(dim=dim, keepdim=keepdim) 4277*da0073e9SAndroid Build Coastguard Worker mean2 = input2.mean(dim=dim, keepdim=keepdim) 4278*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(3, 4, 6, 3, requires_grad=True) 4279*da0073e9SAndroid Build Coastguard Worker 4280*da0073e9SAndroid Build Coastguard Worker r1 = var1 * var1 * mean1 * mean1 4281*da0073e9SAndroid Build Coastguard Worker r2 = var2 * var2 * mean2 * mean2 4282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(r1, r2, rtol=0.01, atol=0.0) 4283*da0073e9SAndroid Build Coastguard Worker 4284*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(r1, grad) 4285*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward(r2, grad) 4286*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input1.grad, input2.grad, rtol=0.01, atol=0.0) 4287*da0073e9SAndroid Build Coastguard Worker 4288*da0073e9SAndroid Build Coastguard Worker @skipIfNoLapack 4289*da0073e9SAndroid Build Coastguard Worker def test_lobpcg(self): 4290*da0073e9SAndroid Build Coastguard Worker def func(k, A, largest=True, B=None): 4291*da0073e9SAndroid Build Coastguard Worker X_shape = list(A.shape) 4292*da0073e9SAndroid Build Coastguard Worker X_shape[-1] = k 4293*da0073e9SAndroid Build Coastguard Worker X = torch.eye(A.size(-2), k, dtype=A.dtype, device=A.device) 4294*da0073e9SAndroid Build Coastguard Worker if A.dim() > 2: 4295*da0073e9SAndroid Build Coastguard Worker X = X.expand(X_shape) 4296*da0073e9SAndroid Build Coastguard Worker 4297*da0073e9SAndroid Build Coastguard Worker D, U = torch.lobpcg(A=A, k=k, B=B, X=X, largest=largest) 4298*da0073e9SAndroid Build Coastguard Worker 4299*da0073e9SAndroid Build Coastguard Worker # LOBPCG uses a random initial eigenspace approximation 4300*da0073e9SAndroid Build Coastguard Worker # if parameter `X` is not provided. 4301*da0073e9SAndroid Build Coastguard Worker # This may cause a non-deterministic behavior 4302*da0073e9SAndroid Build Coastguard Worker # when it comes to the sign of an eigenvector 4303*da0073e9SAndroid Build Coastguard Worker # (note if v is an eigenvector, so is -v), 4304*da0073e9SAndroid Build Coastguard Worker # hence we eliminate this non-determinism 4305*da0073e9SAndroid Build Coastguard Worker # by making sure that each column of U 4306*da0073e9SAndroid Build Coastguard Worker # gets multiplied by the sign of its max (in absolute value) element. 4307*da0073e9SAndroid Build Coastguard Worker # Also, gradcheck changes the content of the input by +/- eps (default to 1e-06) 4308*da0073e9SAndroid Build Coastguard Worker # to compute the numerical gradient which can also cause the signs to flip. 4309*da0073e9SAndroid Build Coastguard Worker _, idx = U.abs().max(-2, keepdim=True) 4310*da0073e9SAndroid Build Coastguard Worker sign = U.gather(-2, idx).sign() 4311*da0073e9SAndroid Build Coastguard Worker U = U * sign 4312*da0073e9SAndroid Build Coastguard Worker return D, U 4313*da0073e9SAndroid Build Coastguard Worker 4314*da0073e9SAndroid Build Coastguard Worker # TODO: review if this can be ported to OpInfos or moved to test_linalg.py 4315*da0073e9SAndroid Build Coastguard Worker def run_symeig_test(k, sizes, largest=True): 4316*da0073e9SAndroid Build Coastguard Worker A = torch.rand(*sizes).double() 4317*da0073e9SAndroid Build Coastguard Worker A = (A @ A.mT) / 10 4318*da0073e9SAndroid Build Coastguard Worker A.requires_grad_(True) 4319*da0073e9SAndroid Build Coastguard Worker 4320*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda A: func(k, A, largest), A, check_batched_grad=False) 4321*da0073e9SAndroid Build Coastguard Worker 4322*da0073e9SAndroid Build Coastguard Worker # Custom gradient vectors for better stability due to some 4323*da0073e9SAndroid Build Coastguard Worker # non-determinism in the lobpcg's forward. 4324*da0073e9SAndroid Build Coastguard Worker # Note it is not required if symeig is in forward instead (tested). 4325*da0073e9SAndroid Build Coastguard Worker D_grad = torch.rand(*A.shape[:-2], k) / 100 4326*da0073e9SAndroid Build Coastguard Worker U_grad = torch.rand(*A.shape[:-1], k) / 100 4327*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 4328*da0073e9SAndroid Build Coastguard Worker lambda A: func(k, A, largest), 4329*da0073e9SAndroid Build Coastguard Worker A, 4330*da0073e9SAndroid Build Coastguard Worker [D_grad, U_grad], 4331*da0073e9SAndroid Build Coastguard Worker atol=1e-4, 4332*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 4333*da0073e9SAndroid Build Coastguard Worker ) 4334*da0073e9SAndroid Build Coastguard Worker 4335*da0073e9SAndroid Build Coastguard Worker # check whether A.grad is symmetric 4336*da0073e9SAndroid Build Coastguard Worker A = A.detach().requires_grad_(True) 4337*da0073e9SAndroid Build Coastguard Worker D, U = func(k, A, largest) 4338*da0073e9SAndroid Build Coastguard Worker (D.sum() + U.sum()).backward() 4339*da0073e9SAndroid Build Coastguard Worker self.assertEqual(A.grad, A.grad.mT) 4340*da0073e9SAndroid Build Coastguard Worker 4341*da0073e9SAndroid Build Coastguard Worker for largest in [True, False]: 4342*da0073e9SAndroid Build Coastguard Worker run_symeig_test(1, (6, 6), largest=largest) 4343*da0073e9SAndroid Build Coastguard Worker run_symeig_test(1, (2, 6, 6), largest=largest) 4344*da0073e9SAndroid Build Coastguard Worker run_symeig_test(1, (2, 2, 6, 6), largest=largest) 4345*da0073e9SAndroid Build Coastguard Worker run_symeig_test(2, (6, 6), largest=largest) 4346*da0073e9SAndroid Build Coastguard Worker run_symeig_test(2, (2, 6, 6), largest=largest) 4347*da0073e9SAndroid Build Coastguard Worker run_symeig_test(2, (2, 2, 6, 6), largest=largest) 4348*da0073e9SAndroid Build Coastguard Worker run_symeig_test(3, (9, 9), largest=largest) 4349*da0073e9SAndroid Build Coastguard Worker run_symeig_test(3, (2, 9, 9), largest=largest) 4350*da0073e9SAndroid Build Coastguard Worker run_symeig_test(3, (2, 2, 9, 9), largest=largest) 4351*da0073e9SAndroid Build Coastguard Worker 4352*da0073e9SAndroid Build Coastguard Worker def test_variable_traverse(self): 4353*da0073e9SAndroid Build Coastguard Worker def get_out_and_unrefed_cycle(): 4354*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(10, requires_grad=True) 4355*da0073e9SAndroid Build Coastguard Worker tmp = inp.view(10, 1) 4356*da0073e9SAndroid Build Coastguard Worker out = tmp.view(10) 4357*da0073e9SAndroid Build Coastguard Worker 4358*da0073e9SAndroid Build Coastguard Worker # Create a reference cycle that contains an 4359*da0073e9SAndroid Build Coastguard Worker # intermediary Variable in the graph 4360*da0073e9SAndroid Build Coastguard Worker my_list = [] 4361*da0073e9SAndroid Build Coastguard Worker my_list.append(tmp) 4362*da0073e9SAndroid Build Coastguard Worker my_list.append(my_list) 4363*da0073e9SAndroid Build Coastguard Worker 4364*da0073e9SAndroid Build Coastguard Worker return out 4365*da0073e9SAndroid Build Coastguard Worker 4366*da0073e9SAndroid Build Coastguard Worker out = get_out_and_unrefed_cycle() 4367*da0073e9SAndroid Build Coastguard Worker gc.collect() 4368*da0073e9SAndroid Build Coastguard Worker # This will segfault if things have been erroneously released 4369*da0073e9SAndroid Build Coastguard Worker out.backward(torch.randn(out.size())) 4370*da0073e9SAndroid Build Coastguard Worker 4371*da0073e9SAndroid Build Coastguard Worker # TODO: review porting these to OpInfo tests 4372*da0073e9SAndroid Build Coastguard Worker def test_pow_zero_tensor_gradient(self): 4373*da0073e9SAndroid Build Coastguard Worker def run_test(input_size, exponent): 4374*da0073e9SAndroid Build Coastguard Worker input = torch.zeros(*input_size, requires_grad=True) 4375*da0073e9SAndroid Build Coastguard Worker input.pow(exponent).sum().backward() 4376*da0073e9SAndroid Build Coastguard Worker self.assertEqual(input.grad.abs().sum(), 0) 4377*da0073e9SAndroid Build Coastguard Worker 4378*da0073e9SAndroid Build Coastguard Worker run_test((10,), torch.zeros(10)) 4379*da0073e9SAndroid Build Coastguard Worker run_test((10, 10), torch.zeros(10, 10)) 4380*da0073e9SAndroid Build Coastguard Worker run_test((10,), 0) 4381*da0073e9SAndroid Build Coastguard Worker 4382*da0073e9SAndroid Build Coastguard Worker def test_current_graph_task_id(self): 4383*da0073e9SAndroid Build Coastguard Worker id = [-1] 4384*da0073e9SAndroid Build Coastguard Worker 4385*da0073e9SAndroid Build Coastguard Worker def hook(_): 4386*da0073e9SAndroid Build Coastguard Worker id[0] = torch._C._current_graph_task_id() 4387*da0073e9SAndroid Build Coastguard Worker 4388*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0, requires_grad=True).clone() 4389*da0073e9SAndroid Build Coastguard Worker t.register_hook(hook) 4390*da0073e9SAndroid Build Coastguard Worker 4391*da0073e9SAndroid Build Coastguard Worker t.backward(retain_graph=True) 4392*da0073e9SAndroid Build Coastguard Worker base = id[0] 4393*da0073e9SAndroid Build Coastguard Worker t.backward(retain_graph=True) 4394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id[0] - base, 1) 4395*da0073e9SAndroid Build Coastguard Worker t.backward(retain_graph=True) 4396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(id[0] - base, 2) 4397*da0073e9SAndroid Build Coastguard Worker 4398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch._C._current_graph_task_id(), -1) 4399*da0073e9SAndroid Build Coastguard Worker 4400*da0073e9SAndroid Build Coastguard Worker def test_current_graph_task_execution_order(self): 4401*da0073e9SAndroid Build Coastguard Worker predicted = [None] 4402*da0073e9SAndroid Build Coastguard Worker 4403*da0073e9SAndroid Build Coastguard Worker def hook(_): 4404*da0073e9SAndroid Build Coastguard Worker predicted[0] = torch._C._current_graph_task_execution_order() 4405*da0073e9SAndroid Build Coastguard Worker 4406*da0073e9SAndroid Build Coastguard Worker def names(nodes): 4407*da0073e9SAndroid Build Coastguard Worker return ", ".join([node.name().split(" ")[-1] for node in nodes]) + "\n" 4408*da0073e9SAndroid Build Coastguard Worker 4409*da0073e9SAndroid Build Coastguard Worker def grad_fns(*tensors): 4410*da0073e9SAndroid Build Coastguard Worker # or grad accumulator 4411*da0073e9SAndroid Build Coastguard Worker out = [] 4412*da0073e9SAndroid Build Coastguard Worker for t in tensors: 4413*da0073e9SAndroid Build Coastguard Worker if t.requires_grad and t.grad_fn is None: 4414*da0073e9SAndroid Build Coastguard Worker out.append(t.clone().grad_fn.next_functions[0][0]) 4415*da0073e9SAndroid Build Coastguard Worker else: 4416*da0073e9SAndroid Build Coastguard Worker out.append(t.grad_fn) 4417*da0073e9SAndroid Build Coastguard Worker return out 4418*da0073e9SAndroid Build Coastguard Worker 4419*da0073e9SAndroid Build Coastguard Worker actual = [] 4420*da0073e9SAndroid Build Coastguard Worker 4421*da0073e9SAndroid Build Coastguard Worker def register_logging_hooks(*tensors): 4422*da0073e9SAndroid Build Coastguard Worker # register hooks that log the order in which they are called 4423*da0073e9SAndroid Build Coastguard Worker def get_hook(i): 4424*da0073e9SAndroid Build Coastguard Worker def hook(t_): 4425*da0073e9SAndroid Build Coastguard Worker actual.append(tensors[i]) 4426*da0073e9SAndroid Build Coastguard Worker 4427*da0073e9SAndroid Build Coastguard Worker return hook 4428*da0073e9SAndroid Build Coastguard Worker 4429*da0073e9SAndroid Build Coastguard Worker for i, t in enumerate(tensors): 4430*da0073e9SAndroid Build Coastguard Worker t.register_hook(get_hook(i)) 4431*da0073e9SAndroid Build Coastguard Worker 4432*da0073e9SAndroid Build Coastguard Worker # Basic example: single path 4433*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0, requires_grad=True).clone().sin().exp() 4434*da0073e9SAndroid Build Coastguard Worker t.register_hook(hook) 4435*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4436*da0073e9SAndroid Build Coastguard Worker t.backward() 4437*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 4438*da0073e9SAndroid Build Coastguard Worker names(predicted[0]), 4439*da0073e9SAndroid Build Coastguard Worker """\ 4440*da0073e9SAndroid Build Coastguard WorkerExpBackward0, SinBackward0, CloneBackward0, torch::autograd::AccumulateGrad 4441*da0073e9SAndroid Build Coastguard Worker""", 4442*da0073e9SAndroid Build Coastguard Worker ) 4443*da0073e9SAndroid Build Coastguard Worker 4444*da0073e9SAndroid Build Coastguard Worker # We don't exactly follow sequence_nr order 4445*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 4446*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(2.0, requires_grad=True) 4447*da0073e9SAndroid Build Coastguard Worker c = b.sin() 4448*da0073e9SAndroid Build Coastguard Worker d = a.cos() 4449*da0073e9SAndroid Build Coastguard Worker out = c * d 4450*da0073e9SAndroid Build Coastguard Worker register_logging_hooks(a, b, c, d, out) 4451*da0073e9SAndroid Build Coastguard Worker out.register_hook(hook) 4452*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4453*da0073e9SAndroid Build Coastguard Worker out.backward() 4454*da0073e9SAndroid Build Coastguard Worker self.assertEqual(predicted[0], grad_fns(*actual)) 4455*da0073e9SAndroid Build Coastguard Worker actual = [] 4456*da0073e9SAndroid Build Coastguard Worker 4457*da0073e9SAndroid Build Coastguard Worker # Accumulate grad node has more than one input 4458*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 4459*da0073e9SAndroid Build Coastguard Worker b = a.sin() 4460*da0073e9SAndroid Build Coastguard Worker c = a.cos() 4461*da0073e9SAndroid Build Coastguard Worker out = b * c 4462*da0073e9SAndroid Build Coastguard Worker register_logging_hooks(a, b, c, out) 4463*da0073e9SAndroid Build Coastguard Worker out.register_hook(hook) 4464*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4465*da0073e9SAndroid Build Coastguard Worker out.backward() 4466*da0073e9SAndroid Build Coastguard Worker self.assertEqual(predicted[0], grad_fns(*actual)) 4467*da0073e9SAndroid Build Coastguard Worker actual = [] 4468*da0073e9SAndroid Build Coastguard Worker 4469*da0073e9SAndroid Build Coastguard Worker # Multiple roots are also OK 4470*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 4471*da0073e9SAndroid Build Coastguard Worker b = a * 2 4472*da0073e9SAndroid Build Coastguard Worker out = b.sin() 4473*da0073e9SAndroid Build Coastguard Worker out2 = b.cos() 4474*da0073e9SAndroid Build Coastguard Worker out3 = b.cos() 4475*da0073e9SAndroid Build Coastguard Worker register_logging_hooks(a, b, out, out2, out3) 4476*da0073e9SAndroid Build Coastguard Worker out3.register_hook(hook) 4477*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4478*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad((out, out3, out2), inputs=(a,)) 4479*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 4480*da0073e9SAndroid Build Coastguard Worker names(predicted[0]), 4481*da0073e9SAndroid Build Coastguard Worker """\ 4482*da0073e9SAndroid Build Coastguard WorkerCosBackward0, CosBackward0, SinBackward0, MulBackward0, torch::autograd::AccumulateGrad 4483*da0073e9SAndroid Build Coastguard Worker""", 4484*da0073e9SAndroid Build Coastguard Worker ) 4485*da0073e9SAndroid Build Coastguard Worker # TODO: Uncomment after update to hooks behavior 4486*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(predicted[0], grad_fns(*actual)) 4487*da0073e9SAndroid Build Coastguard Worker actual = [] 4488*da0073e9SAndroid Build Coastguard Worker 4489*da0073e9SAndroid Build Coastguard Worker # Case where next node is nullptr 4490*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 4491*da0073e9SAndroid Build Coastguard Worker b = a * 2 4492*da0073e9SAndroid Build Coastguard Worker out = b.sin() 4493*da0073e9SAndroid Build Coastguard Worker register_logging_hooks(a, b, out) 4494*da0073e9SAndroid Build Coastguard Worker out.register_hook(hook) 4495*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4496*da0073e9SAndroid Build Coastguard Worker out.backward() 4497*da0073e9SAndroid Build Coastguard Worker self.assertEqual(predicted[0], grad_fns(*actual)) 4498*da0073e9SAndroid Build Coastguard Worker actual = [] 4499*da0073e9SAndroid Build Coastguard Worker 4500*da0073e9SAndroid Build Coastguard Worker # Case where two `inputs` on the same path 4501*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 4502*da0073e9SAndroid Build Coastguard Worker b = a * 2 4503*da0073e9SAndroid Build Coastguard Worker out = b.sin() 4504*da0073e9SAndroid Build Coastguard Worker register_logging_hooks(a, b, out) 4505*da0073e9SAndroid Build Coastguard Worker out.register_hook(hook) 4506*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4507*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad((out,), inputs=(a, b)) 4508*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4509*da0073e9SAndroid Build Coastguard Worker names(predicted[0]), 4510*da0073e9SAndroid Build Coastguard Worker """\ 4511*da0073e9SAndroid Build Coastguard WorkerSinBackward0, MulBackward0, torch::autograd::AccumulateGrad 4512*da0073e9SAndroid Build Coastguard Worker""", 4513*da0073e9SAndroid Build Coastguard Worker ) 4514*da0073e9SAndroid Build Coastguard Worker # TODO: Uncomment after update to hooks behavior 4515*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(predicted[0], grad_fns(*actual)) 4516*da0073e9SAndroid Build Coastguard Worker actual = [] 4517*da0073e9SAndroid Build Coastguard Worker 4518*da0073e9SAndroid Build Coastguard Worker # Case where `inputs` specifies a subgraph 4519*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 4520*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1.0, requires_grad=True) 4521*da0073e9SAndroid Build Coastguard Worker c = a * b 4522*da0073e9SAndroid Build Coastguard Worker out = c.sin() 4523*da0073e9SAndroid Build Coastguard Worker register_logging_hooks(a, b, c, out) 4524*da0073e9SAndroid Build Coastguard Worker out.register_hook(hook) 4525*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 4526*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad((out,), inputs=(a,)) 4527*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4528*da0073e9SAndroid Build Coastguard Worker names(predicted[0]), 4529*da0073e9SAndroid Build Coastguard Worker """\ 4530*da0073e9SAndroid Build Coastguard WorkerSinBackward0, MulBackward0, torch::autograd::AccumulateGrad 4531*da0073e9SAndroid Build Coastguard Worker""", 4532*da0073e9SAndroid Build Coastguard Worker ) 4533*da0073e9SAndroid Build Coastguard Worker # TODO: Uncomment after update to hooks behavior 4534*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(predicted[0], grad_fns(*actual)) 4535*da0073e9SAndroid Build Coastguard Worker actual = [] 4536*da0073e9SAndroid Build Coastguard Worker 4537*da0073e9SAndroid Build Coastguard Worker # Errors when not called in a backward 4538*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4539*da0073e9SAndroid Build Coastguard Worker RuntimeError, "should only be called during the backward pass" 4540*da0073e9SAndroid Build Coastguard Worker ): 4541*da0073e9SAndroid Build Coastguard Worker torch._C._current_graph_task_execution_order() 4542*da0073e9SAndroid Build Coastguard Worker 4543*da0073e9SAndroid Build Coastguard Worker # Errors when context manager not enabled 4544*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0, requires_grad=True).clone().sin().exp() 4545*da0073e9SAndroid Build Coastguard Worker t.register_hook(hook) 4546*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 4547*da0073e9SAndroid Build Coastguard Worker RuntimeError, 4548*da0073e9SAndroid Build Coastguard Worker "expects the current backward to be executed with multithreading disabled", 4549*da0073e9SAndroid Build Coastguard Worker ): 4550*da0073e9SAndroid Build Coastguard Worker t.backward() 4551*da0073e9SAndroid Build Coastguard Worker 4552*da0073e9SAndroid Build Coastguard Worker def test_view_replay_enabled(self): 4553*da0073e9SAndroid Build Coastguard Worker def f(x): 4554*da0073e9SAndroid Build Coastguard Worker out = x.clone().view(-1) 4555*da0073e9SAndroid Build Coastguard Worker # mutate the view, triggering autograd view-replay logic 4556*da0073e9SAndroid Build Coastguard Worker out.add_(1) 4557*da0073e9SAndroid Build Coastguard Worker return out 4558*da0073e9SAndroid Build Coastguard Worker 4559*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, requires_grad=True) 4560*da0073e9SAndroid Build Coastguard Worker 4561*da0073e9SAndroid Build Coastguard Worker # Test as a context manager 4562*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(False): 4563*da0073e9SAndroid Build Coastguard Worker out = f(x) 4564*da0073e9SAndroid Build Coastguard Worker self.assertTrue("AsStridedBackward" in str(out.grad_fn)) 4565*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_view_replay_enabled()) 4566*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_view_replay_enabled()) 4567*da0073e9SAndroid Build Coastguard Worker 4568*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(True): 4569*da0073e9SAndroid Build Coastguard Worker out = f(x) 4570*da0073e9SAndroid Build Coastguard Worker self.assertTrue("ViewBackward" in str(out.grad_fn)) 4571*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_view_replay_enabled()) 4572*da0073e9SAndroid Build Coastguard Worker out = f(x) 4573*da0073e9SAndroid Build Coastguard Worker self.assertTrue("AsStridedBackward" in str(out.grad_fn)) 4574*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_view_replay_enabled()) 4575*da0073e9SAndroid Build Coastguard Worker 4576*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(False): 4577*da0073e9SAndroid Build Coastguard Worker torch.autograd._force_original_view_tracking(True) 4578*da0073e9SAndroid Build Coastguard Worker out = f(x) 4579*da0073e9SAndroid Build Coastguard Worker self.assertTrue("ViewBackward" in str(out.grad_fn)) 4580*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_view_replay_enabled()) 4581*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_view_replay_enabled()) 4582*da0073e9SAndroid Build Coastguard Worker 4583*da0073e9SAndroid Build Coastguard Worker # Test as a function 4584*da0073e9SAndroid Build Coastguard Worker torch.autograd._force_original_view_tracking(False) 4585*da0073e9SAndroid Build Coastguard Worker out = f(x) 4586*da0073e9SAndroid Build Coastguard Worker self.assertTrue("AsStridedBackward" in str(out.grad_fn)) 4587*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_view_replay_enabled()) 4588*da0073e9SAndroid Build Coastguard Worker 4589*da0073e9SAndroid Build Coastguard Worker torch.autograd._force_original_view_tracking(True) 4590*da0073e9SAndroid Build Coastguard Worker out = f(x) 4591*da0073e9SAndroid Build Coastguard Worker self.assertTrue("ViewBackward" in str(out.grad_fn)) 4592*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_view_replay_enabled()) 4593*da0073e9SAndroid Build Coastguard Worker 4594*da0073e9SAndroid Build Coastguard Worker def test_unsafe_set_version_counter(self): 4595*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, requires_grad=True).clone() 4596*da0073e9SAndroid Build Coastguard Worker x.add_(1) 4597*da0073e9SAndroid Build Coastguard Worker x.add_(2) 4598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, x._version) 4599*da0073e9SAndroid Build Coastguard Worker with torch.autograd._unsafe_preserve_version_counter(x): 4600*da0073e9SAndroid Build Coastguard Worker x.mul_(2) 4601*da0073e9SAndroid Build Coastguard Worker x.mul_(3) 4602*da0073e9SAndroid Build Coastguard Worker # version counter doesn't change inside of the context manager 4603*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2, x._version) 4604*da0073e9SAndroid Build Coastguard Worker 4605*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._unsafe_set_version_counter(x, 0) 4606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, x._version) 4607*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Cannot set"): 4608*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._unsafe_set_version_counter(x, -1) 4609*da0073e9SAndroid Build Coastguard Worker 4610*da0073e9SAndroid Build Coastguard Worker def test_current_node(self): 4611*da0073e9SAndroid Build Coastguard Worker pr = [] 4612*da0073e9SAndroid Build Coastguard Worker 4613*da0073e9SAndroid Build Coastguard Worker class MyMode(TorchDispatchMode): 4614*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args, kwargs=None): 4615*da0073e9SAndroid Build Coastguard Worker node = torch._C._current_autograd_node() 4616*da0073e9SAndroid Build Coastguard Worker # Don't use node.name() here as it is not consistent on windows 4617*da0073e9SAndroid Build Coastguard Worker node_name = node.__class__.__name__ if node else "None" 4618*da0073e9SAndroid Build Coastguard Worker pr.append(f"Running {func} from within {node_name}") 4619*da0073e9SAndroid Build Coastguard Worker return func(*args, **(kwargs or {})) 4620*da0073e9SAndroid Build Coastguard Worker 4621*da0073e9SAndroid Build Coastguard Worker with MyMode(): 4622*da0073e9SAndroid Build Coastguard Worker pr.append("FW") 4623*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, requires_grad=True) 4624*da0073e9SAndroid Build Coastguard Worker b = a.mul(2).div(3).sum() 4625*da0073e9SAndroid Build Coastguard Worker pr.append("BW") 4626*da0073e9SAndroid Build Coastguard Worker b.backward() 4627*da0073e9SAndroid Build Coastguard Worker pr.append("Done") 4628*da0073e9SAndroid Build Coastguard Worker 4629*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 4630*da0073e9SAndroid Build Coastguard Worker "\n".join(pr), 4631*da0073e9SAndroid Build Coastguard Worker """\ 4632*da0073e9SAndroid Build Coastguard WorkerFW 4633*da0073e9SAndroid Build Coastguard WorkerRunning aten.rand.default from within None 4634*da0073e9SAndroid Build Coastguard WorkerRunning aten.mul.Tensor from within None 4635*da0073e9SAndroid Build Coastguard WorkerRunning aten.div.Tensor from within None 4636*da0073e9SAndroid Build Coastguard WorkerRunning aten.sum.default from within None 4637*da0073e9SAndroid Build Coastguard WorkerBW 4638*da0073e9SAndroid Build Coastguard WorkerRunning aten.ones_like.default from within None 4639*da0073e9SAndroid Build Coastguard WorkerRunning aten.expand.default from within SumBackward0 4640*da0073e9SAndroid Build Coastguard WorkerRunning aten.div.Tensor from within DivBackward0 4641*da0073e9SAndroid Build Coastguard WorkerRunning aten.mul.Tensor from within MulBackward0 4642*da0073e9SAndroid Build Coastguard WorkerRunning aten.detach.default from within AccumulateGrad 4643*da0073e9SAndroid Build Coastguard WorkerRunning aten.detach.default from within AccumulateGrad 4644*da0073e9SAndroid Build Coastguard WorkerDone""", 4645*da0073e9SAndroid Build Coastguard Worker ) 4646*da0073e9SAndroid Build Coastguard Worker 4647*da0073e9SAndroid Build Coastguard Worker def test_profiler(self): 4648*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 4649*da0073e9SAndroid Build Coastguard Worker 4650*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4651*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd._profiler_enabled()) 4652*da0073e9SAndroid Build Coastguard Worker y = x * 2 + 4 4653*da0073e9SAndroid Build Coastguard Worker 4654*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd._profiler_enabled()) 4655*da0073e9SAndroid Build Coastguard Worker 4656*da0073e9SAndroid Build Coastguard Worker names = ["aten::mul", "aten::add"] 4657*da0073e9SAndroid Build Coastguard Worker found_indices = set() 4658*da0073e9SAndroid Build Coastguard Worker for evt in p.function_events: 4659*da0073e9SAndroid Build Coastguard Worker if evt.name in names: 4660*da0073e9SAndroid Build Coastguard Worker found_indices.add(names.index(evt.name)) 4661*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(found_indices), len(names)) 4662*da0073e9SAndroid Build Coastguard Worker 4663*da0073e9SAndroid Build Coastguard Worker def test_profiler_seq_nr(self): 4664*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4665*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10, requires_grad=True) 4666*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, requires_grad=True) 4667*da0073e9SAndroid Build Coastguard Worker z = x + y 4668*da0073e9SAndroid Build Coastguard Worker s = z.sum(dim=None) 4669*da0073e9SAndroid Build Coastguard Worker s.backward() 4670*da0073e9SAndroid Build Coastguard Worker print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1)) 4671*da0073e9SAndroid Build Coastguard Worker # expecting aten::add, aten::sum to have the sequence numbers, 4672*da0073e9SAndroid Build Coastguard Worker # expecting the corresponding backward nodes to have the same numbers 4673*da0073e9SAndroid Build Coastguard Worker # as the forward ops 4674*da0073e9SAndroid Build Coastguard Worker autograd_ops = { 4675*da0073e9SAndroid Build Coastguard Worker ("aten::add", "Add"): [], 4676*da0073e9SAndroid Build Coastguard Worker ("aten::sum", "Sum"): [], 4677*da0073e9SAndroid Build Coastguard Worker } 4678*da0073e9SAndroid Build Coastguard Worker accumulate_ops = [] 4679*da0073e9SAndroid Build Coastguard Worker found_empty = False 4680*da0073e9SAndroid Build Coastguard Worker for e in p.function_events: 4681*da0073e9SAndroid Build Coastguard Worker for (fwd_name, bwd_name), ops in autograd_ops.items(): 4682*da0073e9SAndroid Build Coastguard Worker if e.name == fwd_name or (bwd_name in e.name and "Backward" in e.name): 4683*da0073e9SAndroid Build Coastguard Worker ops.append(e) 4684*da0073e9SAndroid Build Coastguard Worker 4685*da0073e9SAndroid Build Coastguard Worker if "AccumulateGrad" in e.name: 4686*da0073e9SAndroid Build Coastguard Worker accumulate_ops.append(e) 4687*da0073e9SAndroid Build Coastguard Worker 4688*da0073e9SAndroid Build Coastguard Worker # check that nested ops (e.g. empty) don't have 4689*da0073e9SAndroid Build Coastguard Worker # sequence number 4690*da0073e9SAndroid Build Coastguard Worker if e.name == "aten::empty": 4691*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.sequence_nr, -1) 4692*da0073e9SAndroid Build Coastguard Worker found_empty = True 4693*da0073e9SAndroid Build Coastguard Worker 4694*da0073e9SAndroid Build Coastguard Worker for idx, ((fwd_name, bwd_name), ops) in enumerate(autograd_ops.items()): 4695*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ops), 3) 4696*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[0].name, fwd_name) 4697*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 4698*da0073e9SAndroid Build Coastguard Worker ops[1].name, 4699*da0073e9SAndroid Build Coastguard Worker f"autograd::engine::evaluate_function: {bwd_name}Backward{idx}", 4700*da0073e9SAndroid Build Coastguard Worker ) 4701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[2].name, f"{bwd_name}Backward{idx}") 4702*da0073e9SAndroid Build Coastguard Worker self.assertGreaterEqual(ops[0].sequence_nr, 0) 4703*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[1].sequence_nr, ops[0].sequence_nr) 4704*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[2].sequence_nr, ops[0].sequence_nr) 4705*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[0].fwd_thread, 0) 4706*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[1].fwd_thread, ops[0].thread) 4707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ops[2].fwd_thread, ops[0].thread) 4708*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_empty) 4709*da0073e9SAndroid Build Coastguard Worker 4710*da0073e9SAndroid Build Coastguard Worker def test_profiler_unboxed_only(self): 4711*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 4712*da0073e9SAndroid Build Coastguard Worker 4713*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: 4714*da0073e9SAndroid Build Coastguard Worker x.resize_([3, 2]) 4715*da0073e9SAndroid Build Coastguard Worker 4716*da0073e9SAndroid Build Coastguard Worker def test_profiler_propagation(self): 4717*da0073e9SAndroid Build Coastguard Worker def foo(x): 4718*da0073e9SAndroid Build Coastguard Worker with record_function("in_foo") as rf: 4719*da0073e9SAndroid Build Coastguard Worker return x * 2 4720*da0073e9SAndroid Build Coastguard Worker 4721*da0073e9SAndroid Build Coastguard Worker x = torch.rand(3, 4) 4722*da0073e9SAndroid Build Coastguard Worker traced_foo = torch.jit.trace(foo, x) 4723*da0073e9SAndroid Build Coastguard Worker 4724*da0073e9SAndroid Build Coastguard Worker def bar(x): 4725*da0073e9SAndroid Build Coastguard Worker with record_function("in_bar") as rf: 4726*da0073e9SAndroid Build Coastguard Worker # we expect that profiler will be able 4727*da0073e9SAndroid Build Coastguard Worker # propagate across fork 4728*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(traced_foo, x) 4729*da0073e9SAndroid Build Coastguard Worker y = torch.jit._wait(fut) 4730*da0073e9SAndroid Build Coastguard Worker # note: continuation (and rf's end) can 4731*da0073e9SAndroid Build Coastguard Worker # be executed in a different thread 4732*da0073e9SAndroid Build Coastguard Worker with record_function("in_bar_after_wait") as rf2: 4733*da0073e9SAndroid Build Coastguard Worker y = y * 2 4734*da0073e9SAndroid Build Coastguard Worker return y 4735*da0073e9SAndroid Build Coastguard Worker 4736*da0073e9SAndroid Build Coastguard Worker traced_bar = torch.jit.trace(bar, x) 4737*da0073e9SAndroid Build Coastguard Worker 4738*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4739*da0073e9SAndroid Build Coastguard Worker traced_bar(x) 4740*da0073e9SAndroid Build Coastguard Worker 4741*da0073e9SAndroid Build Coastguard Worker found_foo = False 4742*da0073e9SAndroid Build Coastguard Worker found_bar = False 4743*da0073e9SAndroid Build Coastguard Worker found_bar_after_wait = False 4744*da0073e9SAndroid Build Coastguard Worker for info in p.function_events: 4745*da0073e9SAndroid Build Coastguard Worker if info.name == "in_foo": 4746*da0073e9SAndroid Build Coastguard Worker self.assertFalse(found_foo) 4747*da0073e9SAndroid Build Coastguard Worker found_foo = True 4748*da0073e9SAndroid Build Coastguard Worker elif info.name == "in_bar": 4749*da0073e9SAndroid Build Coastguard Worker self.assertFalse(found_bar) 4750*da0073e9SAndroid Build Coastguard Worker found_bar = True 4751*da0073e9SAndroid Build Coastguard Worker elif info.name == "in_bar_after_wait": 4752*da0073e9SAndroid Build Coastguard Worker self.assertFalse(found_bar_after_wait) 4753*da0073e9SAndroid Build Coastguard Worker found_bar_after_wait = True 4754*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_foo) 4755*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_bar) 4756*da0073e9SAndroid Build Coastguard Worker self.assertTrue(found_bar_after_wait) 4757*da0073e9SAndroid Build Coastguard Worker 4758*da0073e9SAndroid Build Coastguard Worker def test_record_function_callbacks(self): 4759*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 4760*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4761*da0073e9SAndroid Build Coastguard Worker with record_function("foo"): 4762*da0073e9SAndroid Build Coastguard Worker y = x * 2 + 4 4763*da0073e9SAndroid Build Coastguard Worker 4764*da0073e9SAndroid Build Coastguard Worker function_events = p.function_events 4765*da0073e9SAndroid Build Coastguard Worker foo_event = next(event for event in function_events if "foo" in event.name) 4766*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_event.count, 1) 4767*da0073e9SAndroid Build Coastguard Worker 4768*da0073e9SAndroid Build Coastguard Worker def test_record_function_legacy(self): 4769*da0073e9SAndroid Build Coastguard Worker # Test the new _record_function ops work 4770*da0073e9SAndroid Build Coastguard Worker # Note: Remove once record_function uses these directly 4771*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 4772*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4773*da0073e9SAndroid Build Coastguard Worker handle = torch.ops.profiler._record_function_enter("bar", None) 4774*da0073e9SAndroid Build Coastguard Worker try: 4775*da0073e9SAndroid Build Coastguard Worker y = x * 2 + 4 4776*da0073e9SAndroid Build Coastguard Worker finally: 4777*da0073e9SAndroid Build Coastguard Worker torch.ops.profiler._record_function_exit(handle) 4778*da0073e9SAndroid Build Coastguard Worker 4779*da0073e9SAndroid Build Coastguard Worker function_events = p.function_events 4780*da0073e9SAndroid Build Coastguard Worker foo_event = next(event for event in function_events if "bar" in event.name) 4781*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_event.count, 1) 4782*da0073e9SAndroid Build Coastguard Worker 4783*da0073e9SAndroid Build Coastguard Worker def test_profiler_aggregation_fake(self): 4784*da0073e9SAndroid Build Coastguard Worker events = EventList() 4785*da0073e9SAndroid Build Coastguard Worker id = [0] 4786*da0073e9SAndroid Build Coastguard Worker 4787*da0073e9SAndroid Build Coastguard Worker def get_id(): 4788*da0073e9SAndroid Build Coastguard Worker id[0] = id[0] + 1 4789*da0073e9SAndroid Build Coastguard Worker return id[0] 4790*da0073e9SAndroid Build Coastguard Worker 4791*da0073e9SAndroid Build Coastguard Worker # [[thread_id, [(start, end, id), ....]], ...] 4792*da0073e9SAndroid Build Coastguard Worker # Using list instead of a dict so order is guaranteed for any Python 4793*da0073e9SAndroid Build Coastguard Worker # version 4794*da0073e9SAndroid Build Coastguard Worker threads = [ 4795*da0073e9SAndroid Build Coastguard Worker [1, [(0, 1, get_id()), (1, 2, get_id())]], 4796*da0073e9SAndroid Build Coastguard Worker [0, [(0, 2, get_id()), (1, 2, get_id()), (1, 3, get_id())]], 4797*da0073e9SAndroid Build Coastguard Worker ] 4798*da0073e9SAndroid Build Coastguard Worker for thread, ranges in threads: 4799*da0073e9SAndroid Build Coastguard Worker for range in ranges: 4800*da0073e9SAndroid Build Coastguard Worker assert len(range) == 3 4801*da0073e9SAndroid Build Coastguard Worker events.append( 4802*da0073e9SAndroid Build Coastguard Worker FunctionEvent( 4803*da0073e9SAndroid Build Coastguard Worker id=range[2], 4804*da0073e9SAndroid Build Coastguard Worker node_id=0, 4805*da0073e9SAndroid Build Coastguard Worker name="", 4806*da0073e9SAndroid Build Coastguard Worker thread=thread, 4807*da0073e9SAndroid Build Coastguard Worker start_us=range[0], 4808*da0073e9SAndroid Build Coastguard Worker end_us=range[1], 4809*da0073e9SAndroid Build Coastguard Worker ) 4810*da0073e9SAndroid Build Coastguard Worker ) 4811*da0073e9SAndroid Build Coastguard Worker 4812*da0073e9SAndroid Build Coastguard Worker events._populate_cpu_children() 4813*da0073e9SAndroid Build Coastguard Worker 4814*da0073e9SAndroid Build Coastguard Worker # Note that [1, 3] pushes out [0, 2] first. Then we record [1, 2] 4815*da0073e9SAndroid Build Coastguard Worker # as a child of [1, 3] 4816*da0073e9SAndroid Build Coastguard Worker res = [[], [], [], [], [4]] 4817*da0073e9SAndroid Build Coastguard Worker 4818*da0073e9SAndroid Build Coastguard Worker def get_children_ids(event): 4819*da0073e9SAndroid Build Coastguard Worker return [child.id for child in event.cpu_children] 4820*da0073e9SAndroid Build Coastguard Worker 4821*da0073e9SAndroid Build Coastguard Worker assert [get_children_ids(event) for event in events] == res 4822*da0073e9SAndroid Build Coastguard Worker 4823*da0073e9SAndroid Build Coastguard Worker def test_profiler_aggregation_table(self): 4824*da0073e9SAndroid Build Coastguard Worker """ 4825*da0073e9SAndroid Build Coastguard Worker Test if the profiling result is aggregated for `str(prof)` 4826*da0073e9SAndroid Build Coastguard Worker 4827*da0073e9SAndroid Build Coastguard Worker See: https://github.com/pytorch/pytorch/issues/37500 4828*da0073e9SAndroid Build Coastguard Worker """ 4829*da0073e9SAndroid Build Coastguard Worker 4830*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1024) 4831*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof: 4832*da0073e9SAndroid Build Coastguard Worker torch.einsum("i->", x) 4833*da0073e9SAndroid Build Coastguard Worker 4834*da0073e9SAndroid Build Coastguard Worker prof_str = str(prof) 4835*da0073e9SAndroid Build Coastguard Worker prof_table = prof.table() 4836*da0073e9SAndroid Build Coastguard Worker 4837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(prof_table, prof_str) 4838*da0073e9SAndroid Build Coastguard Worker 4839*da0073e9SAndroid Build Coastguard Worker def test_profiler_function_event_avg(self): 4840*da0073e9SAndroid Build Coastguard Worker avg = FunctionEventAvg() 4841*da0073e9SAndroid Build Coastguard Worker avg.add( 4842*da0073e9SAndroid Build Coastguard Worker FunctionEvent(id=0, node_id=0, name="foo", thread=0, start_us=10, end_us=15) 4843*da0073e9SAndroid Build Coastguard Worker ) 4844*da0073e9SAndroid Build Coastguard Worker avg.add( 4845*da0073e9SAndroid Build Coastguard Worker FunctionEvent(id=1, node_id=0, name="foo", thread=0, start_us=20, end_us=30) 4846*da0073e9SAndroid Build Coastguard Worker ) 4847*da0073e9SAndroid Build Coastguard Worker avg.add(avg) 4848*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.key, "foo") 4849*da0073e9SAndroid Build Coastguard Worker 4850*da0073e9SAndroid Build Coastguard Worker # aggregate stats 4851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.count, 4) 4852*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.cpu_time_total, 30) 4853*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.self_cpu_time_total, 30) 4854*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.device_time_total, 0) 4855*da0073e9SAndroid Build Coastguard Worker 4856*da0073e9SAndroid Build Coastguard Worker # average stats 4857*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.cpu_time, 7.5) 4858*da0073e9SAndroid Build Coastguard Worker self.assertEqual(avg.device_time_total, 0) 4859*da0073e9SAndroid Build Coastguard Worker 4860*da0073e9SAndroid Build Coastguard Worker def test_profiler_shapes(self): 4861*da0073e9SAndroid Build Coastguard Worker print() 4862*da0073e9SAndroid Build Coastguard Worker layer1 = torch.nn.Linear(20, 30) 4863*da0073e9SAndroid Build Coastguard Worker layer2 = torch.nn.Linear(30, 40) 4864*da0073e9SAndroid Build Coastguard Worker input = torch.randn(128, 20) 4865*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True, use_kineto=kineto_available()) as prof: 4866*da0073e9SAndroid Build Coastguard Worker layer2(layer1(input)) 4867*da0073e9SAndroid Build Coastguard Worker 4868*da0073e9SAndroid Build Coastguard Worker print(prof.function_events) 4869*da0073e9SAndroid Build Coastguard Worker 4870*da0073e9SAndroid Build Coastguard Worker linear_expected_shapes = [ 4871*da0073e9SAndroid Build Coastguard Worker [[128, 20], [30, 20], [30]], 4872*da0073e9SAndroid Build Coastguard Worker [[128, 30], [40, 30], [40]], 4873*da0073e9SAndroid Build Coastguard Worker ] 4874*da0073e9SAndroid Build Coastguard Worker 4875*da0073e9SAndroid Build Coastguard Worker found_indices = set() 4876*da0073e9SAndroid Build Coastguard Worker for event in prof.function_events: 4877*da0073e9SAndroid Build Coastguard Worker if event.name == "aten::linear": 4878*da0073e9SAndroid Build Coastguard Worker self.assertTrue(event.input_shapes in linear_expected_shapes) 4879*da0073e9SAndroid Build Coastguard Worker found_indices.add(linear_expected_shapes.index(event.input_shapes)) 4880*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(found_indices), len(linear_expected_shapes)) 4881*da0073e9SAndroid Build Coastguard Worker 4882*da0073e9SAndroid Build Coastguard Worker def test_profiler_aggregation_lstm(self): 4883*da0073e9SAndroid Build Coastguard Worker print() 4884*da0073e9SAndroid Build Coastguard Worker rnn = torch.nn.LSTM(10, 20, 2) 4885*da0073e9SAndroid Build Coastguard Worker total_time_s = 0 4886*da0073e9SAndroid Build Coastguard Worker with profile(record_shapes=True, use_kineto=kineto_available()) as prof: 4887*da0073e9SAndroid Build Coastguard Worker for i in range(20): 4888*da0073e9SAndroid Build Coastguard Worker input = torch.randn(5, 3, 10) 4889*da0073e9SAndroid Build Coastguard Worker h = torch.randn(2, 3, 20) 4890*da0073e9SAndroid Build Coastguard Worker c = torch.randn(2, 3, 20) 4891*da0073e9SAndroid Build Coastguard Worker start = time.time() 4892*da0073e9SAndroid Build Coastguard Worker rnn(input, (h, c)) 4893*da0073e9SAndroid Build Coastguard Worker end = time.time() 4894*da0073e9SAndroid Build Coastguard Worker total_time_s += end - start 4895*da0073e9SAndroid Build Coastguard Worker 4896*da0073e9SAndroid Build Coastguard Worker print(prof.table(sort_by="self_cpu_time_total", row_limit=10, header="TEST")) 4897*da0073e9SAndroid Build Coastguard Worker print( 4898*da0073e9SAndroid Build Coastguard Worker prof.key_averages(group_by_input_shape=True).table( 4899*da0073e9SAndroid Build Coastguard Worker sort_by="self_cpu_time_total", row_limit=10 4900*da0073e9SAndroid Build Coastguard Worker ) 4901*da0073e9SAndroid Build Coastguard Worker ) 4902*da0073e9SAndroid Build Coastguard Worker print( 4903*da0073e9SAndroid Build Coastguard Worker prof.table( 4904*da0073e9SAndroid Build Coastguard Worker sort_by="self_cpu_time_total", 4905*da0073e9SAndroid Build Coastguard Worker row_limit=10, 4906*da0073e9SAndroid Build Coastguard Worker max_src_column_width=300, 4907*da0073e9SAndroid Build Coastguard Worker header="TEST", 4908*da0073e9SAndroid Build Coastguard Worker top_level_events_only=True, 4909*da0073e9SAndroid Build Coastguard Worker ) 4910*da0073e9SAndroid Build Coastguard Worker ) 4911*da0073e9SAndroid Build Coastguard Worker print( 4912*da0073e9SAndroid Build Coastguard Worker prof.key_averages(group_by_input_shape=True).table( 4913*da0073e9SAndroid Build Coastguard Worker sort_by="self_cpu_time_total", row_limit=10, top_level_events_only=True 4914*da0073e9SAndroid Build Coastguard Worker ) 4915*da0073e9SAndroid Build Coastguard Worker ) 4916*da0073e9SAndroid Build Coastguard Worker 4917*da0073e9SAndroid Build Coastguard Worker total_time_us = ( 4918*da0073e9SAndroid Build Coastguard Worker total_time_s * 1000.0 * 1000.0 4919*da0073e9SAndroid Build Coastguard Worker ) # make it us which is profiler default 4920*da0073e9SAndroid Build Coastguard Worker print("Total time based on python measurements: ", _format_time(total_time_us)) 4921*da0073e9SAndroid Build Coastguard Worker print( 4922*da0073e9SAndroid Build Coastguard Worker f"CPU time measurement python side overhead: {(total_time_us / prof.self_cpu_time_total - 1.0) * 100.0:.2f}%" 4923*da0073e9SAndroid Build Coastguard Worker ) 4924*da0073e9SAndroid Build Coastguard Worker 4925*da0073e9SAndroid Build Coastguard Worker if sys.platform != "win32": 4926*da0073e9SAndroid Build Coastguard Worker with tempfile.NamedTemporaryFile() as trace_file: 4927*da0073e9SAndroid Build Coastguard Worker prof.export_chrome_trace(trace_file.name) 4928*da0073e9SAndroid Build Coastguard Worker 4929*da0073e9SAndroid Build Coastguard Worker def test_record_function(self): 4930*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 4931*da0073e9SAndroid Build Coastguard Worker 4932*da0073e9SAndroid Build Coastguard Worker def forward(x): 4933*da0073e9SAndroid Build Coastguard Worker with record_function("outer"): 4934*da0073e9SAndroid Build Coastguard Worker y = x * 2 + 4 4935*da0073e9SAndroid Build Coastguard Worker with record_function("inner"): 4936*da0073e9SAndroid Build Coastguard Worker y = y - 1 4937*da0073e9SAndroid Build Coastguard Worker y = y / 1 4938*da0073e9SAndroid Build Coastguard Worker 4939*da0073e9SAndroid Build Coastguard Worker forward(x) 4940*da0073e9SAndroid Build Coastguard Worker 4941*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4942*da0073e9SAndroid Build Coastguard Worker forward(x) 4943*da0073e9SAndroid Build Coastguard Worker 4944*da0073e9SAndroid Build Coastguard Worker events = p.function_events 4945*da0073e9SAndroid Build Coastguard Worker important_events = [ 4946*da0073e9SAndroid Build Coastguard Worker "outer", 4947*da0073e9SAndroid Build Coastguard Worker "aten::mul", 4948*da0073e9SAndroid Build Coastguard Worker "aten::add", 4949*da0073e9SAndroid Build Coastguard Worker "inner", 4950*da0073e9SAndroid Build Coastguard Worker "aten::sub", 4951*da0073e9SAndroid Build Coastguard Worker "aten::div", 4952*da0073e9SAndroid Build Coastguard Worker ] 4953*da0073e9SAndroid Build Coastguard Worker idx = 0 4954*da0073e9SAndroid Build Coastguard Worker for info in events: 4955*da0073e9SAndroid Build Coastguard Worker if info.name == important_events[idx]: 4956*da0073e9SAndroid Build Coastguard Worker idx = idx + 1 4957*da0073e9SAndroid Build Coastguard Worker if idx == len(important_events): 4958*da0073e9SAndroid Build Coastguard Worker break 4959*da0073e9SAndroid Build Coastguard Worker self.assertEqual(idx, len(important_events)) 4960*da0073e9SAndroid Build Coastguard Worker 4961*da0073e9SAndroid Build Coastguard Worker # We can also use record_function to decorate arbitrary function 4962*da0073e9SAndroid Build Coastguard Worker @record_function("my_func") 4963*da0073e9SAndroid Build Coastguard Worker def f(x, y): 4964*da0073e9SAndroid Build Coastguard Worker return x + y 4965*da0073e9SAndroid Build Coastguard Worker 4966*da0073e9SAndroid Build Coastguard Worker with profile(use_kineto=kineto_available()) as p: 4967*da0073e9SAndroid Build Coastguard Worker f(1, 2) 4968*da0073e9SAndroid Build Coastguard Worker 4969*da0073e9SAndroid Build Coastguard Worker self.assertTrue("my_func" in str(p)) 4970*da0073e9SAndroid Build Coastguard Worker 4971*da0073e9SAndroid Build Coastguard Worker def test_record_function_multithreaded(self): 4972*da0073e9SAndroid Build Coastguard Worker rf = record_function("outer") 4973*da0073e9SAndroid Build Coastguard Worker rf.__enter__() 4974*da0073e9SAndroid Build Coastguard Worker with record_function("inner"): 4975*da0073e9SAndroid Build Coastguard Worker # test that exiting the record function after starting another one 4976*da0073e9SAndroid Build Coastguard Worker # doesn't throw. 4977*da0073e9SAndroid Build Coastguard Worker rf.__exit__(None, None, None) 4978*da0073e9SAndroid Build Coastguard Worker 4979*da0073e9SAndroid Build Coastguard Worker with record_function("inner"): 4980*da0073e9SAndroid Build Coastguard Worker rf.__enter__() 4981*da0073e9SAndroid Build Coastguard Worker # test that exiting the record function after ending another one 4982*da0073e9SAndroid Build Coastguard Worker # doesn't throw. 4983*da0073e9SAndroid Build Coastguard Worker rf.__exit__(None, None, None) 4984*da0073e9SAndroid Build Coastguard Worker 4985*da0073e9SAndroid Build Coastguard Worker def test_dir(self): 4986*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, 10) 4987*da0073e9SAndroid Build Coastguard Worker keys = dir(x) 4988*da0073e9SAndroid Build Coastguard Worker self.assertIn("shape", keys) 4989*da0073e9SAndroid Build Coastguard Worker 4990*da0073e9SAndroid Build Coastguard Worker # real and imag are only implemented for complex tensors. 4991*da0073e9SAndroid Build Coastguard Worker y = torch.randn(10, 10, dtype=torch.cfloat) 4992*da0073e9SAndroid Build Coastguard Worker imag_key = "imag" 4993*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: hasattr(x, imag_key)) 4994*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(y, imag_key)) 4995*da0073e9SAndroid Build Coastguard Worker keys.remove(imag_key) 4996*da0073e9SAndroid Build Coastguard Worker 4997*da0073e9SAndroid Build Coastguard Worker for key in keys: 4998*da0073e9SAndroid Build Coastguard Worker self.assertTrue(hasattr(x, key)) 4999*da0073e9SAndroid Build Coastguard Worker 5000*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_saved_output(self): 5001*da0073e9SAndroid Build Coastguard Worker # Test an in-place operation on a view in which the in-place op saves 5002*da0073e9SAndroid Build Coastguard Worker # its output. Previously, this created a reference cycle. 5003*da0073e9SAndroid Build Coastguard Worker dealloc = [0] 5004*da0073e9SAndroid Build Coastguard Worker 5005*da0073e9SAndroid Build Coastguard Worker class IncrementOnDelete: 5006*da0073e9SAndroid Build Coastguard Worker def __del__(self): 5007*da0073e9SAndroid Build Coastguard Worker dealloc[0] += 1 5008*da0073e9SAndroid Build Coastguard Worker 5009*da0073e9SAndroid Build Coastguard Worker def test(): 5010*da0073e9SAndroid Build Coastguard Worker root = torch.randn(3, 3, requires_grad=True) 5011*da0073e9SAndroid Build Coastguard Worker copy = root.clone() 5012*da0073e9SAndroid Build Coastguard Worker copy.grad_fn.register_hook(IncrementOnDelete()) 5013*da0073e9SAndroid Build Coastguard Worker view = copy.view(9) 5014*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.relu(view, inplace=True) 5015*da0073e9SAndroid Build Coastguard Worker 5016*da0073e9SAndroid Build Coastguard Worker test() 5017*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dealloc[0], 1) 5018*da0073e9SAndroid Build Coastguard Worker 5019*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_leaf_errors(self): 5020*da0073e9SAndroid Build Coastguard Worker # Issue #21875: Fail faster (when we try to modify the view vs. in backward()) 5021*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(1, requires_grad=True) 5022*da0073e9SAndroid Build Coastguard Worker y = x.view_as(x) 5023*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5024*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5025*da0073e9SAndroid Build Coastguard Worker "a view of a leaf Variable that " 5026*da0073e9SAndroid Build Coastguard Worker "requires grad is being used in " 5027*da0073e9SAndroid Build Coastguard Worker "an in-place operation.", 5028*da0073e9SAndroid Build Coastguard Worker ): 5029*da0073e9SAndroid Build Coastguard Worker y.add_(1) 5030*da0073e9SAndroid Build Coastguard Worker 5031*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_backward(self): 5032*da0073e9SAndroid Build Coastguard Worker # Issue #10532: Make sure that this does not raise RuntimeError. 5033*da0073e9SAndroid Build Coastguard Worker net = nn.Sequential(nn.InstanceNorm2d(2), nn.ReLU(True)) 5034*da0073e9SAndroid Build Coastguard Worker 5035*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[[[1.0, 1.0]]]], requires_grad=True) 5036*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad( 5037*da0073e9SAndroid Build Coastguard Worker net(x).pow(2), [x], grad_outputs=x.new_ones(x.shape), create_graph=True 5038*da0073e9SAndroid Build Coastguard Worker ) 5039*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(g.sum(), [x]) 5040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor([[[[1.0, 1.0]]]])) 5041*da0073e9SAndroid Build Coastguard Worker 5042*da0073e9SAndroid Build Coastguard Worker # https://discuss.pytorch.org/t/freeing-buffer-strange-behavior/31955/8 5043*da0073e9SAndroid Build Coastguard Worker inputs = torch.ones((1, 3, 256, 256), requires_grad=True) 5044*da0073e9SAndroid Build Coastguard Worker 5045*da0073e9SAndroid Build Coastguard Worker tmp1 = (inputs + 1).view_as(inputs) 5046*da0073e9SAndroid Build Coastguard Worker tmp2 = torch.nn.functional.threshold(tmp1, 0.0, 0.0, True) 5047*da0073e9SAndroid Build Coastguard Worker prob_interpolated = torch.sigmoid(tmp2) 5048*da0073e9SAndroid Build Coastguard Worker 5049*da0073e9SAndroid Build Coastguard Worker gradients = torch.autograd.grad( 5050*da0073e9SAndroid Build Coastguard Worker outputs=prob_interpolated, 5051*da0073e9SAndroid Build Coastguard Worker inputs=inputs, 5052*da0073e9SAndroid Build Coastguard Worker grad_outputs=torch.ones(prob_interpolated.size()), 5053*da0073e9SAndroid Build Coastguard Worker create_graph=True, 5054*da0073e9SAndroid Build Coastguard Worker retain_graph=True, 5055*da0073e9SAndroid Build Coastguard Worker )[0] 5056*da0073e9SAndroid Build Coastguard Worker 5057*da0073e9SAndroid Build Coastguard Worker gradient_penalty = gradients.sum() 5058*da0073e9SAndroid Build Coastguard Worker gradient_penalty.backward() 5059*da0073e9SAndroid Build Coastguard Worker 5060*da0073e9SAndroid Build Coastguard Worker fn = gradient_penalty.grad_fn.next_functions[0][0].next_functions[1][0] 5061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fn.name(), "ThresholdBackwardBackward0") 5062*da0073e9SAndroid Build Coastguard Worker 5063*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_weak_grad_fn(self): 5064*da0073e9SAndroid Build Coastguard Worker # Issue 23502: Test that b's grad_fn is preserved. 5065*da0073e9SAndroid Build Coastguard Worker a = torch.arange(10.0, requires_grad=True) 5066*da0073e9SAndroid Build Coastguard Worker 5067*da0073e9SAndroid Build Coastguard Worker b = a.narrow(0, 0, 2).clone().view(-1) 5068*da0073e9SAndroid Build Coastguard Worker b.relu_() 5069*da0073e9SAndroid Build Coastguard Worker 5070*da0073e9SAndroid Build Coastguard Worker c = b.clone() 5071*da0073e9SAndroid Build Coastguard Worker del b 5072*da0073e9SAndroid Build Coastguard Worker gc.collect() 5073*da0073e9SAndroid Build Coastguard Worker 5074*da0073e9SAndroid Build Coastguard Worker s = c.sum() 5075*da0073e9SAndroid Build Coastguard Worker s.backward() 5076*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s, torch.tensor(1.0)) 5077*da0073e9SAndroid Build Coastguard Worker 5078*da0073e9SAndroid Build Coastguard Worker # Issue #21875: Fail faster (when we try to modify the view vs. in backward()) 5079*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, requires_grad=True).narrow(0, 0, 10) 5080*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 5081*da0073e9SAndroid Build Coastguard Worker b = a.relu_() 5082*da0073e9SAndroid Build Coastguard Worker 5083*da0073e9SAndroid Build Coastguard Worker def test_out_variant_raises_when_inputs_require_grad(self): 5084*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, requires_grad=True) 5085*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, requires_grad=True) 5086*da0073e9SAndroid Build Coastguard Worker x = torch.zeros_like(a) 5087*da0073e9SAndroid Build Coastguard Worker 5088*da0073e9SAndroid Build Coastguard Worker # out=... functions don't support automatic differentiation currently 5089*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x)) 5090*da0073e9SAndroid Build Coastguard Worker 5091*da0073e9SAndroid Build Coastguard Worker # the inputs can require grad if we're in no_grad() mode 5092*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 5093*da0073e9SAndroid Build Coastguard Worker torch.mul(a, b, out=x) 5094*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, a * b) 5095*da0073e9SAndroid Build Coastguard Worker 5096*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2) 5097*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2) 5098*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2, 2, requires_grad=True) 5099*da0073e9SAndroid Build Coastguard Worker # we should throw an exception if the output requires grad 5100*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex(RuntimeError, "out=", lambda: torch.mul(a, b, out=x)) 5101*da0073e9SAndroid Build Coastguard Worker 5102*da0073e9SAndroid Build Coastguard Worker def test_anomaly_detect_nan(self): 5103*da0073e9SAndroid Build Coastguard Worker size = 10 5104*da0073e9SAndroid Build Coastguard Worker 5105*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 5106*da0073e9SAndroid Build Coastguard Worker @staticmethod 5107*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1, inp2, fail_0th): 5108*da0073e9SAndroid Build Coastguard Worker ctx.fail_0th = fail_0th 5109*da0073e9SAndroid Build Coastguard Worker return inp1.sum(0, keepdim=True) 5110*da0073e9SAndroid Build Coastguard Worker 5111*da0073e9SAndroid Build Coastguard Worker @staticmethod 5112*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 5113*da0073e9SAndroid Build Coastguard Worker gI = gO.clone().expand(size) 5114*da0073e9SAndroid Build Coastguard Worker gI[0] = 0 5115*da0073e9SAndroid Build Coastguard Worker gI[0] /= 0 # Generate a nan 5116*da0073e9SAndroid Build Coastguard Worker if ctx.fail_0th: 5117*da0073e9SAndroid Build Coastguard Worker return gI, None, None 5118*da0073e9SAndroid Build Coastguard Worker else: 5119*da0073e9SAndroid Build Coastguard Worker return None, gI, None 5120*da0073e9SAndroid Build Coastguard Worker 5121*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(size, requires_grad=True) 5122*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp, inp, True) 5123*da0073e9SAndroid Build Coastguard Worker out.backward() # Should not fail 5124*da0073e9SAndroid Build Coastguard Worker 5125*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(size, requires_grad=True) 5126*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp, inp, True) 5127*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5128*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5129*da0073e9SAndroid Build Coastguard Worker "Function 'MyFuncBackward' returned nan values in its 0th output.", 5130*da0073e9SAndroid Build Coastguard Worker ): 5131*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5132*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5133*da0073e9SAndroid Build Coastguard Worker out.backward() 5134*da0073e9SAndroid Build Coastguard Worker self.assertIn("No forward pass information", str(w[0].message)) 5135*da0073e9SAndroid Build Coastguard Worker 5136*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(size, requires_grad=True) 5137*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5138*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5139*da0073e9SAndroid Build Coastguard Worker "Function 'MyFuncBackward' returned nan values in its 1th output.", 5140*da0073e9SAndroid Build Coastguard Worker ): 5141*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5142*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5143*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp, inp, False) 5144*da0073e9SAndroid Build Coastguard Worker out.backward() 5145*da0073e9SAndroid Build Coastguard Worker self.assertIn("MyFunc.apply", str(w[0].message)) 5146*da0073e9SAndroid Build Coastguard Worker 5147*da0073e9SAndroid Build Coastguard Worker def test_calculate_shape_util(self): 5148*da0073e9SAndroid Build Coastguard Worker out = torch.randn(10, 5, requires_grad=True) 5149*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(5, 10, requires_grad=True) 5150*da0073e9SAndroid Build Coastguard Worker out_shape, grad_shape = _calculate_shape(out, grad, False) 5151*da0073e9SAndroid Build Coastguard Worker 5152*da0073e9SAndroid Build Coastguard Worker assert out_shape == torch.Size([10, 5]) 5153*da0073e9SAndroid Build Coastguard Worker assert grad_shape == torch.Size([5, 10]) 5154*da0073e9SAndroid Build Coastguard Worker 5155*da0073e9SAndroid Build Coastguard Worker out = torch.nested.as_nested_tensor( 5156*da0073e9SAndroid Build Coastguard Worker [ 5157*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 5, requires_grad=True), 5158*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 5, requires_grad=True), 5159*da0073e9SAndroid Build Coastguard Worker torch.randn(10, 5, requires_grad=True), 5160*da0073e9SAndroid Build Coastguard Worker ] 5161*da0073e9SAndroid Build Coastguard Worker ) 5162*da0073e9SAndroid Build Coastguard Worker grad = torch.nested.as_nested_tensor( 5163*da0073e9SAndroid Build Coastguard Worker [ 5164*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 10, requires_grad=True), 5165*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 10, requires_grad=True), 5166*da0073e9SAndroid Build Coastguard Worker ] 5167*da0073e9SAndroid Build Coastguard Worker ) 5168*da0073e9SAndroid Build Coastguard Worker out_shape, grad_shape = _calculate_shape(out, grad, False) 5169*da0073e9SAndroid Build Coastguard Worker 5170*da0073e9SAndroid Build Coastguard Worker assert torch.equal(out_shape, torch.tensor([[10, 5], [10, 5], [10, 5]])) 5171*da0073e9SAndroid Build Coastguard Worker assert torch.equal(grad_shape, torch.tensor([[5, 10], [5, 10]])) 5172*da0073e9SAndroid Build Coastguard Worker 5173*da0073e9SAndroid Build Coastguard Worker def test_nested_anomaly_detect_nan(self): 5174*da0073e9SAndroid Build Coastguard Worker size = 10 5175*da0073e9SAndroid Build Coastguard Worker 5176*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 5177*da0073e9SAndroid Build Coastguard Worker @staticmethod 5178*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1, fail_0th): 5179*da0073e9SAndroid Build Coastguard Worker ctx.fail_0th = fail_0th 5180*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(inp1) 5181*da0073e9SAndroid Build Coastguard Worker return inp1.sum(0, keepdim=True) 5182*da0073e9SAndroid Build Coastguard Worker 5183*da0073e9SAndroid Build Coastguard Worker @staticmethod 5184*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 5185*da0073e9SAndroid Build Coastguard Worker (inp,) = ctx.saved_tensors 5186*da0073e9SAndroid Build Coastguard Worker fail_0th = ctx.fail_0th 5187*da0073e9SAndroid Build Coastguard Worker g = gO.clone().expand(size) 5188*da0073e9SAndroid Build Coastguard Worker gI = MyFunc2.apply(g * inp, g + inp, fail_0th) 5189*da0073e9SAndroid Build Coastguard Worker return gI, None 5190*da0073e9SAndroid Build Coastguard Worker 5191*da0073e9SAndroid Build Coastguard Worker class MyFunc2(Function): 5192*da0073e9SAndroid Build Coastguard Worker @staticmethod 5193*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1, inp2, fail_0th): 5194*da0073e9SAndroid Build Coastguard Worker ctx.fail_0th = fail_0th 5195*da0073e9SAndroid Build Coastguard Worker return inp1 * 2.0 + inp2 5196*da0073e9SAndroid Build Coastguard Worker 5197*da0073e9SAndroid Build Coastguard Worker @staticmethod 5198*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 5199*da0073e9SAndroid Build Coastguard Worker fail_0th = ctx.fail_0th 5200*da0073e9SAndroid Build Coastguard Worker g1 = gO.clone() 5201*da0073e9SAndroid Build Coastguard Worker g2 = gO.clone() 5202*da0073e9SAndroid Build Coastguard Worker g1[0] = 0 5203*da0073e9SAndroid Build Coastguard Worker g2[0] = 0 5204*da0073e9SAndroid Build Coastguard Worker # generate a nan 5205*da0073e9SAndroid Build Coastguard Worker if fail_0th: 5206*da0073e9SAndroid Build Coastguard Worker g1[0] /= 0 5207*da0073e9SAndroid Build Coastguard Worker else: 5208*da0073e9SAndroid Build Coastguard Worker g2[0] /= 0 5209*da0073e9SAndroid Build Coastguard Worker return g1, g2, None 5210*da0073e9SAndroid Build Coastguard Worker 5211*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(size, requires_grad=True) 5212*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp, True) 5213*da0073e9SAndroid Build Coastguard Worker (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5214*da0073e9SAndroid Build Coastguard Worker gsum = ginp.sum() 5215*da0073e9SAndroid Build Coastguard Worker gsum.backward() # should not fail 5216*da0073e9SAndroid Build Coastguard Worker 5217*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(size, requires_grad=True) 5218*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp, True) 5219*da0073e9SAndroid Build Coastguard Worker (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5220*da0073e9SAndroid Build Coastguard Worker gsum = ginp.sum() 5221*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5222*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5223*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5224*da0073e9SAndroid Build Coastguard Worker "Function 'MyFunc2Backward' returned nan values in its 0th output.", 5225*da0073e9SAndroid Build Coastguard Worker ): 5226*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5227*da0073e9SAndroid Build Coastguard Worker gsum.backward() 5228*da0073e9SAndroid Build Coastguard Worker self.assertIn("No forward pass information", str(w[1].message)) 5229*da0073e9SAndroid Build Coastguard Worker 5230*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(size, requires_grad=True) 5231*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5232*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5233*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5234*da0073e9SAndroid Build Coastguard Worker "Function 'MyFunc2Backward' returned nan values in its 1th output.", 5235*da0073e9SAndroid Build Coastguard Worker ): 5236*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5237*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp, False) 5238*da0073e9SAndroid Build Coastguard Worker (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5239*da0073e9SAndroid Build Coastguard Worker gsum = ginp.sum() 5240*da0073e9SAndroid Build Coastguard Worker gsum.backward() 5241*da0073e9SAndroid Build Coastguard Worker self.assertIn("MyFunc2.apply", str(w[1].message)) 5242*da0073e9SAndroid Build Coastguard Worker self.assertIn("MyFunc.apply", str(w[2].message)) 5243*da0073e9SAndroid Build Coastguard Worker 5244*da0073e9SAndroid Build Coastguard Worker def test_anomaly_grad_warnings(self): 5245*da0073e9SAndroid Build Coastguard Worker # PyTorch won't throw warnings if there is an error 5246*da0073e9SAndroid Build Coastguard Worker # but we'd want to at least see them in stderr 5247*da0073e9SAndroid Build Coastguard Worker 5248*da0073e9SAndroid Build Coastguard Worker class StdErrDiverter: 5249*da0073e9SAndroid Build Coastguard Worker def __enter__(self): 5250*da0073e9SAndroid Build Coastguard Worker self.stderr_orig = sys.stderr 5251*da0073e9SAndroid Build Coastguard Worker self.stderr_new = io.StringIO() 5252*da0073e9SAndroid Build Coastguard Worker sys.stderr = self.stderr_new 5253*da0073e9SAndroid Build Coastguard Worker return self 5254*da0073e9SAndroid Build Coastguard Worker 5255*da0073e9SAndroid Build Coastguard Worker def __exit__(self, *args): 5256*da0073e9SAndroid Build Coastguard Worker self.captured = self.stderr_new.getvalue() 5257*da0073e9SAndroid Build Coastguard Worker sys.stderr = self.stderr_orig 5258*da0073e9SAndroid Build Coastguard Worker 5259*da0073e9SAndroid Build Coastguard Worker # if the warnings don't throw, they will be handled as regular warnings 5260*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5261*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5262*da0073e9SAndroid Build Coastguard Worker "one of the variables needed for gradient computation has been " 5263*da0073e9SAndroid Build Coastguard Worker "modified by an inplace operation", 5264*da0073e9SAndroid Build Coastguard Worker ): 5265*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5266*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5267*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 5268*da0073e9SAndroid Build Coastguard Worker d1 = a + 1 5269*da0073e9SAndroid Build Coastguard Worker d2 = d1**2 5270*da0073e9SAndroid Build Coastguard Worker d1 += 1 5271*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(d2.sum(), a) 5272*da0073e9SAndroid Build Coastguard Worker 5273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 5274*da0073e9SAndroid Build Coastguard Worker self.assertIn("Anomaly Detection has been enabled", str(w[0].message)) 5275*da0073e9SAndroid Build Coastguard Worker self.assertIn("Error detected in PowBackward0", str(w[1].message)) 5276*da0073e9SAndroid Build Coastguard Worker 5277*da0073e9SAndroid Build Coastguard Worker # if the warning throws, it will be printed to sys.stderr 5278*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5279*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5280*da0073e9SAndroid Build Coastguard Worker "one of the variables needed for gradient computation has been " 5281*da0073e9SAndroid Build Coastguard Worker "modified by an inplace operation", 5282*da0073e9SAndroid Build Coastguard Worker ): 5283*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5284*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5285*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("error") 5286*da0073e9SAndroid Build Coastguard Worker with StdErrDiverter() as s: 5287*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 5288*da0073e9SAndroid Build Coastguard Worker d1 = a + 1 5289*da0073e9SAndroid Build Coastguard Worker d2 = d1**2 5290*da0073e9SAndroid Build Coastguard Worker d1 += 1 5291*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(d2.sum(), a) 5292*da0073e9SAndroid Build Coastguard Worker 5293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 5294*da0073e9SAndroid Build Coastguard Worker self.assertIn("Anomaly Detection has been enabled", str(w[0].message)) 5295*da0073e9SAndroid Build Coastguard Worker self.assertIn("Error detected in PowBackward0", s.captured) 5296*da0073e9SAndroid Build Coastguard Worker 5297*da0073e9SAndroid Build Coastguard Worker def test_anomaly_assign_parent_cleanup(self): 5298*da0073e9SAndroid Build Coastguard Worker # Test that python objects created are properly cleaned up when assign_parent is called 5299*da0073e9SAndroid Build Coastguard Worker 5300*da0073e9SAndroid Build Coastguard Worker def get_ref(): 5301*da0073e9SAndroid Build Coastguard Worker # we use torch.exp here but any function that will construct a new node in its 5302*da0073e9SAndroid Build Coastguard Worker # backward call in grad mode will work 5303*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 5304*da0073e9SAndroid Build Coastguard Worker t = x.exp() 5305*da0073e9SAndroid Build Coastguard Worker 5306*da0073e9SAndroid Build Coastguard Worker # ExpBackward calls mul, creating the MulBackward node when create_graph=True. 5307*da0073e9SAndroid Build Coastguard Worker # In anomaly mode, a PyObject referencing MulBackward's "parent" ExpBackward is added to 5308*da0073e9SAndroid Build Coastguard Worker # MulBackward's anomaly metadata dict, creating the following reference chain: 5309*da0073e9SAndroid Build Coastguard Worker # 5310*da0073e9SAndroid Build Coastguard Worker # grad -> MulBackward -> PyObject -> ExpBackward 5311*da0073e9SAndroid Build Coastguard Worker # 5312*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5313*da0073e9SAndroid Build Coastguard Worker grad = torch.autograd.grad(t, x, torch.ones_like(t), create_graph=True) 5314*da0073e9SAndroid Build Coastguard Worker 5315*da0073e9SAndroid Build Coastguard Worker # We add a weak reference to a new Foo object, which we insert into ExpBackward's metadata dict 5316*da0073e9SAndroid Build Coastguard Worker # 5317*da0073e9SAndroid Build Coastguard Worker # (PyObject) -> ExpBackward -> dict -> *Foo* 5318*da0073e9SAndroid Build Coastguard Worker # t ----^ WeakRef ---^ 5319*da0073e9SAndroid Build Coastguard Worker # 5320*da0073e9SAndroid Build Coastguard Worker # We want to test that when grad goes out of scope at the end of this function that PyObject is destroyed 5321*da0073e9SAndroid Build Coastguard Worker # We can test this by seeing whether Foo is not kept alive once t is destroyed 5322*da0073e9SAndroid Build Coastguard Worker class Foo: 5323*da0073e9SAndroid Build Coastguard Worker pass 5324*da0073e9SAndroid Build Coastguard Worker 5325*da0073e9SAndroid Build Coastguard Worker my_obj = Foo() 5326*da0073e9SAndroid Build Coastguard Worker meta_dict = t.grad_fn.metadata 5327*da0073e9SAndroid Build Coastguard Worker meta_dict[0] = my_obj 5328*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(my_obj) 5329*da0073e9SAndroid Build Coastguard Worker return t, ref 5330*da0073e9SAndroid Build Coastguard Worker 5331*da0073e9SAndroid Build Coastguard Worker t, ref = get_ref() 5332*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(ref()) 5333*da0073e9SAndroid Build Coastguard Worker del t 5334*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref()) 5335*da0073e9SAndroid Build Coastguard Worker 5336*da0073e9SAndroid Build Coastguard Worker def test_nested_anomaly_printstack_cleanup(self): 5337*da0073e9SAndroid Build Coastguard Worker # Test if metadata dict PyObject is properly destroyed 5338*da0073e9SAndroid Build Coastguard Worker def get_ref(): 5339*da0073e9SAndroid Build Coastguard Worker # This is similar to the construction in test_anomaly_assign_parent_cleanup: 5340*da0073e9SAndroid Build Coastguard Worker # 5341*da0073e9SAndroid Build Coastguard Worker # MyFuncBackward2 -> PyObject -> MyFuncBackward -> dict -> Foo 5342*da0073e9SAndroid Build Coastguard Worker # out ---^ WeakRef ---^ 5343*da0073e9SAndroid Build Coastguard Worker # 5344*da0073e9SAndroid Build Coastguard Worker # We want to check that Foo is still properly destroyed even when MyFunc2Backward's 5345*da0073e9SAndroid Build Coastguard Worker # AnomalyMetadata calls printstack, which does some python object manipulation. 5346*da0073e9SAndroid Build Coastguard Worker # 5347*da0073e9SAndroid Build Coastguard Worker # You might be wondering why we still have to test_anomaly_assign_parent_cleanup, 5348*da0073e9SAndroid Build Coastguard Worker # since if PyObject is not destroyed here, wouldn't this test would detect that also? 5349*da0073e9SAndroid Build Coastguard Worker # The answer is that custom function's PyObject (THPFunction) actually only hold 5350*da0073e9SAndroid Build Coastguard Worker # a weak reference to the c++ node! 5351*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 5352*da0073e9SAndroid Build Coastguard Worker @staticmethod 5353*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 5354*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 5355*da0073e9SAndroid Build Coastguard Worker return x 5356*da0073e9SAndroid Build Coastguard Worker 5357*da0073e9SAndroid Build Coastguard Worker @staticmethod 5358*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 5359*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 5360*da0073e9SAndroid Build Coastguard Worker return MyFunc2.apply(x) 5361*da0073e9SAndroid Build Coastguard Worker 5362*da0073e9SAndroid Build Coastguard Worker class MyFunc2(Function): 5363*da0073e9SAndroid Build Coastguard Worker @staticmethod 5364*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 5365*da0073e9SAndroid Build Coastguard Worker return x 5366*da0073e9SAndroid Build Coastguard Worker 5367*da0073e9SAndroid Build Coastguard Worker @staticmethod 5368*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 5369*da0073e9SAndroid Build Coastguard Worker return gO + float("NaN") 5370*da0073e9SAndroid Build Coastguard Worker 5371*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, requires_grad=True) 5372*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(inp) 5373*da0073e9SAndroid Build Coastguard Worker (ginp,) = torch.autograd.grad(out, (inp,), create_graph=True) 5374*da0073e9SAndroid Build Coastguard Worker 5375*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5376*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5377*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5378*da0073e9SAndroid Build Coastguard Worker "Function 'MyFunc2Backward' returned nan values in its 0th output.", 5379*da0073e9SAndroid Build Coastguard Worker ): 5380*da0073e9SAndroid Build Coastguard Worker with detect_anomaly(): 5381*da0073e9SAndroid Build Coastguard Worker ginp.backward() 5382*da0073e9SAndroid Build Coastguard Worker 5383*da0073e9SAndroid Build Coastguard Worker class Foo: 5384*da0073e9SAndroid Build Coastguard Worker pass 5385*da0073e9SAndroid Build Coastguard Worker 5386*da0073e9SAndroid Build Coastguard Worker my_obj = Foo() 5387*da0073e9SAndroid Build Coastguard Worker meta_dict = out.grad_fn.metadata 5388*da0073e9SAndroid Build Coastguard Worker meta_dict[0] = my_obj 5389*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(my_obj) 5390*da0073e9SAndroid Build Coastguard Worker return out, ref 5391*da0073e9SAndroid Build Coastguard Worker 5392*da0073e9SAndroid Build Coastguard Worker t, ref = get_ref() 5393*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(ref()) 5394*da0073e9SAndroid Build Coastguard Worker del t 5395*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref()) 5396*da0073e9SAndroid Build Coastguard Worker 5397*da0073e9SAndroid Build Coastguard Worker def test_anomaly_mode_no_check_nan(self): 5398*da0073e9SAndroid Build Coastguard Worker class MyFunc(torch.autograd.Function): 5399*da0073e9SAndroid Build Coastguard Worker @staticmethod 5400*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp): 5401*da0073e9SAndroid Build Coastguard Worker return inp.clone() 5402*da0073e9SAndroid Build Coastguard Worker 5403*da0073e9SAndroid Build Coastguard Worker @staticmethod 5404*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 5405*da0073e9SAndroid Build Coastguard Worker return torch.tensor(float("nan")).expand(10, 10) 5406*da0073e9SAndroid Build Coastguard Worker 5407*da0073e9SAndroid Build Coastguard Worker def run_fn(a): 5408*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(a) 5409*da0073e9SAndroid Build Coastguard Worker return out.sum() 5410*da0073e9SAndroid Build Coastguard Worker 5411*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 5412*da0073e9SAndroid Build Coastguard Worker with torch.autograd.detect_anomaly(check_nan=False): 5413*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(10, 10, requires_grad=True) 5414*da0073e9SAndroid Build Coastguard Worker out = run_fn(inp) 5415*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 5416*da0073e9SAndroid Build Coastguard Worker 5417*da0073e9SAndroid Build Coastguard Worker with torch.autograd.detect_anomaly(check_nan=True): 5418*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5419*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5420*da0073e9SAndroid Build Coastguard Worker "Function 'MyFuncBackward' returned nan values in its 0th output.", 5421*da0073e9SAndroid Build Coastguard Worker ): 5422*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 5423*da0073e9SAndroid Build Coastguard Worker 5424*da0073e9SAndroid Build Coastguard Worker out.backward() 5425*da0073e9SAndroid Build Coastguard Worker 5426*da0073e9SAndroid Build Coastguard Worker def test_no_grad_copy(self): 5427*da0073e9SAndroid Build Coastguard Worker # create autograd function that saves grad pointer as class static 5428*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 5429*da0073e9SAndroid Build Coastguard Worker static_grad_ptr = None 5430*da0073e9SAndroid Build Coastguard Worker 5431*da0073e9SAndroid Build Coastguard Worker @staticmethod 5432*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1, inp2): 5433*da0073e9SAndroid Build Coastguard Worker return inp1 + inp2 5434*da0073e9SAndroid Build Coastguard Worker 5435*da0073e9SAndroid Build Coastguard Worker @staticmethod 5436*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5437*da0073e9SAndroid Build Coastguard Worker MyFunc.static_grad_ptr = grad.data_ptr() 5438*da0073e9SAndroid Build Coastguard Worker return grad, grad 5439*da0073e9SAndroid Build Coastguard Worker 5440*da0073e9SAndroid Build Coastguard Worker class NonContGradFunc(Function): 5441*da0073e9SAndroid Build Coastguard Worker @staticmethod 5442*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1): 5443*da0073e9SAndroid Build Coastguard Worker ctx.size = inp1.size() 5444*da0073e9SAndroid Build Coastguard Worker return torch.tensor([1.0]) 5445*da0073e9SAndroid Build Coastguard Worker 5446*da0073e9SAndroid Build Coastguard Worker @staticmethod 5447*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5448*da0073e9SAndroid Build Coastguard Worker return torch.ones(1).expand(ctx.size) 5449*da0073e9SAndroid Build Coastguard Worker 5450*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, 6, requires_grad=True) 5451*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, 6, requires_grad=True) 5452*da0073e9SAndroid Build Coastguard Worker # non-contiguous grad should be copied 5453*da0073e9SAndroid Build Coastguard Worker NonContGradFunc.apply(MyFunc.apply(a, b)).backward() 5454*da0073e9SAndroid Build Coastguard Worker self.assertFalse(a.grad.data_ptr() == MyFunc.static_grad_ptr) 5455*da0073e9SAndroid Build Coastguard Worker self.assertFalse(b.grad.data_ptr() == MyFunc.static_grad_ptr) 5456*da0073e9SAndroid Build Coastguard Worker # test case that should trigger no copy for one of a,b 5457*da0073e9SAndroid Build Coastguard Worker a.grad = b.grad = None 5458*da0073e9SAndroid Build Coastguard Worker MyFunc.apply(a, b)[1][0].backward() 5459*da0073e9SAndroid Build Coastguard Worker p_g = MyFunc.static_grad_ptr 5460*da0073e9SAndroid Build Coastguard Worker p_a = a.grad.data_ptr() 5461*da0073e9SAndroid Build Coastguard Worker p_b = b.grad.data_ptr() 5462*da0073e9SAndroid Build Coastguard Worker # check a,b uses different grad buffer 5463*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p_a == p_b) 5464*da0073e9SAndroid Build Coastguard Worker # check one of them is using the computed buffer 5465*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p_a == p_g or p_b == p_g) 5466*da0073e9SAndroid Build Coastguard Worker 5467*da0073e9SAndroid Build Coastguard Worker def test_no_grad_copy_sparse(self): 5468*da0073e9SAndroid Build Coastguard Worker # create autograd function that saves grad pointer as class static 5469*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 5470*da0073e9SAndroid Build Coastguard Worker static_grad_ptr = None 5471*da0073e9SAndroid Build Coastguard Worker 5472*da0073e9SAndroid Build Coastguard Worker @staticmethod 5473*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1, inp2): 5474*da0073e9SAndroid Build Coastguard Worker return inp1 + inp2 5475*da0073e9SAndroid Build Coastguard Worker 5476*da0073e9SAndroid Build Coastguard Worker @staticmethod 5477*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5478*da0073e9SAndroid Build Coastguard Worker MyFunc.static_grad_ptr = grad._values().data_ptr() 5479*da0073e9SAndroid Build Coastguard Worker return grad, grad 5480*da0073e9SAndroid Build Coastguard Worker 5481*da0073e9SAndroid Build Coastguard Worker class NonContGradFunc(Function): 5482*da0073e9SAndroid Build Coastguard Worker static_grad_ptr = None 5483*da0073e9SAndroid Build Coastguard Worker 5484*da0073e9SAndroid Build Coastguard Worker @staticmethod 5485*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp1, inp2): 5486*da0073e9SAndroid Build Coastguard Worker return inp1 + inp2 5487*da0073e9SAndroid Build Coastguard Worker 5488*da0073e9SAndroid Build Coastguard Worker @staticmethod 5489*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 5490*da0073e9SAndroid Build Coastguard Worker # Create a sparse tensor with non-contigous indices and values 5491*da0073e9SAndroid Build Coastguard Worker # and return as grad. 5492*da0073e9SAndroid Build Coastguard Worker v = torch.rand(1, 3) 5493*da0073e9SAndroid Build Coastguard Worker i = torch.ones(1, 1, dtype=torch.long) 5494*da0073e9SAndroid Build Coastguard Worker nv = v.expand(8, 3) 5495*da0073e9SAndroid Build Coastguard Worker ni = i.expand(1, 8) 5496*da0073e9SAndroid Build Coastguard Worker ngrad = torch.sparse_coo_tensor(ni, nv, (10, 3), dtype=torch.float32) 5497*da0073e9SAndroid Build Coastguard Worker NonContGradFunc.static_grad_ptr = ngrad._values().data_ptr() 5498*da0073e9SAndroid Build Coastguard Worker return ngrad, ngrad 5499*da0073e9SAndroid Build Coastguard Worker 5500*da0073e9SAndroid Build Coastguard Worker a = torch.randn(10, 3, requires_grad=True) 5501*da0073e9SAndroid Build Coastguard Worker b = torch.randn(10, 3, requires_grad=True) 5502*da0073e9SAndroid Build Coastguard Worker input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]) 5503*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 4]) 5504*da0073e9SAndroid Build Coastguard Worker import torch.nn.functional as F 5505*da0073e9SAndroid Build Coastguard Worker 5506*da0073e9SAndroid Build Coastguard Worker # test case that should trigger no copy for one of a,b 5507*da0073e9SAndroid Build Coastguard Worker emb_matrix = MyFunc.apply(a, b) 5508*da0073e9SAndroid Build Coastguard Worker loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 5509*da0073e9SAndroid Build Coastguard Worker loss.backward(retain_graph=True) 5510*da0073e9SAndroid Build Coastguard Worker p_g = MyFunc.static_grad_ptr 5511*da0073e9SAndroid Build Coastguard Worker p_a = a.grad._values().data_ptr() 5512*da0073e9SAndroid Build Coastguard Worker p_b = b.grad._values().data_ptr() 5513*da0073e9SAndroid Build Coastguard Worker # check a,b uses different grad buffer 5514*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p_a == p_b) 5515*da0073e9SAndroid Build Coastguard Worker # check one of them is using the computed buffer 5516*da0073e9SAndroid Build Coastguard Worker self.assertTrue(p_a == p_g or p_b == p_g) 5517*da0073e9SAndroid Build Coastguard Worker 5518*da0073e9SAndroid Build Coastguard Worker # Run backwards multiple times to ensure accumulation works. 5519*da0073e9SAndroid Build Coastguard Worker for i in range(10): 5520*da0073e9SAndroid Build Coastguard Worker loss.backward(retain_graph=True) 5521*da0073e9SAndroid Build Coastguard Worker 5522*da0073e9SAndroid Build Coastguard Worker # non-contiguous indices and value, we should trigger a copy. 5523*da0073e9SAndroid Build Coastguard Worker a.grad = b.grad = None 5524*da0073e9SAndroid Build Coastguard Worker emb_matrix = NonContGradFunc.apply(a, b) 5525*da0073e9SAndroid Build Coastguard Worker loss = F.embedding_bag(emb_matrix, input, offsets, sparse=True).sum() 5526*da0073e9SAndroid Build Coastguard Worker loss.backward(retain_graph=True) 5527*da0073e9SAndroid Build Coastguard Worker p_g = NonContGradFunc.static_grad_ptr 5528*da0073e9SAndroid Build Coastguard Worker p_a = a.grad._values().data_ptr() 5529*da0073e9SAndroid Build Coastguard Worker p_b = b.grad._values().data_ptr() 5530*da0073e9SAndroid Build Coastguard Worker # check a,b uses different grad buffer 5531*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p_a == p_b) 5532*da0073e9SAndroid Build Coastguard Worker # Verify we cloned both grads. 5533*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p_a == p_g) 5534*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p_b == p_g) 5535*da0073e9SAndroid Build Coastguard Worker 5536*da0073e9SAndroid Build Coastguard Worker # Run backwards multiple times to ensure accumulation works. 5537*da0073e9SAndroid Build Coastguard Worker for i in range(10): 5538*da0073e9SAndroid Build Coastguard Worker loss.backward(retain_graph=True) 5539*da0073e9SAndroid Build Coastguard Worker 5540*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_single_input(self): 5541*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5542*da0073e9SAndroid Build Coastguard Worker def f(inp): 5543*da0073e9SAndroid Build Coastguard Worker return inp.mul(5) 5544*da0073e9SAndroid Build Coastguard Worker 5545*da0073e9SAndroid Build Coastguard Worker gradcheck( 5546*da0073e9SAndroid Build Coastguard Worker f, 5547*da0073e9SAndroid Build Coastguard Worker torch.rand(10, dtype=torch.float64, requires_grad=True), 5548*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5549*da0073e9SAndroid Build Coastguard Worker ) 5550*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 5551*da0073e9SAndroid Build Coastguard Worker f, 5552*da0073e9SAndroid Build Coastguard Worker torch.rand(10, dtype=torch.float64, requires_grad=True), 5553*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5554*da0073e9SAndroid Build Coastguard Worker ) 5555*da0073e9SAndroid Build Coastguard Worker 5556*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5557*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5558*da0073e9SAndroid Build Coastguard Worker 5559*da0073e9SAndroid Build Coastguard Worker @parametrize( 5560*da0073e9SAndroid Build Coastguard Worker "layout", 5561*da0073e9SAndroid Build Coastguard Worker ( 5562*da0073e9SAndroid Build Coastguard Worker torch.sparse_coo, 5563*da0073e9SAndroid Build Coastguard Worker torch.sparse_csr, 5564*da0073e9SAndroid Build Coastguard Worker torch.sparse_csc, 5565*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsr, 5566*da0073e9SAndroid Build Coastguard Worker torch.sparse_bsc, 5567*da0073e9SAndroid Build Coastguard Worker ), 5568*da0073e9SAndroid Build Coastguard Worker ) 5569*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_input(self, layout): 5570*da0073e9SAndroid Build Coastguard Worker if layout in {torch.sparse_bsr, torch.sparse_bsc}: 5571*da0073e9SAndroid Build Coastguard Worker blocksize = (2, 2) 5572*da0073e9SAndroid Build Coastguard Worker size = (4, 8) 5573*da0073e9SAndroid Build Coastguard Worker else: 5574*da0073e9SAndroid Build Coastguard Worker blocksize = None 5575*da0073e9SAndroid Build Coastguard Worker size = (2, 2) 5576*da0073e9SAndroid Build Coastguard Worker 5577*da0073e9SAndroid Build Coastguard Worker def check(fast_mode, masked): 5578*da0073e9SAndroid Build Coastguard Worker def fn(sparse): 5579*da0073e9SAndroid Build Coastguard Worker return torch.sum(sparse) 5580*da0073e9SAndroid Build Coastguard Worker 5581*da0073e9SAndroid Build Coastguard Worker gradcheck( 5582*da0073e9SAndroid Build Coastguard Worker fn, 5583*da0073e9SAndroid Build Coastguard Worker torch.rand(size, dtype=torch.double) 5584*da0073e9SAndroid Build Coastguard Worker .to_sparse(layout=layout, blocksize=blocksize) 5585*da0073e9SAndroid Build Coastguard Worker .requires_grad_(), 5586*da0073e9SAndroid Build Coastguard Worker masked=masked, 5587*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5588*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5589*da0073e9SAndroid Build Coastguard Worker ) 5590*da0073e9SAndroid Build Coastguard Worker 5591*da0073e9SAndroid Build Coastguard Worker for fast_mode, masked in product(*[(True, False)] * 2): 5592*da0073e9SAndroid Build Coastguard Worker check(fast_mode=fast_mode, masked=masked) 5593*da0073e9SAndroid Build Coastguard Worker 5594*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_nondeterministic(self): 5595*da0073e9SAndroid Build Coastguard Worker class NonDetFunc(Function): 5596*da0073e9SAndroid Build Coastguard Worker @staticmethod 5597*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, jitter=0.0): 5598*da0073e9SAndroid Build Coastguard Worker ctx._jitter = jitter 5599*da0073e9SAndroid Build Coastguard Worker return x 5600*da0073e9SAndroid Build Coastguard Worker 5601*da0073e9SAndroid Build Coastguard Worker @staticmethod 5602*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_out): 5603*da0073e9SAndroid Build Coastguard Worker return ( 5604*da0073e9SAndroid Build Coastguard Worker NonDetFunc.apply(grad_out, ctx._jitter) 5605*da0073e9SAndroid Build Coastguard Worker * (1 + torch.rand_like(grad_out) * ctx._jitter), 5606*da0073e9SAndroid Build Coastguard Worker None, 5607*da0073e9SAndroid Build Coastguard Worker ) 5608*da0073e9SAndroid Build Coastguard Worker 5609*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5610*da0073e9SAndroid Build Coastguard Worker inp = torch.randn(5, 5, dtype=torch.double, requires_grad=True) 5611*da0073e9SAndroid Build Coastguard Worker gradcheck( 5612*da0073e9SAndroid Build Coastguard Worker lambda x: NonDetFunc.apply(x, 0.0), 5613*da0073e9SAndroid Build Coastguard Worker inp, 5614*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5615*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5616*da0073e9SAndroid Build Coastguard Worker ) 5617*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"): 5618*da0073e9SAndroid Build Coastguard Worker gradcheck( 5619*da0073e9SAndroid Build Coastguard Worker lambda x: NonDetFunc.apply(x, 1e-6), 5620*da0073e9SAndroid Build Coastguard Worker inp, 5621*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5622*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5623*da0073e9SAndroid Build Coastguard Worker ) 5624*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Backward is not reentrant"): 5625*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 5626*da0073e9SAndroid Build Coastguard Worker lambda x: NonDetFunc.apply(x, 1e-12), 5627*da0073e9SAndroid Build Coastguard Worker inp, 5628*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5629*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5630*da0073e9SAndroid Build Coastguard Worker ) 5631*da0073e9SAndroid Build Coastguard Worker gradcheck( 5632*da0073e9SAndroid Build Coastguard Worker lambda x: NonDetFunc.apply(x, 0.0), 5633*da0073e9SAndroid Build Coastguard Worker inp, 5634*da0073e9SAndroid Build Coastguard Worker nondet_tol=1e-5, 5635*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5636*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5637*da0073e9SAndroid Build Coastguard Worker ) 5638*da0073e9SAndroid Build Coastguard Worker gradcheck( 5639*da0073e9SAndroid Build Coastguard Worker lambda x: NonDetFunc.apply(x, 1e-6), 5640*da0073e9SAndroid Build Coastguard Worker inp, 5641*da0073e9SAndroid Build Coastguard Worker nondet_tol=1e-5, 5642*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5643*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5644*da0073e9SAndroid Build Coastguard Worker ) 5645*da0073e9SAndroid Build Coastguard Worker gradgradcheck( 5646*da0073e9SAndroid Build Coastguard Worker lambda x: NonDetFunc.apply(x, 1e-12), 5647*da0073e9SAndroid Build Coastguard Worker inp, 5648*da0073e9SAndroid Build Coastguard Worker nondet_tol=1e-5, 5649*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5650*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5651*da0073e9SAndroid Build Coastguard Worker ) 5652*da0073e9SAndroid Build Coastguard Worker 5653*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5654*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5655*da0073e9SAndroid Build Coastguard Worker 5656*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_validates_inputs(self): 5657*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5658*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, requires_grad=True).to_sparse() 5659*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5660*da0073e9SAndroid Build Coastguard Worker gradcheck( 5661*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(), 5662*da0073e9SAndroid Build Coastguard Worker (x,), 5663*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5664*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5665*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5666*da0073e9SAndroid Build Coastguard Worker masked=True, 5667*da0073e9SAndroid Build Coastguard Worker ) 5668*da0073e9SAndroid Build Coastguard Worker ) 5669*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5670*da0073e9SAndroid Build Coastguard Worker gradcheck( 5671*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(), 5672*da0073e9SAndroid Build Coastguard Worker (x,), 5673*da0073e9SAndroid Build Coastguard Worker masked=False, 5674*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5675*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5676*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5677*da0073e9SAndroid Build Coastguard Worker ) 5678*da0073e9SAndroid Build Coastguard Worker ) 5679*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5680*da0073e9SAndroid Build Coastguard Worker gradcheck( 5681*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(masked_grad=False), 5682*da0073e9SAndroid Build Coastguard Worker (x,), 5683*da0073e9SAndroid Build Coastguard Worker masked=False, 5684*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5685*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5686*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5687*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5688*da0073e9SAndroid Build Coastguard Worker ) 5689*da0073e9SAndroid Build Coastguard Worker ) 5690*da0073e9SAndroid Build Coastguard Worker 5691*da0073e9SAndroid Build Coastguard Worker # when none of the inputs require grad (always raises even if raise_exception=False) 5692*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, requires_grad=False) 5693*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5694*da0073e9SAndroid Build Coastguard Worker ValueError, "at least one input tensor to require gradient" 5695*da0073e9SAndroid Build Coastguard Worker ): 5696*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode) 5697*da0073e9SAndroid Build Coastguard Worker 5698*da0073e9SAndroid Build Coastguard Worker # (warning) when inputs are not double precision 5699*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.float32, requires_grad=True) 5700*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 5701*da0073e9SAndroid Build Coastguard Worker UserWarning, "Input #0 requires gradient and is not a double precision" 5702*da0073e9SAndroid Build Coastguard Worker ): 5703*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 5704*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x, (x,), atol=1e-1, fast_mode=fast_mode) 5705*da0073e9SAndroid Build Coastguard Worker ) 5706*da0073e9SAndroid Build Coastguard Worker 5707*da0073e9SAndroid Build Coastguard Worker # when layout is not mkldnn(aka has strides) and input has a dimension with stride 0. (always raises 5708*da0073e9SAndroid Build Coastguard Worker # even if raise_exception=False) 5709*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.float64, requires_grad=True) 5710*da0073e9SAndroid Build Coastguard Worker x = x.expand((2, 2)) 5711*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5712*da0073e9SAndroid Build Coastguard Worker RuntimeError, "The 0th input has a dimension with stride 0" 5713*da0073e9SAndroid Build Coastguard Worker ): 5714*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x, (x,), raise_exception=False, fast_mode=fast_mode) 5715*da0073e9SAndroid Build Coastguard Worker 5716*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5717*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5718*da0073e9SAndroid Build Coastguard Worker 5719*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 5720*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 5721*da0073e9SAndroid Build Coastguard Worker ) 5722*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_validates_input_mkldnn(self): 5723*da0073e9SAndroid Build Coastguard Worker # when mkldnn inputs, forward mode testing is not allowed 5724*da0073e9SAndroid Build Coastguard Worker # Update tolerances below to make sure the gradient match even in single precision floats 5725*da0073e9SAndroid Build Coastguard Worker # Use the warning assert to hide the float32 warning 5726*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1).to_mkldnn().requires_grad_() 5727*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 5728*da0073e9SAndroid Build Coastguard Worker UserWarning, "Input #0 requires gradient and is not a double precision" 5729*da0073e9SAndroid Build Coastguard Worker ): 5730*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5731*da0073e9SAndroid Build Coastguard Worker ValueError, "MKLDNN inputs are not support for forward AD gradcheck." 5732*da0073e9SAndroid Build Coastguard Worker ): 5733*da0073e9SAndroid Build Coastguard Worker gradcheck( 5734*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(), 5735*da0073e9SAndroid Build Coastguard Worker (x,), 5736*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5737*da0073e9SAndroid Build Coastguard Worker fast_mode=False, 5738*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 5739*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5740*da0073e9SAndroid Build Coastguard Worker rtol=1e-1, 5741*da0073e9SAndroid Build Coastguard Worker ) 5742*da0073e9SAndroid Build Coastguard Worker 5743*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 5744*da0073e9SAndroid Build Coastguard Worker UserWarning, "Input #0 requires gradient and is not a double precision" 5745*da0073e9SAndroid Build Coastguard Worker ): 5746*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5747*da0073e9SAndroid Build Coastguard Worker ValueError, "MKLDNN inputs are not support for forward AD gradcheck." 5748*da0073e9SAndroid Build Coastguard Worker ): 5749*da0073e9SAndroid Build Coastguard Worker gradcheck( 5750*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(), 5751*da0073e9SAndroid Build Coastguard Worker (x,), 5752*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5753*da0073e9SAndroid Build Coastguard Worker fast_mode=True, 5754*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 5755*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5756*da0073e9SAndroid Build Coastguard Worker rtol=1e-1, 5757*da0073e9SAndroid Build Coastguard Worker ) 5758*da0073e9SAndroid Build Coastguard Worker 5759*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 5760*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 5761*da0073e9SAndroid Build Coastguard Worker ) 5762*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_test_outputs(self): 5763*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5764*da0073e9SAndroid Build Coastguard Worker # when sparse outputs (always raise even if raise_exception=False) 5765*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, requires_grad=True).to_sparse() 5766*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5767*da0073e9SAndroid Build Coastguard Worker ValueError, "Sparse output is not supported at gradcheck yet" 5768*da0073e9SAndroid Build Coastguard Worker ): 5769*da0073e9SAndroid Build Coastguard Worker gradcheck( 5770*da0073e9SAndroid Build Coastguard Worker lambda x: x, 5771*da0073e9SAndroid Build Coastguard Worker (x,), 5772*da0073e9SAndroid Build Coastguard Worker masked=True, 5773*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5774*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5775*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5776*da0073e9SAndroid Build Coastguard Worker ) 5777*da0073e9SAndroid Build Coastguard Worker 5778*da0073e9SAndroid Build Coastguard Worker # when mkldnn outputs (always raise even if raise_exception=False) 5779*da0073e9SAndroid Build Coastguard Worker root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) 5780*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5781*da0073e9SAndroid Build Coastguard Worker ValueError, "MKLDNN output is not supported at gradcheck yet" 5782*da0073e9SAndroid Build Coastguard Worker ): 5783*da0073e9SAndroid Build Coastguard Worker gradcheck( 5784*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_mkldnn(), 5785*da0073e9SAndroid Build Coastguard Worker (root,), 5786*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5787*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5788*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5789*da0073e9SAndroid Build Coastguard Worker ) 5790*da0073e9SAndroid Build Coastguard Worker 5791*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5792*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5793*da0073e9SAndroid Build Coastguard Worker 5794*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_check_no_differentiable_outputs(self): 5795*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5796*da0073e9SAndroid Build Coastguard Worker # When none of the outputs are differentiable, but numerical gradient is not zero 5797*da0073e9SAndroid Build Coastguard Worker x = torch.ones((1,), requires_grad=True) 5798*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5799*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Numerical gradient for function expected to be zero" 5800*da0073e9SAndroid Build Coastguard Worker ): 5801*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: torch.tensor([x]), x) 5802*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5803*da0073e9SAndroid Build Coastguard Worker gradcheck( 5804*da0073e9SAndroid Build Coastguard Worker lambda x: torch.tensor([x]), 5805*da0073e9SAndroid Build Coastguard Worker x, 5806*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5807*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5808*da0073e9SAndroid Build Coastguard Worker ) 5809*da0073e9SAndroid Build Coastguard Worker ) 5810*da0073e9SAndroid Build Coastguard Worker 5811*da0073e9SAndroid Build Coastguard Worker # succeed when no outputs at all 5812*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(lambda x: (), (x,), fast_mode=fast_mode)) 5813*da0073e9SAndroid Build Coastguard Worker 5814*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5815*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5816*da0073e9SAndroid Build Coastguard Worker 5817*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_check_batched_grad(self): 5818*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5819*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, dtype=torch.double, requires_grad=True).to_sparse() 5820*da0073e9SAndroid Build Coastguard Worker # runtime error while compute batched grad (print big error) 5821*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5822*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5823*da0073e9SAndroid Build Coastguard Worker "gradcheck or gradgradcheck failed while testing batched gradient", 5824*da0073e9SAndroid Build Coastguard Worker ): 5825*da0073e9SAndroid Build Coastguard Worker gradcheck( 5826*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(), 5827*da0073e9SAndroid Build Coastguard Worker (x,), 5828*da0073e9SAndroid Build Coastguard Worker masked=True, 5829*da0073e9SAndroid Build Coastguard Worker check_batched_grad=True, 5830*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5831*da0073e9SAndroid Build Coastguard Worker ) 5832*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5833*da0073e9SAndroid Build Coastguard Worker gradcheck( 5834*da0073e9SAndroid Build Coastguard Worker lambda x: x.to_dense(), 5835*da0073e9SAndroid Build Coastguard Worker (x,), 5836*da0073e9SAndroid Build Coastguard Worker masked=True, 5837*da0073e9SAndroid Build Coastguard Worker check_batched_grad=True, 5838*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5839*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5840*da0073e9SAndroid Build Coastguard Worker ) 5841*da0073e9SAndroid Build Coastguard Worker ) 5842*da0073e9SAndroid Build Coastguard Worker 5843*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5844*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5845*da0073e9SAndroid Build Coastguard Worker 5846*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_backward_mul_by_grad_output(self): 5847*da0073e9SAndroid Build Coastguard Worker # when grad_input is sparse and has incorrect sparse_dim/dense_dim 5848*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5849*da0073e9SAndroid Build Coastguard Worker def fn(x): 5850*da0073e9SAndroid Build Coastguard Worker def hook(grad): 5851*da0073e9SAndroid Build Coastguard Worker if grad is not None: 5852*da0073e9SAndroid Build Coastguard Worker return grad.to_dense().to_sparse(1) 5853*da0073e9SAndroid Build Coastguard Worker return grad 5854*da0073e9SAndroid Build Coastguard Worker 5855*da0073e9SAndroid Build Coastguard Worker y = x.clone() 5856*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook) 5857*da0073e9SAndroid Build Coastguard Worker return y.to_dense() 5858*da0073e9SAndroid Build Coastguard Worker 5859*da0073e9SAndroid Build Coastguard Worker x = torch.ones((2, 2), dtype=torch.double, requires_grad=True).to_sparse() 5860*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5861*da0073e9SAndroid Build Coastguard Worker RuntimeError, "grad is sparse tensor, but has incorrect sparse_dim" 5862*da0073e9SAndroid Build Coastguard Worker ): 5863*da0073e9SAndroid Build Coastguard Worker gradcheck( 5864*da0073e9SAndroid Build Coastguard Worker fn, 5865*da0073e9SAndroid Build Coastguard Worker (x,), 5866*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5867*da0073e9SAndroid Build Coastguard Worker masked=True, 5868*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5869*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5870*da0073e9SAndroid Build Coastguard Worker ) 5871*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5872*da0073e9SAndroid Build Coastguard Worker gradcheck( 5873*da0073e9SAndroid Build Coastguard Worker fn, 5874*da0073e9SAndroid Build Coastguard Worker (x,), 5875*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5876*da0073e9SAndroid Build Coastguard Worker masked=True, 5877*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5878*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5879*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5880*da0073e9SAndroid Build Coastguard Worker ) 5881*da0073e9SAndroid Build Coastguard Worker ) 5882*da0073e9SAndroid Build Coastguard Worker 5883*da0073e9SAndroid Build Coastguard Worker # when backward not multiplied by grad_output (non-sparse case) 5884*da0073e9SAndroid Build Coastguard Worker def fn2(x): 5885*da0073e9SAndroid Build Coastguard Worker y = x.clone() 5886*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x + 1e-2) 5887*da0073e9SAndroid Build Coastguard Worker return y 5888*da0073e9SAndroid Build Coastguard Worker 5889*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.double, requires_grad=True) 5890*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5891*da0073e9SAndroid Build Coastguard Worker RuntimeError, "backward not multiplied by grad_output" 5892*da0073e9SAndroid Build Coastguard Worker ): 5893*da0073e9SAndroid Build Coastguard Worker gradcheck(fn2, (x,), atol=1e-1, fast_mode=fast_mode) 5894*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5895*da0073e9SAndroid Build Coastguard Worker gradcheck( 5896*da0073e9SAndroid Build Coastguard Worker fn2, (x,), atol=1e-1, raise_exception=False, fast_mode=fast_mode 5897*da0073e9SAndroid Build Coastguard Worker ) 5898*da0073e9SAndroid Build Coastguard Worker ) 5899*da0073e9SAndroid Build Coastguard Worker 5900*da0073e9SAndroid Build Coastguard Worker # when backward not multiplied by grad_output (sparse case) 5901*da0073e9SAndroid Build Coastguard Worker def fn3(x): 5902*da0073e9SAndroid Build Coastguard Worker y = x.clone().to_dense() 5903*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x + 1e-2) 5904*da0073e9SAndroid Build Coastguard Worker return y 5905*da0073e9SAndroid Build Coastguard Worker 5906*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.double, requires_grad=True).to_sparse() 5907*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5908*da0073e9SAndroid Build Coastguard Worker RuntimeError, "backward not multiplied by grad_output" 5909*da0073e9SAndroid Build Coastguard Worker ): 5910*da0073e9SAndroid Build Coastguard Worker gradcheck( 5911*da0073e9SAndroid Build Coastguard Worker fn3, 5912*da0073e9SAndroid Build Coastguard Worker (x,), 5913*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5914*da0073e9SAndroid Build Coastguard Worker masked=True, 5915*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5916*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5917*da0073e9SAndroid Build Coastguard Worker ) 5918*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5919*da0073e9SAndroid Build Coastguard Worker gradcheck( 5920*da0073e9SAndroid Build Coastguard Worker fn3, 5921*da0073e9SAndroid Build Coastguard Worker (x,), 5922*da0073e9SAndroid Build Coastguard Worker atol=1e-1, 5923*da0073e9SAndroid Build Coastguard Worker masked=True, 5924*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5925*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5926*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5927*da0073e9SAndroid Build Coastguard Worker ) 5928*da0073e9SAndroid Build Coastguard Worker ) 5929*da0073e9SAndroid Build Coastguard Worker 5930*da0073e9SAndroid Build Coastguard Worker # when layout of grad_input is not the same as input 5931*da0073e9SAndroid Build Coastguard Worker class Test(Function): 5932*da0073e9SAndroid Build Coastguard Worker @staticmethod 5933*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 5934*da0073e9SAndroid Build Coastguard Worker return x 5935*da0073e9SAndroid Build Coastguard Worker 5936*da0073e9SAndroid Build Coastguard Worker @staticmethod 5937*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 5938*da0073e9SAndroid Build Coastguard Worker return x.to_sparse() 5939*da0073e9SAndroid Build Coastguard Worker 5940*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.double, requires_grad=True) 5941*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "grad is incorrect layout"): 5942*da0073e9SAndroid Build Coastguard Worker gradcheck( 5943*da0073e9SAndroid Build Coastguard Worker Test.apply, (x,), check_batched_grad=False, fast_mode=fast_mode 5944*da0073e9SAndroid Build Coastguard Worker ) 5945*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5946*da0073e9SAndroid Build Coastguard Worker gradcheck( 5947*da0073e9SAndroid Build Coastguard Worker Test.apply, 5948*da0073e9SAndroid Build Coastguard Worker (x,), 5949*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 5950*da0073e9SAndroid Build Coastguard Worker raise_exception=False, 5951*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 5952*da0073e9SAndroid Build Coastguard Worker ) 5953*da0073e9SAndroid Build Coastguard Worker ) 5954*da0073e9SAndroid Build Coastguard Worker 5955*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5956*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5957*da0073e9SAndroid Build Coastguard Worker 5958*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_undefined_grad(self): 5959*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5960*da0073e9SAndroid Build Coastguard Worker # when encounter runtime error while running backward 5961*da0073e9SAndroid Build Coastguard Worker def fn(x): 5962*da0073e9SAndroid Build Coastguard Worker def hook(x): 5963*da0073e9SAndroid Build Coastguard Worker if x is None: 5964*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("x is undefined") 5965*da0073e9SAndroid Build Coastguard Worker 5966*da0073e9SAndroid Build Coastguard Worker y = x.clone() 5967*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook) 5968*da0073e9SAndroid Build Coastguard Worker return y 5969*da0073e9SAndroid Build Coastguard Worker 5970*da0073e9SAndroid Build Coastguard Worker x = torch.ones(1, dtype=torch.double, requires_grad=True) 5971*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 5972*da0073e9SAndroid Build Coastguard Worker UserWarning, 5973*da0073e9SAndroid Build Coastguard Worker "Backwards compatibility: New undefined gradient support checking feature", 5974*da0073e9SAndroid Build Coastguard Worker ): 5975*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5976*da0073e9SAndroid Build Coastguard Worker RuntimeError, 5977*da0073e9SAndroid Build Coastguard Worker "Expected backward function to handle undefined output grads", 5978*da0073e9SAndroid Build Coastguard Worker ): 5979*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), fast_mode=fast_mode) 5980*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 5981*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode) 5982*da0073e9SAndroid Build Coastguard Worker ) 5983*da0073e9SAndroid Build Coastguard Worker 5984*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 5985*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 5986*da0073e9SAndroid Build Coastguard Worker 5987*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_jacobian_mismatch(self): 5988*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 5989*da0073e9SAndroid Build Coastguard Worker def fn(x): # R -> R, C -> C 5990*da0073e9SAndroid Build Coastguard Worker y = x.clone() 5991*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x + 1e-2) 5992*da0073e9SAndroid Build Coastguard Worker return y 5993*da0073e9SAndroid Build Coastguard Worker 5994*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, requires_grad=True) 5995*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 5996*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Jacobian mismatch for output 0 with respect to input 0" 5997*da0073e9SAndroid Build Coastguard Worker ): 5998*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), fast_mode=fast_mode) 5999*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6000*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode) 6001*da0073e9SAndroid Build Coastguard Worker ) 6002*da0073e9SAndroid Build Coastguard Worker 6003*da0073e9SAndroid Build Coastguard Worker x_c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128) 6004*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6005*da0073e9SAndroid Build Coastguard Worker RuntimeError, 6006*da0073e9SAndroid Build Coastguard Worker "While considering the imaginary part of complex outputs only", 6007*da0073e9SAndroid Build Coastguard Worker ): 6008*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x_c,), fast_mode=False) 6009*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6010*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x_c,), raise_exception=False, fast_mode=False) 6011*da0073e9SAndroid Build Coastguard Worker ) 6012*da0073e9SAndroid Build Coastguard Worker 6013*da0073e9SAndroid Build Coastguard Worker def fn2(x): # R -> C 6014*da0073e9SAndroid Build Coastguard Worker y = torch.complex(x, x) 6015*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x + 1e-2) 6016*da0073e9SAndroid Build Coastguard Worker return y 6017*da0073e9SAndroid Build Coastguard Worker 6018*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, requires_grad=True) 6019*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6020*da0073e9SAndroid Build Coastguard Worker RuntimeError, 6021*da0073e9SAndroid Build Coastguard Worker "While considering the imaginary part of complex outputs only", 6022*da0073e9SAndroid Build Coastguard Worker ): 6023*da0073e9SAndroid Build Coastguard Worker gradcheck(fn2, (x,), fast_mode=False) 6024*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6025*da0073e9SAndroid Build Coastguard Worker gradcheck(fn2, (x,), raise_exception=False, fast_mode=False) 6026*da0073e9SAndroid Build Coastguard Worker ) 6027*da0073e9SAndroid Build Coastguard Worker 6028*da0073e9SAndroid Build Coastguard Worker def fn3(x): # C -> R 6029*da0073e9SAndroid Build Coastguard Worker y = torch.real(x) 6030*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x + 1e-2) 6031*da0073e9SAndroid Build Coastguard Worker return y 6032*da0073e9SAndroid Build Coastguard Worker 6033*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6034*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Jacobian mismatch for output 0 with respect to input 0" 6035*da0073e9SAndroid Build Coastguard Worker ): 6036*da0073e9SAndroid Build Coastguard Worker gradcheck(fn3, (x_c,), fast_mode=False) 6037*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6038*da0073e9SAndroid Build Coastguard Worker gradcheck(fn3, (x_c,), raise_exception=False, fast_mode=False) 6039*da0073e9SAndroid Build Coastguard Worker ) 6040*da0073e9SAndroid Build Coastguard Worker 6041*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 6042*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 6043*da0073e9SAndroid Build Coastguard Worker 6044*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_dense_and_sparse_inputs(self): 6045*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 6046*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6047*da0073e9SAndroid Build Coastguard Worker return x * y.coalesce().to_dense() 6048*da0073e9SAndroid Build Coastguard Worker 6049*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, dtype=torch.double, requires_grad=True) 6050*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, dtype=torch.double).to_sparse().requires_grad_(True) 6051*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 6052*da0073e9SAndroid Build Coastguard Worker gradcheck( 6053*da0073e9SAndroid Build Coastguard Worker fn, 6054*da0073e9SAndroid Build Coastguard Worker (a, b), 6055*da0073e9SAndroid Build Coastguard Worker masked=True, 6056*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6057*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 6058*da0073e9SAndroid Build Coastguard Worker ) 6059*da0073e9SAndroid Build Coastguard Worker ) 6060*da0073e9SAndroid Build Coastguard Worker 6061*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 6062*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 6063*da0073e9SAndroid Build Coastguard Worker 6064*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6065*da0073e9SAndroid Build Coastguard Worker not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled" 6066*da0073e9SAndroid Build Coastguard Worker ) 6067*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_multiple_mkldnn_inputs(self): 6068*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 6069*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6070*da0073e9SAndroid Build Coastguard Worker return x + y.to_dense() 6071*da0073e9SAndroid Build Coastguard Worker 6072*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, requires_grad=True) 6073*da0073e9SAndroid Build Coastguard Worker b = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True) 6074*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 6075*da0073e9SAndroid Build Coastguard Worker gradcheck( 6076*da0073e9SAndroid Build Coastguard Worker fn, (a, b), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode 6077*da0073e9SAndroid Build Coastguard Worker ) 6078*da0073e9SAndroid Build Coastguard Worker ) 6079*da0073e9SAndroid Build Coastguard Worker 6080*da0073e9SAndroid Build Coastguard Worker def fn2(x, y): 6081*da0073e9SAndroid Build Coastguard Worker return x.to_dense() + y.to_dense() 6082*da0073e9SAndroid Build Coastguard Worker 6083*da0073e9SAndroid Build Coastguard Worker c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True) 6084*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 6085*da0073e9SAndroid Build Coastguard Worker gradcheck( 6086*da0073e9SAndroid Build Coastguard Worker fn, (a, c), atol=1e-1, check_batched_grad=False, fast_mode=fast_mode 6087*da0073e9SAndroid Build Coastguard Worker ) 6088*da0073e9SAndroid Build Coastguard Worker ) 6089*da0073e9SAndroid Build Coastguard Worker 6090*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 6091*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 6092*da0073e9SAndroid Build Coastguard Worker 6093*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_output_shape_or_dtype_depend_on_values(self): 6094*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 6095*da0073e9SAndroid Build Coastguard Worker def fn(x): 6096*da0073e9SAndroid Build Coastguard Worker if torch.all(x >= 1): 6097*da0073e9SAndroid Build Coastguard Worker return torch.cat([x, x]) 6098*da0073e9SAndroid Build Coastguard Worker else: 6099*da0073e9SAndroid Build Coastguard Worker return x 6100*da0073e9SAndroid Build Coastguard Worker 6101*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, dtype=torch.double, requires_grad=True) 6102*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6103*da0073e9SAndroid Build Coastguard Worker AssertionError, 6104*da0073e9SAndroid Build Coastguard Worker "return outputs with the same shape when inputs are perturbed", 6105*da0073e9SAndroid Build Coastguard Worker ): 6106*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(fn, (a,), fast_mode=fast_mode)) 6107*da0073e9SAndroid Build Coastguard Worker 6108*da0073e9SAndroid Build Coastguard Worker def fn2(x): 6109*da0073e9SAndroid Build Coastguard Worker if torch.all(x >= 1): 6110*da0073e9SAndroid Build Coastguard Worker return x.to(torch.float32) 6111*da0073e9SAndroid Build Coastguard Worker else: 6112*da0073e9SAndroid Build Coastguard Worker return x 6113*da0073e9SAndroid Build Coastguard Worker 6114*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6115*da0073e9SAndroid Build Coastguard Worker AssertionError, 6116*da0073e9SAndroid Build Coastguard Worker "return outputs with the same dtype when inputs are perturbed", 6117*da0073e9SAndroid Build Coastguard Worker ): 6118*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(fn2, (a,), fast_mode=fast_mode)) 6119*da0073e9SAndroid Build Coastguard Worker 6120*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 6121*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 6122*da0073e9SAndroid Build Coastguard Worker 6123*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_complex_non_complex_outputs(self): 6124*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6125*da0073e9SAndroid Build Coastguard Worker z = torch.complex(x, y) 6126*da0073e9SAndroid Build Coastguard Worker return z, x + 1 6127*da0073e9SAndroid Build Coastguard Worker 6128*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 2, requires_grad=True, dtype=torch.float64) 6129*da0073e9SAndroid Build Coastguard Worker b = torch.ones(2, 2, requires_grad=True, dtype=torch.float64) 6130*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(fn, (a, b))) 6131*da0073e9SAndroid Build Coastguard Worker 6132*da0073e9SAndroid Build Coastguard Worker def fn2(z): 6133*da0073e9SAndroid Build Coastguard Worker return z, torch.real(z) 6134*da0073e9SAndroid Build Coastguard Worker 6135*da0073e9SAndroid Build Coastguard Worker c = torch.ones(2, 2, requires_grad=True, dtype=torch.complex128) 6136*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(fn2, (c))) 6137*da0073e9SAndroid Build Coastguard Worker 6138*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_get_numerical_jacobian(self): 6139*da0073e9SAndroid Build Coastguard Worker # get_numerical_jacobian is deprecated and no longer used internally by gradcheck 6140*da0073e9SAndroid Build Coastguard Worker from torch.autograd.gradcheck import get_numerical_jacobian 6141*da0073e9SAndroid Build Coastguard Worker 6142*da0073e9SAndroid Build Coastguard Worker def fn(inputs): 6143*da0073e9SAndroid Build Coastguard Worker # get_numerical_jacobian requires fn to take inputs as a tuple 6144*da0073e9SAndroid Build Coastguard Worker # and returns the jacobian wrt the first output 6145*da0073e9SAndroid Build Coastguard Worker x = inputs[0] 6146*da0073e9SAndroid Build Coastguard Worker y = inputs[1] 6147*da0073e9SAndroid Build Coastguard Worker return 2 * x + y, x + 2 * y 6148*da0073e9SAndroid Build Coastguard Worker 6149*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6150*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6151*da0073e9SAndroid Build Coastguard Worker 6152*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 6153*da0073e9SAndroid Build Coastguard Worker FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" 6154*da0073e9SAndroid Build Coastguard Worker ): 6155*da0073e9SAndroid Build Coastguard Worker jacobian = get_numerical_jacobian(fn, (a, b), target=a, eps=1e-6) 6156*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) 6157*da0073e9SAndroid Build Coastguard Worker 6158*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 6159*da0073e9SAndroid Build Coastguard Worker FutureWarning, "`get_numerical_jacobian` was part of PyTorch's private API" 6160*da0073e9SAndroid Build Coastguard Worker ): 6161*da0073e9SAndroid Build Coastguard Worker jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6) 6162*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jacobian[0], 2 * torch.eye(4, dtype=torch.double)) 6163*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jacobian[1], 1 * torch.eye(4, dtype=torch.double)) 6164*da0073e9SAndroid Build Coastguard Worker 6165*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"): 6166*da0073e9SAndroid Build Coastguard Worker jacobian = get_numerical_jacobian(fn, (a, b), eps=1e-6, grad_out=2.0) 6167*da0073e9SAndroid Build Coastguard Worker 6168*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_get_analytical_jacobian(self): 6169*da0073e9SAndroid Build Coastguard Worker from torch.autograd.gradcheck import get_analytical_jacobian 6170*da0073e9SAndroid Build Coastguard Worker 6171*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6172*da0073e9SAndroid Build Coastguard Worker return 2 * x + y, x + 2 * y 6173*da0073e9SAndroid Build Coastguard Worker 6174*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6175*da0073e9SAndroid Build Coastguard Worker b = torch.rand(2, 2, requires_grad=True, dtype=torch.float64) 6176*da0073e9SAndroid Build Coastguard Worker 6177*da0073e9SAndroid Build Coastguard Worker outputs = fn(a, b) 6178*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 6179*da0073e9SAndroid Build Coastguard Worker FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" 6180*da0073e9SAndroid Build Coastguard Worker ): 6181*da0073e9SAndroid Build Coastguard Worker ( 6182*da0073e9SAndroid Build Coastguard Worker jacobians, 6183*da0073e9SAndroid Build Coastguard Worker reentrant, 6184*da0073e9SAndroid Build Coastguard Worker correct_grad_sizes, 6185*da0073e9SAndroid Build Coastguard Worker correct_grad_types, 6186*da0073e9SAndroid Build Coastguard Worker ) = get_analytical_jacobian((a, b), outputs[0]) 6187*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jacobians[0], 2 * torch.eye(4, dtype=torch.double)) 6188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jacobians[1], 1 * torch.eye(4, dtype=torch.double)) 6189*da0073e9SAndroid Build Coastguard Worker self.assertTrue(reentrant) 6190*da0073e9SAndroid Build Coastguard Worker 6191*da0073e9SAndroid Build Coastguard Worker class NonDetFunc(Function): 6192*da0073e9SAndroid Build Coastguard Worker @staticmethod 6193*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, jitter=0.0): 6194*da0073e9SAndroid Build Coastguard Worker ctx._jitter = jitter 6195*da0073e9SAndroid Build Coastguard Worker return x 6196*da0073e9SAndroid Build Coastguard Worker 6197*da0073e9SAndroid Build Coastguard Worker @staticmethod 6198*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_out): 6199*da0073e9SAndroid Build Coastguard Worker return ( 6200*da0073e9SAndroid Build Coastguard Worker NonDetFunc.apply(grad_out, ctx._jitter) 6201*da0073e9SAndroid Build Coastguard Worker * (1 + torch.rand_like(grad_out) * ctx._jitter), 6202*da0073e9SAndroid Build Coastguard Worker None, 6203*da0073e9SAndroid Build Coastguard Worker ) 6204*da0073e9SAndroid Build Coastguard Worker 6205*da0073e9SAndroid Build Coastguard Worker outputs = NonDetFunc.apply(a, 1e-6) 6206*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex( 6207*da0073e9SAndroid Build Coastguard Worker FutureWarning, "`get_analytical_jacobian` was part of PyTorch's private API" 6208*da0073e9SAndroid Build Coastguard Worker ): 6209*da0073e9SAndroid Build Coastguard Worker ( 6210*da0073e9SAndroid Build Coastguard Worker jacobians, 6211*da0073e9SAndroid Build Coastguard Worker reentrant, 6212*da0073e9SAndroid Build Coastguard Worker correct_grad_sizes, 6213*da0073e9SAndroid Build Coastguard Worker correct_grad_types, 6214*da0073e9SAndroid Build Coastguard Worker ) = get_analytical_jacobian((a,), outputs) 6215*da0073e9SAndroid Build Coastguard Worker self.assertFalse(reentrant) 6216*da0073e9SAndroid Build Coastguard Worker 6217*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Expected grad_out to be 1.0"): 6218*da0073e9SAndroid Build Coastguard Worker jacobians, _, _, _ = get_analytical_jacobian((a,), outputs, grad_out=2.0) 6219*da0073e9SAndroid Build Coastguard Worker 6220*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_custom_error(self): 6221*da0073e9SAndroid Build Coastguard Worker from torch.autograd.gradcheck import GradcheckError 6222*da0073e9SAndroid Build Coastguard Worker 6223*da0073e9SAndroid Build Coastguard Worker def check(fast_mode): 6224*da0073e9SAndroid Build Coastguard Worker def fn(x): 6225*da0073e9SAndroid Build Coastguard Worker y = x.clone() 6226*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda x: x + 1e-2) 6227*da0073e9SAndroid Build Coastguard Worker return y 6228*da0073e9SAndroid Build Coastguard Worker 6229*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2, requires_grad=True) 6230*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6231*da0073e9SAndroid Build Coastguard Worker GradcheckError, "Jacobian mismatch for output 0 with respect to input 0" 6232*da0073e9SAndroid Build Coastguard Worker ): 6233*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), fast_mode=fast_mode) 6234*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6235*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Jacobian mismatch for output 0 with respect to input 0" 6236*da0073e9SAndroid Build Coastguard Worker ): 6237*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), fast_mode=fast_mode) 6238*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 6239*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x,), raise_exception=False, fast_mode=fast_mode) 6240*da0073e9SAndroid Build Coastguard Worker ) 6241*da0073e9SAndroid Build Coastguard Worker 6242*da0073e9SAndroid Build Coastguard Worker def fn2(x): 6243*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Not a GradcheckError!") 6244*da0073e9SAndroid Build Coastguard Worker 6245*da0073e9SAndroid Build Coastguard Worker # Checks that when raise_exception=False, non-GradcheckErrors are not caught by gradcheck 6246*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Not a GradcheckError!"): 6247*da0073e9SAndroid Build Coastguard Worker gradcheck(fn2, (x,), fast_mode=fast_mode, raise_exception=False) 6248*da0073e9SAndroid Build Coastguard Worker 6249*da0073e9SAndroid Build Coastguard Worker check(fast_mode=True) 6250*da0073e9SAndroid Build Coastguard Worker check(fast_mode=False) 6251*da0073e9SAndroid Build Coastguard Worker 6252*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_forward_ad(self): 6253*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 6254*da0073e9SAndroid Build Coastguard Worker return x + y, y 6255*da0073e9SAndroid Build Coastguard Worker 6256*da0073e9SAndroid Build Coastguard Worker def bad_fn(x, y): 6257*da0073e9SAndroid Build Coastguard Worker # Hacky way to check if we're currently inside a forward ad level 6258*da0073e9SAndroid Build Coastguard Worker is_running_forward_ad = fwAD._current_level >= 0 6259*da0073e9SAndroid Build Coastguard Worker 6260*da0073e9SAndroid Build Coastguard Worker if is_running_forward_ad: 6261*da0073e9SAndroid Build Coastguard Worker y_p, y_d = fwAD.unpack_dual(y) 6262*da0073e9SAndroid Build Coastguard Worker y = fwAD.make_dual(y_p, y_d * 1.1) 6263*da0073e9SAndroid Build Coastguard Worker 6264*da0073e9SAndroid Build Coastguard Worker return x + y, y 6265*da0073e9SAndroid Build Coastguard Worker 6266*da0073e9SAndroid Build Coastguard Worker err_msg = "Jacobian computed with forward mode mismatch for output 0 with respect to input 1" 6267*da0073e9SAndroid Build Coastguard Worker 6268*da0073e9SAndroid Build Coastguard Worker for fast_mode in [True, False]: 6269*da0073e9SAndroid Build Coastguard Worker # Test for all inputs and outputs being real 6270*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, dtype=torch.double, requires_grad=True) 6271*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, dtype=torch.double, requires_grad=True) 6272*da0073e9SAndroid Build Coastguard Worker 6273*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6274*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 6275*da0073e9SAndroid Build Coastguard Worker gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6276*da0073e9SAndroid Build Coastguard Worker 6277*da0073e9SAndroid Build Coastguard Worker def basic_mul(x): 6278*da0073e9SAndroid Build Coastguard Worker return torch.view_as_real(torch.resolve_conj(x * 1j)) 6279*da0073e9SAndroid Build Coastguard Worker 6280*da0073e9SAndroid Build Coastguard Worker gradcheck(basic_mul, x, check_forward_ad=True, fast_mode=fast_mode) 6281*da0073e9SAndroid Build Coastguard Worker 6282*da0073e9SAndroid Build Coastguard Worker # Test for one input and one output being complex 6283*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, dtype=torch.cdouble, requires_grad=True) 6284*da0073e9SAndroid Build Coastguard Worker 6285*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6286*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 6287*da0073e9SAndroid Build Coastguard Worker gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6288*da0073e9SAndroid Build Coastguard Worker 6289*da0073e9SAndroid Build Coastguard Worker # Test for all inputs and outputs being complex 6290*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, dtype=torch.cdouble, requires_grad=True) 6291*da0073e9SAndroid Build Coastguard Worker 6292*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6293*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 6294*da0073e9SAndroid Build Coastguard Worker gradcheck(bad_fn, (x, y), check_forward_ad=True, fast_mode=fast_mode) 6295*da0073e9SAndroid Build Coastguard Worker 6296*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_forward_ad_runs_with_no_requires_grad(self): 6297*da0073e9SAndroid Build Coastguard Worker # Currently requires_grad is used as a easy way for gradcheck to know 6298*da0073e9SAndroid Build Coastguard Worker # which inputs of the function are meant to be differentiable 6299*da0073e9SAndroid Build Coastguard Worker # This test checks that when the inputs are passed to the function they should not have 6300*da0073e9SAndroid Build Coastguard Worker # requires_grad=True even though they may have requires_grad=True when passed 6301*da0073e9SAndroid Build Coastguard Worker # to gradcheck 6302*da0073e9SAndroid Build Coastguard Worker class UserFn(Function): 6303*da0073e9SAndroid Build Coastguard Worker @staticmethod 6304*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 6305*da0073e9SAndroid Build Coastguard Worker if fwAD._current_level >= 0: 6306*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.requires_grad) 6307*da0073e9SAndroid Build Coastguard Worker self.assertFalse(y.requires_grad) 6308*da0073e9SAndroid Build Coastguard Worker return x.clone(), y.clone() 6309*da0073e9SAndroid Build Coastguard Worker 6310*da0073e9SAndroid Build Coastguard Worker @staticmethod 6311*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_t, y_t): 6312*da0073e9SAndroid Build Coastguard Worker return x_t, y_t 6313*da0073e9SAndroid Build Coastguard Worker 6314*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, dtype=torch.double, requires_grad=True) 6315*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, dtype=torch.double, requires_grad=True) 6316*da0073e9SAndroid Build Coastguard Worker 6317*da0073e9SAndroid Build Coastguard Worker gradcheck( 6318*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6319*da0073e9SAndroid Build Coastguard Worker (x, y), 6320*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6321*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=False, 6322*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6323*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6324*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=False, 6325*da0073e9SAndroid Build Coastguard Worker ) 6326*da0073e9SAndroid Build Coastguard Worker 6327*da0073e9SAndroid Build Coastguard Worker gradcheck( 6328*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6329*da0073e9SAndroid Build Coastguard Worker (x, y), 6330*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6331*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=True, 6332*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6333*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6334*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=False, 6335*da0073e9SAndroid Build Coastguard Worker ) 6336*da0073e9SAndroid Build Coastguard Worker 6337*da0073e9SAndroid Build Coastguard Worker gradcheck( 6338*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6339*da0073e9SAndroid Build Coastguard Worker (x, y), 6340*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6341*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=True, 6342*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6343*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6344*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 6345*da0073e9SAndroid Build Coastguard Worker ) 6346*da0073e9SAndroid Build Coastguard Worker 6347*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, dtype=torch.double, requires_grad=True) 6348*da0073e9SAndroid Build Coastguard Worker y = torch.rand(2, dtype=torch.double, requires_grad=False) 6349*da0073e9SAndroid Build Coastguard Worker gradcheck( 6350*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6351*da0073e9SAndroid Build Coastguard Worker (x, y), 6352*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6353*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=True, 6354*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6355*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6356*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 6357*da0073e9SAndroid Build Coastguard Worker ) 6358*da0073e9SAndroid Build Coastguard Worker 6359*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_forward_ad_respects_requires_grad(self): 6360*da0073e9SAndroid Build Coastguard Worker # Currently requires_grad is used as a easy way for gradcheck to know 6361*da0073e9SAndroid Build Coastguard Worker # which inputs of the function are meant to be differentiable 6362*da0073e9SAndroid Build Coastguard Worker jvp_count = [0] 6363*da0073e9SAndroid Build Coastguard Worker 6364*da0073e9SAndroid Build Coastguard Worker class UserFn(Function): 6365*da0073e9SAndroid Build Coastguard Worker @staticmethod 6366*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 6367*da0073e9SAndroid Build Coastguard Worker return x.clone(), y.clone() 6368*da0073e9SAndroid Build Coastguard Worker 6369*da0073e9SAndroid Build Coastguard Worker @staticmethod 6370*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_t, y_t): 6371*da0073e9SAndroid Build Coastguard Worker jvp_count[0] += 1 6372*da0073e9SAndroid Build Coastguard Worker return x_t, y_t 6373*da0073e9SAndroid Build Coastguard Worker 6374*da0073e9SAndroid Build Coastguard Worker # NB: In slow gradcheck we need to loop through numel times so use numel = 1 to ensure 6375*da0073e9SAndroid Build Coastguard Worker # that fast and slow have the same counts 6376*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, dtype=torch.double, requires_grad=True) 6377*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, dtype=torch.double, requires_grad=True) 6378*da0073e9SAndroid Build Coastguard Worker gradcheck( 6379*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6380*da0073e9SAndroid Build Coastguard Worker (x, y), 6381*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6382*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=False, 6383*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6384*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6385*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=False, 6386*da0073e9SAndroid Build Coastguard Worker ) 6387*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jvp_count[0], 2) # (2) once per input 6388*da0073e9SAndroid Build Coastguard Worker jvp_count = [0] 6389*da0073e9SAndroid Build Coastguard Worker 6390*da0073e9SAndroid Build Coastguard Worker gradcheck( 6391*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6392*da0073e9SAndroid Build Coastguard Worker (x, y), 6393*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6394*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=True, 6395*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6396*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6397*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=False, 6398*da0073e9SAndroid Build Coastguard Worker ) 6399*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6400*da0073e9SAndroid Build Coastguard Worker jvp_count[0], 6 6401*da0073e9SAndroid Build Coastguard Worker ) # (+4): (once with normal ZT (+1), once with efficient ZT (+1)) for each input (x2) 6402*da0073e9SAndroid Build Coastguard Worker jvp_count = [0] 6403*da0073e9SAndroid Build Coastguard Worker 6404*da0073e9SAndroid Build Coastguard Worker gradcheck( 6405*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6406*da0073e9SAndroid Build Coastguard Worker (x, y), 6407*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6408*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=True, 6409*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6410*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6411*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 6412*da0073e9SAndroid Build Coastguard Worker ) 6413*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6414*da0073e9SAndroid Build Coastguard Worker jvp_count[0], 12 6415*da0073e9SAndroid Build Coastguard Worker ) # (+6): (compute batch of 2 with vmap (+1), with a loop (+2)) for each input (x2) 6416*da0073e9SAndroid Build Coastguard Worker jvp_count = [0] 6417*da0073e9SAndroid Build Coastguard Worker 6418*da0073e9SAndroid Build Coastguard Worker # Repeat the previous test except we mark one input with requires_grad=False 6419*da0073e9SAndroid Build Coastguard Worker # NB: _test_undefined_forward_mode is only (+1), when function has single differentiable input, not (+2)! 6420*da0073e9SAndroid Build Coastguard Worker # Otherwise, other counts are halved. 6421*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, dtype=torch.double, requires_grad=True) 6422*da0073e9SAndroid Build Coastguard Worker y = torch.rand(1, dtype=torch.double, requires_grad=False) 6423*da0073e9SAndroid Build Coastguard Worker gradcheck( 6424*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6425*da0073e9SAndroid Build Coastguard Worker (x, y), 6426*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6427*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=True, 6428*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6429*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6430*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 6431*da0073e9SAndroid Build Coastguard Worker ) 6432*da0073e9SAndroid Build Coastguard Worker self.assertEqual(jvp_count[0], 5) # 1 + 1 + 3 6433*da0073e9SAndroid Build Coastguard Worker 6434*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_check_forward_or_backward_only(self): 6435*da0073e9SAndroid Build Coastguard Worker """Depending on settings for check_forward_ad and check_backward_ad, the 6436*da0073e9SAndroid Build Coastguard Worker correct codepaths should be reached (or not reached) 6437*da0073e9SAndroid Build Coastguard Worker """ 6438*da0073e9SAndroid Build Coastguard Worker fwd_fail_err_msg = "FAIL FWD" 6439*da0073e9SAndroid Build Coastguard Worker bwd_fail_err_msg = "FAIL BWD" 6440*da0073e9SAndroid Build Coastguard Worker 6441*da0073e9SAndroid Build Coastguard Worker class UserFn(Function): 6442*da0073e9SAndroid Build Coastguard Worker @staticmethod 6443*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo, fwd_bad, bwd_bad): 6444*da0073e9SAndroid Build Coastguard Worker ctx.fwd_bad = fwd_bad 6445*da0073e9SAndroid Build Coastguard Worker ctx.bwd_bad = bwd_bad 6446*da0073e9SAndroid Build Coastguard Worker return foo * 2 6447*da0073e9SAndroid Build Coastguard Worker 6448*da0073e9SAndroid Build Coastguard Worker @staticmethod 6449*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, gO): 6450*da0073e9SAndroid Build Coastguard Worker if ctx.bwd_bad: 6451*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(bwd_fail_err_msg) 6452*da0073e9SAndroid Build Coastguard Worker else: 6453*da0073e9SAndroid Build Coastguard Worker return 2 * gO, None, None 6454*da0073e9SAndroid Build Coastguard Worker 6455*da0073e9SAndroid Build Coastguard Worker @staticmethod 6456*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, gI, _1, _2): 6457*da0073e9SAndroid Build Coastguard Worker if ctx.fwd_bad: 6458*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(fwd_fail_err_msg) 6459*da0073e9SAndroid Build Coastguard Worker else: 6460*da0073e9SAndroid Build Coastguard Worker return 2 * gI 6461*da0073e9SAndroid Build Coastguard Worker 6462*da0073e9SAndroid Build Coastguard Worker for fast_mode in (True, False): 6463*da0073e9SAndroid Build Coastguard Worker for check_forward_ad in (True, False): 6464*da0073e9SAndroid Build Coastguard Worker for check_backward_ad in (True, False): 6465*da0073e9SAndroid Build Coastguard Worker for fwd_bad in (True, False): 6466*da0073e9SAndroid Build Coastguard Worker for bwd_bad in (True, False): 6467*da0073e9SAndroid Build Coastguard Worker fwd_should_fail = fwd_bad and check_forward_ad 6468*da0073e9SAndroid Build Coastguard Worker bwd_should_fail = bwd_bad and check_backward_ad 6469*da0073e9SAndroid Build Coastguard Worker 6470*da0073e9SAndroid Build Coastguard Worker def run(): 6471*da0073e9SAndroid Build Coastguard Worker gradcheck( 6472*da0073e9SAndroid Build Coastguard Worker UserFn.apply, 6473*da0073e9SAndroid Build Coastguard Worker (x, fwd_bad, bwd_bad), 6474*da0073e9SAndroid Build Coastguard Worker check_forward_ad=check_forward_ad, 6475*da0073e9SAndroid Build Coastguard Worker check_backward_ad=check_backward_ad, 6476*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=check_backward_ad, 6477*da0073e9SAndroid Build Coastguard Worker check_batched_grad=check_backward_ad, 6478*da0073e9SAndroid Build Coastguard Worker fast_mode=fast_mode, 6479*da0073e9SAndroid Build Coastguard Worker ) 6480*da0073e9SAndroid Build Coastguard Worker 6481*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, dtype=torch.double, requires_grad=True) 6482*da0073e9SAndroid Build Coastguard Worker 6483*da0073e9SAndroid Build Coastguard Worker if not check_forward_ad and not check_backward_ad: 6484*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6485*da0073e9SAndroid Build Coastguard Worker AssertionError, "Expected at least one of" 6486*da0073e9SAndroid Build Coastguard Worker ): 6487*da0073e9SAndroid Build Coastguard Worker run() 6488*da0073e9SAndroid Build Coastguard Worker continue 6489*da0073e9SAndroid Build Coastguard Worker 6490*da0073e9SAndroid Build Coastguard Worker if not fwd_should_fail and not bwd_should_fail: 6491*da0073e9SAndroid Build Coastguard Worker run() 6492*da0073e9SAndroid Build Coastguard Worker else: 6493*da0073e9SAndroid Build Coastguard Worker # If both fail, backward AD failure "hides" forward AD failure 6494*da0073e9SAndroid Build Coastguard Worker if fwd_should_fail: 6495*da0073e9SAndroid Build Coastguard Worker fail_msg = fwd_fail_err_msg 6496*da0073e9SAndroid Build Coastguard Worker if bwd_should_fail: 6497*da0073e9SAndroid Build Coastguard Worker fail_msg = bwd_fail_err_msg 6498*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, fail_msg): 6499*da0073e9SAndroid Build Coastguard Worker run() 6500*da0073e9SAndroid Build Coastguard Worker 6501*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_forward_ad_batched_grad(self): 6502*da0073e9SAndroid Build Coastguard Worker x = torch.rand(2, dtype=torch.double, requires_grad=True) 6503*da0073e9SAndroid Build Coastguard Worker 6504*da0073e9SAndroid Build Coastguard Worker # multiple inputs and outputs with non-tensors inputs 6505*da0073e9SAndroid Build Coastguard Worker def fn1(a: torch.Tensor, b: int): 6506*da0073e9SAndroid Build Coastguard Worker return a.clone(), a + 1 6507*da0073e9SAndroid Build Coastguard Worker 6508*da0073e9SAndroid Build Coastguard Worker gradcheck( 6509*da0073e9SAndroid Build Coastguard Worker fn1, 6510*da0073e9SAndroid Build Coastguard Worker (x, 1), 6511*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6512*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6513*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6514*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=False, 6515*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 6516*da0073e9SAndroid Build Coastguard Worker ) 6517*da0073e9SAndroid Build Coastguard Worker 6518*da0073e9SAndroid Build Coastguard Worker # unrelated inputs: tangent for c is None 6519*da0073e9SAndroid Build Coastguard Worker def fn2(a: torch.Tensor, c: torch.Tensor): 6520*da0073e9SAndroid Build Coastguard Worker return a.clone() 6521*da0073e9SAndroid Build Coastguard Worker 6522*da0073e9SAndroid Build Coastguard Worker gradcheck( 6523*da0073e9SAndroid Build Coastguard Worker fn2, 6524*da0073e9SAndroid Build Coastguard Worker (x, x.clone()), 6525*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 6526*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 6527*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 6528*da0073e9SAndroid Build Coastguard Worker check_undefined_grad=False, 6529*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 6530*da0073e9SAndroid Build Coastguard Worker ) 6531*da0073e9SAndroid Build Coastguard Worker 6532*da0073e9SAndroid Build Coastguard Worker class Fn(Function): 6533*da0073e9SAndroid Build Coastguard Worker @staticmethod 6534*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo): 6535*da0073e9SAndroid Build Coastguard Worker return foo * 2 6536*da0073e9SAndroid Build Coastguard Worker 6537*da0073e9SAndroid Build Coastguard Worker @staticmethod 6538*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, gO): 6539*da0073e9SAndroid Build Coastguard Worker return gO * 2 6540*da0073e9SAndroid Build Coastguard Worker 6541*da0073e9SAndroid Build Coastguard Worker @staticmethod 6542*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, gI): 6543*da0073e9SAndroid Build Coastguard Worker torch.randn_like(gI) 6544*da0073e9SAndroid Build Coastguard Worker return gI * 2 6545*da0073e9SAndroid Build Coastguard Worker 6546*da0073e9SAndroid Build Coastguard Worker msg = "vmap: We do not yet support calling random operations inside of vmap" 6547*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 6548*da0073e9SAndroid Build Coastguard Worker gradcheck( 6549*da0073e9SAndroid Build Coastguard Worker Fn.apply, (x,), check_forward_ad=True, check_batched_forward_grad=True 6550*da0073e9SAndroid Build Coastguard Worker ) 6551*da0073e9SAndroid Build Coastguard Worker 6552*da0073e9SAndroid Build Coastguard Worker def test_version_counter(self): 6553*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2) 6554*da0073e9SAndroid Build Coastguard Worker 6555*da0073e9SAndroid Build Coastguard Worker # In-place op bumps version 6556*da0073e9SAndroid Build Coastguard Worker x_saved_version = x._version 6557*da0073e9SAndroid Build Coastguard Worker x.add_(1).add_(1) 6558*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version > x_saved_version) 6559*da0073e9SAndroid Build Coastguard Worker 6560*da0073e9SAndroid Build Coastguard Worker # Differentiable view shares version counter 6561*da0073e9SAndroid Build Coastguard Worker xz = x[:] 6562*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version == xz._version) 6563*da0073e9SAndroid Build Coastguard Worker xz.add_(1) 6564*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version == xz._version) 6565*da0073e9SAndroid Build Coastguard Worker 6566*da0073e9SAndroid Build Coastguard Worker # `x.data = y` preserves version counter of `x` 6567*da0073e9SAndroid Build Coastguard Worker x_saved_version = x._version 6568*da0073e9SAndroid Build Coastguard Worker x.data = torch.randn(2, 3) 6569*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version == x_saved_version) 6570*da0073e9SAndroid Build Coastguard Worker x.add_(1) 6571*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version > x_saved_version) 6572*da0073e9SAndroid Build Coastguard Worker # Make sure `x` is still using the same version counter it shares with `xz` 6573*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version == xz._version) 6574*da0073e9SAndroid Build Coastguard Worker 6575*da0073e9SAndroid Build Coastguard Worker # In-place op on `xz` also updates version of `x`, 6576*da0073e9SAndroid Build Coastguard Worker # because they share the version counter 6577*da0073e9SAndroid Build Coastguard Worker xz.add_(1) 6578*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x._version == xz._version) 6579*da0073e9SAndroid Build Coastguard Worker 6580*da0073e9SAndroid Build Coastguard Worker def test_set_data_tensorimpl_type(self): 6581*da0073e9SAndroid Build Coastguard Worker # Dense tensor has impl of type `TensorImpl`, while sparse tensor has impl 6582*da0073e9SAndroid Build Coastguard Worker # of type `SparseTensorImpl`. 6583*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, 2) 6584*da0073e9SAndroid Build Coastguard Worker x_s = torch.sparse_coo_tensor(torch.zeros([1, 1]), torch.ones([1])) 6585*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "incompatible tensor type"): 6586*da0073e9SAndroid Build Coastguard Worker x.data = x_s 6587*da0073e9SAndroid Build Coastguard Worker 6588*da0073e9SAndroid Build Coastguard Worker def test_set_data_preserve_pyobj(self): 6589*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 2) 6590*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1, 2) 6591*da0073e9SAndroid Build Coastguard Worker b_id_saved = id(b) 6592*da0073e9SAndroid Build Coastguard Worker b.data = a 6593*da0073e9SAndroid Build Coastguard Worker self.assertTrue(b_id_saved == id(b)) 6594*da0073e9SAndroid Build Coastguard Worker 6595*da0073e9SAndroid Build Coastguard Worker def test_set_data_self_requires_grad(self): 6596*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 6597*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(2.0) 6598*da0073e9SAndroid Build Coastguard Worker c = torch.tensor(3, dtype=torch.int64) 6599*da0073e9SAndroid Build Coastguard Worker a.data = b 6600*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6601*da0073e9SAndroid Build Coastguard Worker RuntimeError, "must be floating point or complex dtype" 6602*da0073e9SAndroid Build Coastguard Worker ): 6603*da0073e9SAndroid Build Coastguard Worker a.data = c 6604*da0073e9SAndroid Build Coastguard Worker 6605*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipping because doesn't work for windows") 6606*da0073e9SAndroid Build Coastguard Worker def test_thread_shutdown(self): 6607*da0073e9SAndroid Build Coastguard Worker code = """import torch 6608*da0073e9SAndroid Build Coastguard Workerfrom torch.autograd import Function 6609*da0073e9SAndroid Build Coastguard Workerclass MyFunction(Function): 6610*da0073e9SAndroid Build Coastguard Worker @staticmethod 6611*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 6612*da0073e9SAndroid Build Coastguard Worker return x 6613*da0073e9SAndroid Build Coastguard Worker 6614*da0073e9SAndroid Build Coastguard Worker @staticmethod 6615*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 6616*da0073e9SAndroid Build Coastguard Worker return grad 6617*da0073e9SAndroid Build Coastguard Worker 6618*da0073e9SAndroid Build Coastguard Worker# Run on cuda if it is available to ensure that the worker thread 6619*da0073e9SAndroid Build Coastguard Worker# is properly initialized by the time we exit. 6620*da0073e9SAndroid Build Coastguard Workerdevice = "cuda" if torch.cuda.is_available() else "cpu" 6621*da0073e9SAndroid Build Coastguard Worker 6622*da0073e9SAndroid Build Coastguard Workerfor shape in [(1,), ()]: 6623*da0073e9SAndroid Build Coastguard Worker v = torch.ones(shape, requires_grad=True, device=device) 6624*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(v).backward() 6625*da0073e9SAndroid Build Coastguard Worker""" 6626*da0073e9SAndroid Build Coastguard Worker s = TestCase.runWithPytorchAPIUsageStderr(code) 6627*da0073e9SAndroid Build Coastguard Worker # The autograd engine creates worker threads only when GPU devices are present. 6628*da0073e9SAndroid Build Coastguard Worker # So make sure that we do shutdown threads when we're testing cuda and make sure 6629*da0073e9SAndroid Build Coastguard Worker # that there is no thread to shutdown when we're not using cuda. 6630*da0073e9SAndroid Build Coastguard Worker if TEST_CUDA or torch.backends.mps.is_available() or torch.xpu.is_available(): 6631*da0073e9SAndroid Build Coastguard Worker self.assertRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") 6632*da0073e9SAndroid Build Coastguard Worker else: 6633*da0073e9SAndroid Build Coastguard Worker self.assertNotRegex(s, "PYTORCH_API_USAGE torch.autograd.thread_shutdown") 6634*da0073e9SAndroid Build Coastguard Worker 6635*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6636*da0073e9SAndroid Build Coastguard Worker IS_MACOS, 6637*da0073e9SAndroid Build Coastguard Worker "Fails with SIGBUS on macOS; https://github.com/pytorch/pytorch/issues/25941", 6638*da0073e9SAndroid Build Coastguard Worker ) 6639*da0073e9SAndroid Build Coastguard Worker def test_deep_reentrant(self): 6640*da0073e9SAndroid Build Coastguard Worker class DeepReentrant(Function): 6641*da0073e9SAndroid Build Coastguard Worker @staticmethod 6642*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 6643*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 6644*da0073e9SAndroid Build Coastguard Worker ctx.x = Variable(x.detach(), requires_grad=True) 6645*da0073e9SAndroid Build Coastguard Worker ctx.x = ctx.x - 1 6646*da0073e9SAndroid Build Coastguard Worker return ctx.x.detach() 6647*da0073e9SAndroid Build Coastguard Worker 6648*da0073e9SAndroid Build Coastguard Worker @staticmethod 6649*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 6650*da0073e9SAndroid Build Coastguard Worker if ctx.x < 0: 6651*da0073e9SAndroid Build Coastguard Worker return x 6652*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 6653*da0073e9SAndroid Build Coastguard Worker DeepReentrant.apply(ctx.x).sum().backward() 6654*da0073e9SAndroid Build Coastguard Worker return x 6655*da0073e9SAndroid Build Coastguard Worker 6656*da0073e9SAndroid Build Coastguard Worker # Test stack overflow escape mechanism 6657*da0073e9SAndroid Build Coastguard Worker v = torch.tensor(2000.0, requires_grad=True) 6658*da0073e9SAndroid Build Coastguard Worker # This will cause stack overflow if reentrant calls are handled 6659*da0073e9SAndroid Build Coastguard Worker # in the same thread recursively 6660*da0073e9SAndroid Build Coastguard Worker DeepReentrant.apply(v).sum().backward() 6661*da0073e9SAndroid Build Coastguard Worker 6662*da0073e9SAndroid Build Coastguard Worker # Test stack overflow escape mechanism multiple times 6663*da0073e9SAndroid Build Coastguard Worker # to ensure reusing workers in the pool works fine 6664*da0073e9SAndroid Build Coastguard Worker v2 = torch.tensor(200.0, requires_grad=True) 6665*da0073e9SAndroid Build Coastguard Worker DeepReentrant.apply(v2).sum().backward() 6666*da0073e9SAndroid Build Coastguard Worker 6667*da0073e9SAndroid Build Coastguard Worker def test_reentrant_priority(self): 6668*da0073e9SAndroid Build Coastguard Worker order = [] 6669*da0073e9SAndroid Build Coastguard Worker 6670*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 6671*da0073e9SAndroid Build Coastguard Worker @staticmethod 6672*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 6673*da0073e9SAndroid Build Coastguard Worker return x 6674*da0073e9SAndroid Build Coastguard Worker 6675*da0073e9SAndroid Build Coastguard Worker @staticmethod 6676*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 6677*da0073e9SAndroid Build Coastguard Worker order.append("MyFunction") 6678*da0073e9SAndroid Build Coastguard Worker return x 6679*da0073e9SAndroid Build Coastguard Worker 6680*da0073e9SAndroid Build Coastguard Worker class Reentrant(Function): 6681*da0073e9SAndroid Build Coastguard Worker @staticmethod 6682*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 6683*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 6684*da0073e9SAndroid Build Coastguard Worker ctx.x = Variable(x.detach(), requires_grad=True) 6685*da0073e9SAndroid Build Coastguard Worker ctx.x = ctx.x - 1 6686*da0073e9SAndroid Build Coastguard Worker return ctx.x.detach() 6687*da0073e9SAndroid Build Coastguard Worker 6688*da0073e9SAndroid Build Coastguard Worker @staticmethod 6689*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 6690*da0073e9SAndroid Build Coastguard Worker order.append("Reentrant") 6691*da0073e9SAndroid Build Coastguard Worker if ctx.x < 0: 6692*da0073e9SAndroid Build Coastguard Worker return x 6693*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 6694*da0073e9SAndroid Build Coastguard Worker Reentrant.apply(ctx.x).backward() 6695*da0073e9SAndroid Build Coastguard Worker return x 6696*da0073e9SAndroid Build Coastguard Worker 6697*da0073e9SAndroid Build Coastguard Worker a = MyFunction.apply(torch.tensor(6.0, requires_grad=True)) 6698*da0073e9SAndroid Build Coastguard Worker b = Reentrant.apply(torch.tensor(9.0, requires_grad=True)) 6699*da0073e9SAndroid Build Coastguard Worker v = a * b 6700*da0073e9SAndroid Build Coastguard Worker v.backward() 6701*da0073e9SAndroid Build Coastguard Worker # The tasks for the Reentrant and MyFunction backward() will be added 6702*da0073e9SAndroid Build Coastguard Worker # to the queue in the autograd engine at the same time. The backward 6703*da0073e9SAndroid Build Coastguard Worker # for Reentrant will be executed first, which will then add other 6704*da0073e9SAndroid Build Coastguard Worker # backward tasks to the queue. We want to ensure all the reentrant tasks 6705*da0073e9SAndroid Build Coastguard Worker # are prioritized over the MyFunction backward task regardless of their 6706*da0073e9SAndroid Build Coastguard Worker # sequence numbers 6707*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(order), 11) 6708*da0073e9SAndroid Build Coastguard Worker self.assertEqual(order.count("Reentrant"), 10) 6709*da0073e9SAndroid Build Coastguard Worker self.assertEqual(order[-1], "MyFunction") 6710*da0073e9SAndroid Build Coastguard Worker 6711*da0073e9SAndroid Build Coastguard Worker @slowTest 6712*da0073e9SAndroid Build Coastguard Worker def test_checkpointing(self): 6713*da0073e9SAndroid Build Coastguard Worker num_inp = 2000 6714*da0073e9SAndroid Build Coastguard Worker nz_inp = 10 6715*da0073e9SAndroid Build Coastguard Worker nz_out = 10 6716*da0073e9SAndroid Build Coastguard Worker nz_bottleneck = 1000 6717*da0073e9SAndroid Build Coastguard Worker 6718*da0073e9SAndroid Build Coastguard Worker # small proxy network for some complex reasoning we want to do per input 6719*da0073e9SAndroid Build Coastguard Worker module = nn.Sequential( 6720*da0073e9SAndroid Build Coastguard Worker nn.Linear(nz_inp, nz_bottleneck), 6721*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 6722*da0073e9SAndroid Build Coastguard Worker nn.Linear(nz_bottleneck, nz_inp), 6723*da0073e9SAndroid Build Coastguard Worker ) 6724*da0073e9SAndroid Build Coastguard Worker 6725*da0073e9SAndroid Build Coastguard Worker feat_combined = [] 6726*da0073e9SAndroid Build Coastguard Worker for r in range(num_inp): 6727*da0073e9SAndroid Build Coastguard Worker data_r = torch.empty(1, nz_inp) 6728*da0073e9SAndroid Build Coastguard Worker data_r.uniform_() 6729*da0073e9SAndroid Build Coastguard Worker data_r.requires_grad = True 6730*da0073e9SAndroid Build Coastguard Worker feat_r = checkpoint(module, data_r, use_reentrant=True) 6731*da0073e9SAndroid Build Coastguard Worker feat_combined.append(feat_r) 6732*da0073e9SAndroid Build Coastguard Worker 6733*da0073e9SAndroid Build Coastguard Worker # compute mean as a proxy for some joint reasoning 6734*da0073e9SAndroid Build Coastguard Worker mean_combined = torch.stack(feat_combined).mean() 6735*da0073e9SAndroid Build Coastguard Worker mean_combined.backward() 6736*da0073e9SAndroid Build Coastguard Worker 6737*da0073e9SAndroid Build Coastguard Worker def _test_checkpointing_non_reentrant_autocast(self, device_type): 6738*da0073e9SAndroid Build Coastguard Worker for enabled in [True, False]: 6739*da0073e9SAndroid Build Coastguard Worker 6740*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 6741*da0073e9SAndroid Build Coastguard Worker # torch.mm is on autocast's list of ops that should run in 6742*da0073e9SAndroid Build Coastguard Worker # the autocast precision 6743*da0073e9SAndroid Build Coastguard Worker x = torch.mm(x, y) 6744*da0073e9SAndroid Build Coastguard Worker y = torch.mm(x, z) 6745*da0073e9SAndroid Build Coastguard Worker z = torch.mm(z, z) 6746*da0073e9SAndroid Build Coastguard Worker expected_dtype = torch.float32 if not enabled else torch.bfloat16 6747*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_dtype, z.dtype) 6748*da0073e9SAndroid Build Coastguard Worker return z 6749*da0073e9SAndroid Build Coastguard Worker 6750*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=True) 6751*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, 3, requires_grad=True) 6752*da0073e9SAndroid Build Coastguard Worker z = torch.randn(3, 3, requires_grad=True) 6753*da0073e9SAndroid Build Coastguard Worker if device_type == "cuda": 6754*da0073e9SAndroid Build Coastguard Worker x = x.cuda() 6755*da0073e9SAndroid Build Coastguard Worker y = y.cuda() 6756*da0073e9SAndroid Build Coastguard Worker z = z.cuda() 6757*da0073e9SAndroid Build Coastguard Worker 6758*da0073e9SAndroid Build Coastguard Worker with torch.autocast( 6759*da0073e9SAndroid Build Coastguard Worker enabled=enabled, device_type=device_type, dtype=torch.bfloat16 6760*da0073e9SAndroid Build Coastguard Worker ): 6761*da0073e9SAndroid Build Coastguard Worker loss = checkpoint(foo, x, y, z, use_reentrant=False) 6762*da0073e9SAndroid Build Coastguard Worker loss = loss.sum() 6763*da0073e9SAndroid Build Coastguard Worker 6764*da0073e9SAndroid Build Coastguard Worker # Without saving + recasting the autocast type, would raise error in autograd 6765*da0073e9SAndroid Build Coastguard Worker # about mismatched dtypes. 6766*da0073e9SAndroid Build Coastguard Worker loss.backward() # triggers recomputation to check it runs in bfloat 6767*da0073e9SAndroid Build Coastguard Worker 6768*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_non_reentrant_autocast_cpu(self): 6769*da0073e9SAndroid Build Coastguard Worker """ 6770*da0073e9SAndroid Build Coastguard Worker Test that autocast args such as the dtype are preserved during non-reentrant 6771*da0073e9SAndroid Build Coastguard Worker checkpoint recomputation on CPU. 6772*da0073e9SAndroid Build Coastguard Worker """ 6773*da0073e9SAndroid Build Coastguard Worker self._test_checkpointing_non_reentrant_autocast(device_type="cpu") 6774*da0073e9SAndroid Build Coastguard Worker 6775*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf( 6776*da0073e9SAndroid Build Coastguard Worker not torch.cuda.is_available() or not torch.cuda.is_bf16_supported(), 6777*da0073e9SAndroid Build Coastguard Worker "Test requires CUDA bf16 support", 6778*da0073e9SAndroid Build Coastguard Worker ) 6779*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_non_reentrant_autocast_gpu(self): 6780*da0073e9SAndroid Build Coastguard Worker """ 6781*da0073e9SAndroid Build Coastguard Worker Test that autocast args/kwargs such as the dtype are preserved during 6782*da0073e9SAndroid Build Coastguard Worker non-reentrant checkpoint recomputation on GPU. 6783*da0073e9SAndroid Build Coastguard Worker """ 6784*da0073e9SAndroid Build Coastguard Worker self._test_checkpointing_non_reentrant_autocast(device_type="cuda") 6785*da0073e9SAndroid Build Coastguard Worker 6786*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA") 6787*da0073e9SAndroid Build Coastguard Worker @slowTest 6788*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_memory_savings(self): 6789*da0073e9SAndroid Build Coastguard Worker class MyModel(nn.Module): 6790*da0073e9SAndroid Build Coastguard Worker def __init__(self, n, use_checkpoint, use_reentrant): 6791*da0073e9SAndroid Build Coastguard Worker super().__init__() 6792*da0073e9SAndroid Build Coastguard Worker self.n = n 6793*da0073e9SAndroid Build Coastguard Worker self.use_checkpoint = use_checkpoint 6794*da0073e9SAndroid Build Coastguard Worker self.use_reentrant = use_reentrant 6795*da0073e9SAndroid Build Coastguard Worker self.layers = nn.ModuleList() 6796*da0073e9SAndroid Build Coastguard Worker for i in range(self.n): 6797*da0073e9SAndroid Build Coastguard Worker layer = nn.Sequential( 6798*da0073e9SAndroid Build Coastguard Worker nn.Linear(256, 256), nn.Linear(256, 256), nn.Linear(256, 256) 6799*da0073e9SAndroid Build Coastguard Worker ) 6800*da0073e9SAndroid Build Coastguard Worker self.layers.append(layer) 6801*da0073e9SAndroid Build Coastguard Worker # pre-allocate the grad so that increased memory usage is mainly 6802*da0073e9SAndroid Build Coastguard Worker # due to activations. 6803*da0073e9SAndroid Build Coastguard Worker for layer in self.layers: 6804*da0073e9SAndroid Build Coastguard Worker for lin in layer: 6805*da0073e9SAndroid Build Coastguard Worker lin.weight.grad = torch.ones_like(lin.weight) 6806*da0073e9SAndroid Build Coastguard Worker lin.bias.grad = torch.ones_like(lin.bias) 6807*da0073e9SAndroid Build Coastguard Worker 6808*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 6809*da0073e9SAndroid Build Coastguard Worker for i in range(self.n): 6810*da0073e9SAndroid Build Coastguard Worker if not self.use_checkpoint: 6811*da0073e9SAndroid Build Coastguard Worker x = self.layers[i](x) 6812*da0073e9SAndroid Build Coastguard Worker else: 6813*da0073e9SAndroid Build Coastguard Worker x = checkpoint( 6814*da0073e9SAndroid Build Coastguard Worker self.layers[i], x, use_reentrant=self.use_reentrant 6815*da0073e9SAndroid Build Coastguard Worker ) 6816*da0073e9SAndroid Build Coastguard Worker 6817*da0073e9SAndroid Build Coastguard Worker return x 6818*da0073e9SAndroid Build Coastguard Worker 6819*da0073e9SAndroid Build Coastguard Worker model_no_checkpoint = MyModel( 6820*da0073e9SAndroid Build Coastguard Worker 8, use_checkpoint=False, use_reentrant=False 6821*da0073e9SAndroid Build Coastguard Worker ).cuda() 6822*da0073e9SAndroid Build Coastguard Worker model_reentrant_checkpoint = MyModel( 6823*da0073e9SAndroid Build Coastguard Worker 8, use_checkpoint=True, use_reentrant=True 6824*da0073e9SAndroid Build Coastguard Worker ).cuda() 6825*da0073e9SAndroid Build Coastguard Worker model_no_reentrant_checkpoint = MyModel( 6826*da0073e9SAndroid Build Coastguard Worker 8, use_checkpoint=True, use_reentrant=False 6827*da0073e9SAndroid Build Coastguard Worker ).cuda() 6828*da0073e9SAndroid Build Coastguard Worker 6829*da0073e9SAndroid Build Coastguard Worker x = torch.randn(100, 256, requires_grad=True, device="cuda") 6830*da0073e9SAndroid Build Coastguard Worker 6831*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 6832*da0073e9SAndroid Build Coastguard Worker loss = model_no_checkpoint(x.clone()).sum() 6833*da0073e9SAndroid Build Coastguard Worker loss.backward() 6834*da0073e9SAndroid Build Coastguard Worker mem_no_checkpoint = torch.cuda.max_memory_allocated() 6835*da0073e9SAndroid Build Coastguard Worker 6836*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 6837*da0073e9SAndroid Build Coastguard Worker loss = model_reentrant_checkpoint(x.clone()).sum() 6838*da0073e9SAndroid Build Coastguard Worker loss.backward() 6839*da0073e9SAndroid Build Coastguard Worker mem_reentrant_checkpoint = torch.cuda.max_memory_allocated() 6840*da0073e9SAndroid Build Coastguard Worker 6841*da0073e9SAndroid Build Coastguard Worker torch.cuda.reset_peak_memory_stats() 6842*da0073e9SAndroid Build Coastguard Worker loss = model_no_reentrant_checkpoint(x.clone()).sum() 6843*da0073e9SAndroid Build Coastguard Worker loss.backward() 6844*da0073e9SAndroid Build Coastguard Worker mem_no_reentrant_checkpoint = torch.cuda.max_memory_allocated() 6845*da0073e9SAndroid Build Coastguard Worker 6846*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mem_reentrant_checkpoint < mem_no_checkpoint) 6847*da0073e9SAndroid Build Coastguard Worker self.assertTrue(mem_no_reentrant_checkpoint < mem_no_checkpoint) 6848*da0073e9SAndroid Build Coastguard Worker 6849*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_custom_function_works(self): 6850*da0073e9SAndroid Build Coastguard Worker msg = "Unpack is being triggered for a tensor that was already unpacked once" 6851*da0073e9SAndroid Build Coastguard Worker 6852*da0073e9SAndroid Build Coastguard Worker class MyFunc(torch.autograd.Function): 6853*da0073e9SAndroid Build Coastguard Worker @staticmethod 6854*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y, z): 6855*da0073e9SAndroid Build Coastguard Worker w = x * y * z 6856*da0073e9SAndroid Build Coastguard Worker out = w + w 6857*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, y, z, w, out) 6858*da0073e9SAndroid Build Coastguard Worker return out 6859*da0073e9SAndroid Build Coastguard Worker 6860*da0073e9SAndroid Build Coastguard Worker @staticmethod 6861*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_out): 6862*da0073e9SAndroid Build Coastguard Worker x, y, z, w, out = ctx.saved_tensors 6863*da0073e9SAndroid Build Coastguard Worker # Accessing the saved Tensors a second time will raise because 6864*da0073e9SAndroid Build Coastguard Worker # recomputed tensors get cleared as soon as they are unpacked. 6865*da0073e9SAndroid Build Coastguard Worker # A recomputation is only triggered if your backward has a new 6866*da0073e9SAndroid Build Coastguard Worker # graph-task id. 6867*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 6868*da0073e9SAndroid Build Coastguard Worker x_2, y_2, z_2, w_2, out_2 = ctx.saved_tensors 6869*da0073e9SAndroid Build Coastguard Worker return x, y, z 6870*da0073e9SAndroid Build Coastguard Worker 6871*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1.0, requires_grad=True) 6872*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(2.0, requires_grad=True) 6873*da0073e9SAndroid Build Coastguard Worker z = torch.tensor(3.0, requires_grad=True) 6874*da0073e9SAndroid Build Coastguard Worker 6875*da0073e9SAndroid Build Coastguard Worker def foo(x, y, z): 6876*da0073e9SAndroid Build Coastguard Worker x = x * y * z 6877*da0073e9SAndroid Build Coastguard Worker y = y * y * z 6878*da0073e9SAndroid Build Coastguard Worker z = z * z 6879*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(x, y, z) 6880*da0073e9SAndroid Build Coastguard Worker return out 6881*da0073e9SAndroid Build Coastguard Worker 6882*da0073e9SAndroid Build Coastguard Worker out = checkpoint(foo, x, y, z, use_reentrant=False) 6883*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 6884*da0073e9SAndroid Build Coastguard Worker 6885*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_with_context_fn(self): 6886*da0073e9SAndroid Build Coastguard Worker class VerboseTorchDispatchMode(TorchDispatchMode): 6887*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 6888*da0073e9SAndroid Build Coastguard Worker self.operators = [] 6889*da0073e9SAndroid Build Coastguard Worker 6890*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 6891*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 6892*da0073e9SAndroid Build Coastguard Worker kwargs = {} 6893*da0073e9SAndroid Build Coastguard Worker self.operators.append(func.__name__) 6894*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 6895*da0073e9SAndroid Build Coastguard Worker 6896*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1.0, requires_grad=True) 6897*da0073e9SAndroid Build Coastguard Worker verbose_mode = VerboseTorchDispatchMode() 6898*da0073e9SAndroid Build Coastguard Worker 6899*da0073e9SAndroid Build Coastguard Worker def context_fn(): 6900*da0073e9SAndroid Build Coastguard Worker return verbose_mode, contextlib.nullcontext() 6901*da0073e9SAndroid Build Coastguard Worker 6902*da0073e9SAndroid Build Coastguard Worker out = checkpoint( 6903*da0073e9SAndroid Build Coastguard Worker lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn 6904*da0073e9SAndroid Build Coastguard Worker ) 6905*da0073e9SAndroid Build Coastguard Worker self.assertEqual(verbose_mode.operators, ["exp.default"]) 6906*da0073e9SAndroid Build Coastguard Worker 6907*da0073e9SAndroid Build Coastguard Worker verbose_mode.operators = [] 6908*da0073e9SAndroid Build Coastguard Worker 6909*da0073e9SAndroid Build Coastguard Worker def context_fn(): 6910*da0073e9SAndroid Build Coastguard Worker return contextlib.nullcontext(), verbose_mode 6911*da0073e9SAndroid Build Coastguard Worker 6912*da0073e9SAndroid Build Coastguard Worker out = checkpoint( 6913*da0073e9SAndroid Build Coastguard Worker lambda x: x.exp(), x, use_reentrant=False, context_fn=context_fn 6914*da0073e9SAndroid Build Coastguard Worker ) 6915*da0073e9SAndroid Build Coastguard Worker out.backward() 6916*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 6917*da0073e9SAndroid Build Coastguard Worker verbose_mode.operators, ["exp.default", "detach.default", "detach.default"] 6918*da0073e9SAndroid Build Coastguard Worker ) 6919*da0073e9SAndroid Build Coastguard Worker 6920*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6921*da0073e9SAndroid Build Coastguard Worker Exception, "only supported when use_reentrant=False" 6922*da0073e9SAndroid Build Coastguard Worker ): 6923*da0073e9SAndroid Build Coastguard Worker out = checkpoint( 6924*da0073e9SAndroid Build Coastguard Worker lambda x: x.sin(), x, use_reentrant=True, context_fn=context_fn 6925*da0073e9SAndroid Build Coastguard Worker ) 6926*da0073e9SAndroid Build Coastguard Worker 6927*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_warns_if_use_reentrant_not_passed_explcitly(self): 6928*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, requires_grad=True) 6929*da0073e9SAndroid Build Coastguard Worker 6930*da0073e9SAndroid Build Coastguard Worker # Passing explicitly should not warn 6931*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(lambda: checkpoint(lambda x: x, a, use_reentrant=False)) 6932*da0073e9SAndroid Build Coastguard Worker 6933*da0073e9SAndroid Build Coastguard Worker # Not passing explicitly warns 6934*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 6935*da0073e9SAndroid Build Coastguard Worker UserWarning, ".*the use_reentrant parameter should be passed explicitly.*" 6936*da0073e9SAndroid Build Coastguard Worker ): 6937*da0073e9SAndroid Build Coastguard Worker checkpoint(lambda x: x, a) 6938*da0073e9SAndroid Build Coastguard Worker 6939*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_sequential_warns_if_use_reentrant_not_passed_explcitly(self): 6940*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, requires_grad=True) 6941*da0073e9SAndroid Build Coastguard Worker modules_list = [ 6942*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 3), 6943*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 3), 6944*da0073e9SAndroid Build Coastguard Worker torch.nn.Linear(3, 3), 6945*da0073e9SAndroid Build Coastguard Worker ] 6946*da0073e9SAndroid Build Coastguard Worker 6947*da0073e9SAndroid Build Coastguard Worker # Passing explicitly should not warn 6948*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn( 6949*da0073e9SAndroid Build Coastguard Worker lambda: checkpoint_sequential(modules_list, 3, a, use_reentrant=False) 6950*da0073e9SAndroid Build Coastguard Worker ) 6951*da0073e9SAndroid Build Coastguard Worker 6952*da0073e9SAndroid Build Coastguard Worker # Not passing explicitly warns 6953*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsOnceRegex( 6954*da0073e9SAndroid Build Coastguard Worker UserWarning, ".*the use_reentrant parameter should be passed explicitly.*" 6955*da0073e9SAndroid Build Coastguard Worker ): 6956*da0073e9SAndroid Build Coastguard Worker checkpoint_sequential(modules_list, 3, a) 6957*da0073e9SAndroid Build Coastguard Worker 6958*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_detects_non_determinism(self): 6959*da0073e9SAndroid Build Coastguard Worker def save_3_tensors(x): 6960*da0073e9SAndroid Build Coastguard Worker out = x.sin().exp() 6961*da0073e9SAndroid Build Coastguard Worker out = out.sin() 6962*da0073e9SAndroid Build Coastguard Worker return out 6963*da0073e9SAndroid Build Coastguard Worker 6964*da0073e9SAndroid Build Coastguard Worker def save_2_tensors(x): 6965*da0073e9SAndroid Build Coastguard Worker return x.sin().exp() 6966*da0073e9SAndroid Build Coastguard Worker 6967*da0073e9SAndroid Build Coastguard Worker def save_2_tensors_alt(x): 6968*da0073e9SAndroid Build Coastguard Worker return x.sin() * torch.tensor([1.0, 2.0]) 6969*da0073e9SAndroid Build Coastguard Worker 6970*da0073e9SAndroid Build Coastguard Worker def get_non_det_fn(orig_fn, recompute_fn): 6971*da0073e9SAndroid Build Coastguard Worker counter = [0] 6972*da0073e9SAndroid Build Coastguard Worker 6973*da0073e9SAndroid Build Coastguard Worker def fn(x): 6974*da0073e9SAndroid Build Coastguard Worker if counter[0] == 0: 6975*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 6976*da0073e9SAndroid Build Coastguard Worker return orig_fn(x) 6977*da0073e9SAndroid Build Coastguard Worker else: 6978*da0073e9SAndroid Build Coastguard Worker return recompute_fn(x) 6979*da0073e9SAndroid Build Coastguard Worker 6980*da0073e9SAndroid Build Coastguard Worker return fn 6981*da0073e9SAndroid Build Coastguard Worker 6982*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, requires_grad=True) 6983*da0073e9SAndroid Build Coastguard Worker 6984*da0073e9SAndroid Build Coastguard Worker # Save fewer tensors during recompute 6985*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_3_tensors, recompute_fn=save_2_tensors) 6986*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6987*da0073e9SAndroid Build Coastguard Worker RuntimeError, "A different number of tensors was saved" 6988*da0073e9SAndroid Build Coastguard Worker ): 6989*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 6990*da0073e9SAndroid Build Coastguard Worker out.backward() 6991*da0073e9SAndroid Build Coastguard Worker 6992*da0073e9SAndroid Build Coastguard Worker # Save more tensors during recompute 6993*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors) 6994*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(False): 6995*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 6996*da0073e9SAndroid Build Coastguard Worker RuntimeError, "trying to save more tensors during recomputation" 6997*da0073e9SAndroid Build Coastguard Worker ): 6998*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 6999*da0073e9SAndroid Build Coastguard Worker out.backward() 7000*da0073e9SAndroid Build Coastguard Worker 7001*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_3_tensors) 7002*da0073e9SAndroid Build Coastguard Worker # If early stopping is enabled, we would not raise (the results would be correct anyway) 7003*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 7004*da0073e9SAndroid Build Coastguard Worker out.backward() 7005*da0073e9SAndroid Build Coastguard Worker 7006*da0073e9SAndroid Build Coastguard Worker # Save the same number of tensors but the shape is different 7007*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7008*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "tensors have different metadata"): 7009*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 7010*da0073e9SAndroid Build Coastguard Worker out.backward() 7011*da0073e9SAndroid Build Coastguard Worker 7012*da0073e9SAndroid Build Coastguard Worker # Get the debug message if debug=True 7013*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7014*da0073e9SAndroid Build Coastguard Worker 7015*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7016*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7017*da0073e9SAndroid Build Coastguard Worker "You are seeing this error because you passed `debug=True` to checkpoint", 7018*da0073e9SAndroid Build Coastguard Worker ): 7019*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False, debug=True) 7020*da0073e9SAndroid Build Coastguard Worker out.backward() 7021*da0073e9SAndroid Build Coastguard Worker 7022*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7023*da0073e9SAndroid Build Coastguard Worker 7024*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7025*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7026*da0073e9SAndroid Build Coastguard Worker "You are seeing this error because you passed `debug=True` to checkpoint", 7027*da0073e9SAndroid Build Coastguard Worker ): 7028*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_debug_enabled(True): 7029*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False, debug=False) 7030*da0073e9SAndroid Build Coastguard Worker out.backward() 7031*da0073e9SAndroid Build Coastguard Worker 7032*da0073e9SAndroid Build Coastguard Worker fn = get_non_det_fn(orig_fn=save_2_tensors, recompute_fn=save_2_tensors_alt) 7033*da0073e9SAndroid Build Coastguard Worker 7034*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7035*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Recomputed values for the following tensors have different" 7036*da0073e9SAndroid Build Coastguard Worker ): 7037*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_debug_enabled(False): 7038*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False, debug=True) 7039*da0073e9SAndroid Build Coastguard Worker out.backward() 7040*da0073e9SAndroid Build Coastguard Worker 7041*da0073e9SAndroid Build Coastguard Worker def test_access_saved_tensor_twice_without_recomputation_works(self): 7042*da0073e9SAndroid Build Coastguard Worker count = [0] 7043*da0073e9SAndroid Build Coastguard Worker 7044*da0073e9SAndroid Build Coastguard Worker def foo(a): 7045*da0073e9SAndroid Build Coastguard Worker count[0] += 1 7046*da0073e9SAndroid Build Coastguard Worker b = a * a 7047*da0073e9SAndroid Build Coastguard Worker c = a * b 7048*da0073e9SAndroid Build Coastguard Worker d = torch.exp(a) 7049*da0073e9SAndroid Build Coastguard Worker return d 7050*da0073e9SAndroid Build Coastguard Worker 7051*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 7052*da0073e9SAndroid Build Coastguard Worker d = checkpoint(foo, a, use_reentrant=False) 7053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 1) 7054*da0073e9SAndroid Build Coastguard Worker # Recomputed variables only persist within a particular backward call. 7055*da0073e9SAndroid Build Coastguard Worker # If _saved_result is accessed outside of a backward, it will trigger 7056*da0073e9SAndroid Build Coastguard Worker # a recompute. And afterwards, those recomputed results are immediately 7057*da0073e9SAndroid Build Coastguard Worker # cleared. 7058*da0073e9SAndroid Build Coastguard Worker d.grad_fn._saved_result 7059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 2) 7060*da0073e9SAndroid Build Coastguard Worker # Second access will trigger another recompute 7061*da0073e9SAndroid Build Coastguard Worker d.grad_fn._saved_result 7062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 3) 7063*da0073e9SAndroid Build Coastguard Worker # Backward clears the saved variable 7064*da0073e9SAndroid Build Coastguard Worker d.sum().backward() 7065*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 4) 7066*da0073e9SAndroid Build Coastguard Worker # Now it raises an error 7067*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7068*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7069*da0073e9SAndroid Build Coastguard Worker "or directly access saved tensors after they have already been freed", 7070*da0073e9SAndroid Build Coastguard Worker ): 7071*da0073e9SAndroid Build Coastguard Worker d.grad_fn._saved_result 7072*da0073e9SAndroid Build Coastguard Worker 7073*da0073e9SAndroid Build Coastguard Worker @slowTest 7074*da0073e9SAndroid Build Coastguard Worker @parametrize("input_requires_grad", [True, False]) 7075*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant(self, input_requires_grad): 7076*da0073e9SAndroid Build Coastguard Worker """ 7077*da0073e9SAndroid Build Coastguard Worker Basic test for checkpoint without reentrant autograd. 7078*da0073e9SAndroid Build Coastguard Worker """ 7079*da0073e9SAndroid Build Coastguard Worker num_inp = 2000 7080*da0073e9SAndroid Build Coastguard Worker nz_inp = 10 7081*da0073e9SAndroid Build Coastguard Worker nz_out = 10 7082*da0073e9SAndroid Build Coastguard Worker nz_bottleneck = 1000 7083*da0073e9SAndroid Build Coastguard Worker 7084*da0073e9SAndroid Build Coastguard Worker # small proxy network for some complex reasoning we want to do per input 7085*da0073e9SAndroid Build Coastguard Worker module = nn.Sequential( 7086*da0073e9SAndroid Build Coastguard Worker nn.Linear(nz_inp, nz_bottleneck), 7087*da0073e9SAndroid Build Coastguard Worker nn.ReLU(), 7088*da0073e9SAndroid Build Coastguard Worker nn.Linear(nz_bottleneck, nz_inp), 7089*da0073e9SAndroid Build Coastguard Worker ) 7090*da0073e9SAndroid Build Coastguard Worker 7091*da0073e9SAndroid Build Coastguard Worker # Module holder for testing activation checkpointing with no_reentrant 7092*da0073e9SAndroid Build Coastguard Worker # supports kwargs. 7093*da0073e9SAndroid Build Coastguard Worker class MyModule(nn.Module): 7094*da0073e9SAndroid Build Coastguard Worker def __init__(self, mod): 7095*da0073e9SAndroid Build Coastguard Worker super().__init__() 7096*da0073e9SAndroid Build Coastguard Worker self.module = mod 7097*da0073e9SAndroid Build Coastguard Worker 7098*da0073e9SAndroid Build Coastguard Worker def forward(self, data): 7099*da0073e9SAndroid Build Coastguard Worker return self.module(data) 7100*da0073e9SAndroid Build Coastguard Worker 7101*da0073e9SAndroid Build Coastguard Worker module = MyModule(mod=module) 7102*da0073e9SAndroid Build Coastguard Worker 7103*da0073e9SAndroid Build Coastguard Worker # Run model with and without checkpointing and verify gradients are 7104*da0073e9SAndroid Build Coastguard Worker # equivalent, regardless of if inputs require grads or not. 7105*da0073e9SAndroid Build Coastguard Worker module_copy = deepcopy(module) 7106*da0073e9SAndroid Build Coastguard Worker 7107*da0073e9SAndroid Build Coastguard Worker feat_combined = [] 7108*da0073e9SAndroid Build Coastguard Worker feat_combined_no_checkpoint = [] 7109*da0073e9SAndroid Build Coastguard Worker for r in range(num_inp): 7110*da0073e9SAndroid Build Coastguard Worker data_r = torch.empty(1, nz_inp) 7111*da0073e9SAndroid Build Coastguard Worker data_r.uniform_() 7112*da0073e9SAndroid Build Coastguard Worker data_r.requires_grad = input_requires_grad 7113*da0073e9SAndroid Build Coastguard Worker data_r_copy = data_r.clone() 7114*da0073e9SAndroid Build Coastguard Worker feat_r = checkpoint(module, data=data_r, use_reentrant=False) 7115*da0073e9SAndroid Build Coastguard Worker feat_combined.append(feat_r) 7116*da0073e9SAndroid Build Coastguard Worker feat_r_no_checkpoint = module_copy(data_r) 7117*da0073e9SAndroid Build Coastguard Worker feat_combined_no_checkpoint.append(feat_r_no_checkpoint) 7118*da0073e9SAndroid Build Coastguard Worker 7119*da0073e9SAndroid Build Coastguard Worker # compute mean as a proxy for some joint reasoning 7120*da0073e9SAndroid Build Coastguard Worker mean_combined = torch.stack(feat_combined).mean() 7121*da0073e9SAndroid Build Coastguard Worker mean_combined.backward() 7122*da0073e9SAndroid Build Coastguard Worker mean_combined_no_checkpoint = torch.stack(feat_combined_no_checkpoint).mean() 7123*da0073e9SAndroid Build Coastguard Worker mean_combined_no_checkpoint.backward() 7124*da0073e9SAndroid Build Coastguard Worker 7125*da0073e9SAndroid Build Coastguard Worker for checkpoint_param, param in zip( 7126*da0073e9SAndroid Build Coastguard Worker module.parameters(), module_copy.parameters() 7127*da0073e9SAndroid Build Coastguard Worker ): 7128*da0073e9SAndroid Build Coastguard Worker self.assertEqual(checkpoint_param.grad, param.grad) 7129*da0073e9SAndroid Build Coastguard Worker 7130*da0073e9SAndroid Build Coastguard Worker def test_checkpoint_valid_reset_on_error(self): 7131*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, requires_grad=True) 7132*da0073e9SAndroid Build Coastguard Worker 7133*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7134*da0073e9SAndroid Build Coastguard Worker Exception, "torch.utils.checkpoint is incompatible" 7135*da0073e9SAndroid Build Coastguard Worker ): 7136*da0073e9SAndroid Build Coastguard Worker b = checkpoint(torch.exp, a, use_reentrant=True).sum() 7137*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(b, (a,)) 7138*da0073e9SAndroid Build Coastguard Worker 7139*da0073e9SAndroid Build Coastguard Worker c = checkpoint(torch.exp, a, use_reentrant=True).sum() 7140*da0073e9SAndroid Build Coastguard Worker c.backward() 7141*da0073e9SAndroid Build Coastguard Worker 7142*da0073e9SAndroid Build Coastguard Worker @parametrize("use_reentrant", [True, False]) 7143*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_detached_tensor(self, use_reentrant): 7144*da0073e9SAndroid Build Coastguard Worker class NoGradModule(torch.nn.Module): 7145*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7146*da0073e9SAndroid Build Coastguard Worker super().__init__() 7147*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(2, 2, bias=False) 7148*da0073e9SAndroid Build Coastguard Worker self.lin2 = nn.Linear(2, 2, bias=False) 7149*da0073e9SAndroid Build Coastguard Worker 7150*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 7151*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 7152*da0073e9SAndroid Build Coastguard Worker return self.lin2(self.linear(x)) 7153*da0073e9SAndroid Build Coastguard Worker 7154*da0073e9SAndroid Build Coastguard Worker module = NoGradModule() 7155*da0073e9SAndroid Build Coastguard Worker 7156*da0073e9SAndroid Build Coastguard Worker err_ctx = ( 7157*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 7158*da0073e9SAndroid Build Coastguard Worker RuntimeError, "none of output has requires_grad=True" 7159*da0073e9SAndroid Build Coastguard Worker ) 7160*da0073e9SAndroid Build Coastguard Worker if use_reentrant 7161*da0073e9SAndroid Build Coastguard Worker else contextlib.nullcontext() 7162*da0073e9SAndroid Build Coastguard Worker ) 7163*da0073e9SAndroid Build Coastguard Worker 7164*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, requires_grad=True) 7165*da0073e9SAndroid Build Coastguard Worker for _ in range(3): 7166*da0073e9SAndroid Build Coastguard Worker with err_ctx: 7167*da0073e9SAndroid Build Coastguard Worker # out does not require grad 7168*da0073e9SAndroid Build Coastguard Worker out = checkpoint(module, a, use_reentrant=use_reentrant) 7169*da0073e9SAndroid Build Coastguard Worker # Make loss require grad, otherwise we would run into 7170*da0073e9SAndroid Build Coastguard Worker # "element 0 of tensors does not require grad and does not have a grad_fn" 7171*da0073e9SAndroid Build Coastguard Worker out += a 7172*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 7173*da0073e9SAndroid Build Coastguard Worker 7174*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_saved_object_identity(self): 7175*da0073e9SAndroid Build Coastguard Worker x_backward = None 7176*da0073e9SAndroid Build Coastguard Worker 7177*da0073e9SAndroid Build Coastguard Worker class Test(torch.autograd.Function): 7178*da0073e9SAndroid Build Coastguard Worker @staticmethod 7179*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 7180*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(y) 7181*da0073e9SAndroid Build Coastguard Worker return x 7182*da0073e9SAndroid Build Coastguard Worker 7183*da0073e9SAndroid Build Coastguard Worker @staticmethod 7184*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 7185*da0073e9SAndroid Build Coastguard Worker nonlocal x_backward 7186*da0073e9SAndroid Build Coastguard Worker (x_backward,) = ctx.saved_tensors 7187*da0073e9SAndroid Build Coastguard Worker return x, None 7188*da0073e9SAndroid Build Coastguard Worker 7189*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 7190*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1.0, requires_grad=False) 7191*da0073e9SAndroid Build Coastguard Worker 7192*da0073e9SAndroid Build Coastguard Worker Test.apply(a, b).backward() 7193*da0073e9SAndroid Build Coastguard Worker self.assertIs(b, x_backward) 7194*da0073e9SAndroid Build Coastguard Worker 7195*da0073e9SAndroid Build Coastguard Worker x_backward = None 7196*da0073e9SAndroid Build Coastguard Worker checkpoint(Test.apply, a, b, use_reentrant=False).backward() 7197*da0073e9SAndroid Build Coastguard Worker self.assertIs(b, x_backward) 7198*da0073e9SAndroid Build Coastguard Worker 7199*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_correct_grad(self): 7200*da0073e9SAndroid Build Coastguard Worker """ 7201*da0073e9SAndroid Build Coastguard Worker Verifies that correct gradients are calculated for checkpoint 7202*da0073e9SAndroid Build Coastguard Worker without reentrant autograd, for both backward() and autograd.grad(). 7203*da0073e9SAndroid Build Coastguard Worker """ 7204*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, requires_grad=True) 7205*da0073e9SAndroid Build Coastguard Worker 7206*da0073e9SAndroid Build Coastguard Worker b = torch.exp(a).sum() 7207*da0073e9SAndroid Build Coastguard Worker b.backward() 7208*da0073e9SAndroid Build Coastguard Worker b_grad = a.grad 7209*da0073e9SAndroid Build Coastguard Worker 7210*da0073e9SAndroid Build Coastguard Worker a.grad = None 7211*da0073e9SAndroid Build Coastguard Worker c = checkpoint(torch.exp, a, use_reentrant=False).sum() 7212*da0073e9SAndroid Build Coastguard Worker c.backward() 7213*da0073e9SAndroid Build Coastguard Worker c_grad = a.grad 7214*da0073e9SAndroid Build Coastguard Worker 7215*da0073e9SAndroid Build Coastguard Worker a.grad = None 7216*da0073e9SAndroid Build Coastguard Worker d = checkpoint(torch.exp, a, use_reentrant=False).sum() 7217*da0073e9SAndroid Build Coastguard Worker (d_grad,) = torch.autograd.grad(d, (a,)) 7218*da0073e9SAndroid Build Coastguard Worker 7219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_grad, c_grad) 7220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_grad, d_grad) 7221*da0073e9SAndroid Build Coastguard Worker 7222*da0073e9SAndroid Build Coastguard Worker # PYTORCH_TEST_WITH_DYNAMO=1 test fails on CI but can't repro locally 7223*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/127115") 7224*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_dataparallel(self): 7225*da0073e9SAndroid Build Coastguard Worker """ 7226*da0073e9SAndroid Build Coastguard Worker Verifies gradient correctness when checkpoint without reentrant autograd 7227*da0073e9SAndroid Build Coastguard Worker is used in conjunction with DataParallel. 7228*da0073e9SAndroid Build Coastguard Worker """ 7229*da0073e9SAndroid Build Coastguard Worker 7230*da0073e9SAndroid Build Coastguard Worker class LinearModule(torch.nn.Module): 7231*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7232*da0073e9SAndroid Build Coastguard Worker super().__init__() 7233*da0073e9SAndroid Build Coastguard Worker self.linear = nn.Linear(2, 2, bias=False) 7234*da0073e9SAndroid Build Coastguard Worker 7235*da0073e9SAndroid Build Coastguard Worker def forward(self, inp): 7236*da0073e9SAndroid Build Coastguard Worker return self.linear(inp) 7237*da0073e9SAndroid Build Coastguard Worker 7238*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, requires_grad=True) 7239*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7240*da0073e9SAndroid Build Coastguard Worker a = a.cuda() 7241*da0073e9SAndroid Build Coastguard Worker 7242*da0073e9SAndroid Build Coastguard Worker model = LinearModule() 7243*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available(): 7244*da0073e9SAndroid Build Coastguard Worker model = model.cuda() 7245*da0073e9SAndroid Build Coastguard Worker 7246*da0073e9SAndroid Build Coastguard Worker b = deepcopy(model)(a).sum() 7247*da0073e9SAndroid Build Coastguard Worker b.backward() 7248*da0073e9SAndroid Build Coastguard Worker b_grad = a.grad 7249*da0073e9SAndroid Build Coastguard Worker 7250*da0073e9SAndroid Build Coastguard Worker a.grad = None 7251*da0073e9SAndroid Build Coastguard Worker 7252*da0073e9SAndroid Build Coastguard Worker module = torch.nn.DataParallel(deepcopy(model)) 7253*da0073e9SAndroid Build Coastguard Worker c = checkpoint(module, a, use_reentrant=False).sum() 7254*da0073e9SAndroid Build Coastguard Worker c.backward() 7255*da0073e9SAndroid Build Coastguard Worker c_grad = a.grad 7256*da0073e9SAndroid Build Coastguard Worker 7257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_grad, c_grad) 7258*da0073e9SAndroid Build Coastguard Worker 7259*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_parameter_used_in_an_out(self): 7260*da0073e9SAndroid Build Coastguard Worker """ 7261*da0073e9SAndroid Build Coastguard Worker Ensures that gradient hooks are only called once per tensor. 7262*da0073e9SAndroid Build Coastguard Worker """ 7263*da0073e9SAndroid Build Coastguard Worker w = torch.randn(10, 10, requires_grad=True) 7264*da0073e9SAndroid Build Coastguard Worker count = 0 7265*da0073e9SAndroid Build Coastguard Worker 7266*da0073e9SAndroid Build Coastguard Worker def hook(grad): 7267*da0073e9SAndroid Build Coastguard Worker nonlocal count 7268*da0073e9SAndroid Build Coastguard Worker count += 1 7269*da0073e9SAndroid Build Coastguard Worker 7270*da0073e9SAndroid Build Coastguard Worker w.register_hook(hook) 7271*da0073e9SAndroid Build Coastguard Worker x = torch.rand(10, 10, requires_grad=True) 7272*da0073e9SAndroid Build Coastguard Worker h = w * x # Using w outside the checkpoint 7273*da0073e9SAndroid Build Coastguard Worker out = checkpoint( 7274*da0073e9SAndroid Build Coastguard Worker lambda x: w * x, h, use_reentrant=False 7275*da0073e9SAndroid Build Coastguard Worker ) # Using w inside the checkpoint 7276*da0073e9SAndroid Build Coastguard Worker 7277*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 7278*da0073e9SAndroid Build Coastguard Worker # should only call hook once 7279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count, 1) 7280*da0073e9SAndroid Build Coastguard Worker 7281*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/127115 7282*da0073e9SAndroid Build Coastguard Worker @xfailIfTorchDynamo 7283*da0073e9SAndroid Build Coastguard Worker def test_checkpointing_without_reentrant_arbitrary_input_output(self): 7284*da0073e9SAndroid Build Coastguard Worker """ 7285*da0073e9SAndroid Build Coastguard Worker Ensures checkpointing without reentrant autograd works with functions 7286*da0073e9SAndroid Build Coastguard Worker with arbitrary input/output structures. 7287*da0073e9SAndroid Build Coastguard Worker """ 7288*da0073e9SAndroid Build Coastguard Worker 7289*da0073e9SAndroid Build Coastguard Worker class MyModel(torch.nn.Module): 7290*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 7291*da0073e9SAndroid Build Coastguard Worker super().__init__() 7292*da0073e9SAndroid Build Coastguard Worker self.layer = torch.nn.Linear(5, 5, bias=False) 7293*da0073e9SAndroid Build Coastguard Worker 7294*da0073e9SAndroid Build Coastguard Worker def forward(self, dict_input): 7295*da0073e9SAndroid Build Coastguard Worker tensor = dict_input["tensor"] 7296*da0073e9SAndroid Build Coastguard Worker return {"result": self.layer(tensor)} 7297*da0073e9SAndroid Build Coastguard Worker 7298*da0073e9SAndroid Build Coastguard Worker model_no_checkpoint = MyModel() 7299*da0073e9SAndroid Build Coastguard Worker model_checkpoint_without_reentrant = deepcopy(model_no_checkpoint) 7300*da0073e9SAndroid Build Coastguard Worker 7301*da0073e9SAndroid Build Coastguard Worker inp = {"tensor": torch.randn(5, 5)} 7302*da0073e9SAndroid Build Coastguard Worker 7303*da0073e9SAndroid Build Coastguard Worker out_no_checkpoint = model_no_checkpoint(inp)["result"].sum() 7304*da0073e9SAndroid Build Coastguard Worker 7305*da0073e9SAndroid Build Coastguard Worker out_checkpoint = checkpoint( 7306*da0073e9SAndroid Build Coastguard Worker model_checkpoint_without_reentrant, inp, use_reentrant=False 7307*da0073e9SAndroid Build Coastguard Worker )["result"].sum() 7308*da0073e9SAndroid Build Coastguard Worker 7309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_checkpoint, out_no_checkpoint) 7310*da0073e9SAndroid Build Coastguard Worker 7311*da0073e9SAndroid Build Coastguard Worker out_no_checkpoint.backward() 7312*da0073e9SAndroid Build Coastguard Worker out_checkpoint.backward() 7313*da0073e9SAndroid Build Coastguard Worker 7314*da0073e9SAndroid Build Coastguard Worker for param, checkpoint_param in zip( 7315*da0073e9SAndroid Build Coastguard Worker model_no_checkpoint.parameters(), 7316*da0073e9SAndroid Build Coastguard Worker model_checkpoint_without_reentrant.parameters(), 7317*da0073e9SAndroid Build Coastguard Worker ): 7318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param.grad, checkpoint_param.grad) 7319*da0073e9SAndroid Build Coastguard Worker 7320*da0073e9SAndroid Build Coastguard Worker def test_callback_adds_callback(self): 7321*da0073e9SAndroid Build Coastguard Worker called = [0] 7322*da0073e9SAndroid Build Coastguard Worker 7323*da0073e9SAndroid Build Coastguard Worker def callback_final(): 7324*da0073e9SAndroid Build Coastguard Worker called[0] += 1 7325*da0073e9SAndroid Build Coastguard Worker 7326*da0073e9SAndroid Build Coastguard Worker def callback_adds_callback(): 7327*da0073e9SAndroid Build Coastguard Worker called[0] += 1 7328*da0073e9SAndroid Build Coastguard Worker Variable._execution_engine.queue_callback(callback_final) 7329*da0073e9SAndroid Build Coastguard Worker 7330*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 7331*da0073e9SAndroid Build Coastguard Worker @staticmethod 7332*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 7333*da0073e9SAndroid Build Coastguard Worker return input 7334*da0073e9SAndroid Build Coastguard Worker 7335*da0073e9SAndroid Build Coastguard Worker @staticmethod 7336*da0073e9SAndroid Build Coastguard Worker @once_differentiable 7337*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 7338*da0073e9SAndroid Build Coastguard Worker Variable._execution_engine.queue_callback(callback_adds_callback) 7339*da0073e9SAndroid Build Coastguard Worker return grad 7340*da0073e9SAndroid Build Coastguard Worker 7341*da0073e9SAndroid Build Coastguard Worker a = torch.rand((3, 3), requires_grad=True) 7342*da0073e9SAndroid Build Coastguard Worker b = MyFunc.apply(a) 7343*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 7344*da0073e9SAndroid Build Coastguard Worker 7345*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called[0], 2) 7346*da0073e9SAndroid Build Coastguard Worker 7347*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 7348*da0073e9SAndroid Build Coastguard Worker def test_callback_propagates_errors_from_device_thread(self): 7349*da0073e9SAndroid Build Coastguard Worker def callback(): 7350*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("blah") 7351*da0073e9SAndroid Build Coastguard Worker 7352*da0073e9SAndroid Build Coastguard Worker def hook_with_callback(*args): 7353*da0073e9SAndroid Build Coastguard Worker torch.autograd.Variable._execution_engine.queue_callback(callback) 7354*da0073e9SAndroid Build Coastguard Worker 7355*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda")) 7356*da0073e9SAndroid Build Coastguard Worker t.register_hook(hook_with_callback) 7357*da0073e9SAndroid Build Coastguard Worker output = t**2 7358*da0073e9SAndroid Build Coastguard Worker loss = output.sum() 7359*da0073e9SAndroid Build Coastguard Worker 7360*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "blah"): 7361*da0073e9SAndroid Build Coastguard Worker loss.backward() 7362*da0073e9SAndroid Build Coastguard Worker 7363*da0073e9SAndroid Build Coastguard Worker def _test_reentrant_with_callbacks(self, install_callbacks_in_depths): 7364*da0073e9SAndroid Build Coastguard Worker counter = {} 7365*da0073e9SAndroid Build Coastguard Worker counter["inner"] = 0 7366*da0073e9SAndroid Build Coastguard Worker counter["outer"] = 0 7367*da0073e9SAndroid Build Coastguard Worker 7368*da0073e9SAndroid Build Coastguard Worker def inc_inner_counter(): 7369*da0073e9SAndroid Build Coastguard Worker counter["inner"] += 1 7370*da0073e9SAndroid Build Coastguard Worker 7371*da0073e9SAndroid Build Coastguard Worker def inc_outer_counter(): 7372*da0073e9SAndroid Build Coastguard Worker counter["outer"] += 1 7373*da0073e9SAndroid Build Coastguard Worker 7374*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 7375*da0073e9SAndroid Build Coastguard Worker @staticmethod 7376*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 7377*da0073e9SAndroid Build Coastguard Worker return input 7378*da0073e9SAndroid Build Coastguard Worker 7379*da0073e9SAndroid Build Coastguard Worker @staticmethod 7380*da0073e9SAndroid Build Coastguard Worker @once_differentiable 7381*da0073e9SAndroid Build Coastguard Worker def backward(ctx, input): 7382*da0073e9SAndroid Build Coastguard Worker if 1 in install_callbacks_in_depths: 7383*da0073e9SAndroid Build Coastguard Worker # Add a callback to execute. 7384*da0073e9SAndroid Build Coastguard Worker Variable._execution_engine.queue_callback(inc_inner_counter) 7385*da0073e9SAndroid Build Coastguard Worker 7386*da0073e9SAndroid Build Coastguard Worker return input 7387*da0073e9SAndroid Build Coastguard Worker 7388*da0073e9SAndroid Build Coastguard Worker class MyReentrantFunc(Function): 7389*da0073e9SAndroid Build Coastguard Worker @staticmethod 7390*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 7391*da0073e9SAndroid Build Coastguard Worker return input 7392*da0073e9SAndroid Build Coastguard Worker 7393*da0073e9SAndroid Build Coastguard Worker @staticmethod 7394*da0073e9SAndroid Build Coastguard Worker @once_differentiable 7395*da0073e9SAndroid Build Coastguard Worker def backward(ctx, input): 7396*da0073e9SAndroid Build Coastguard Worker if 0 in install_callbacks_in_depths: 7397*da0073e9SAndroid Build Coastguard Worker # Add a callback to execute. 7398*da0073e9SAndroid Build Coastguard Worker Variable._execution_engine.queue_callback(inc_outer_counter) 7399*da0073e9SAndroid Build Coastguard Worker # Reentrant backward call. 7400*da0073e9SAndroid Build Coastguard Worker tmp_inp = input.detach().requires_grad_() 7401*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 7402*da0073e9SAndroid Build Coastguard Worker tmp_out = (MyFunc.apply(tmp_inp)).sum() 7403*da0073e9SAndroid Build Coastguard Worker tmp_out.backward() 7404*da0073e9SAndroid Build Coastguard Worker return input 7405*da0073e9SAndroid Build Coastguard Worker 7406*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand((3, 3), requires_grad=True) 7407*da0073e9SAndroid Build Coastguard Worker t2 = MyReentrantFunc.apply(t1) 7408*da0073e9SAndroid Build Coastguard Worker t3 = t2.sum() 7409*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([t3]) 7410*da0073e9SAndroid Build Coastguard Worker 7411*da0073e9SAndroid Build Coastguard Worker return counter 7412*da0073e9SAndroid Build Coastguard Worker 7413*da0073e9SAndroid Build Coastguard Worker def test_reentrant_with_callbacks_depth_0(self): 7414*da0073e9SAndroid Build Coastguard Worker # Verify callback is called only once. 7415*da0073e9SAndroid Build Coastguard Worker ret = self._test_reentrant_with_callbacks([0]) 7416*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, ret["outer"]) 7417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, ret["inner"]) 7418*da0073e9SAndroid Build Coastguard Worker 7419*da0073e9SAndroid Build Coastguard Worker def test_reentrant_with_callbacks_depth_1(self): 7420*da0073e9SAndroid Build Coastguard Worker # Verify callback is called only once. 7421*da0073e9SAndroid Build Coastguard Worker ret = self._test_reentrant_with_callbacks([1]) 7422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(0, ret["outer"]) 7423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, ret["inner"]) 7424*da0073e9SAndroid Build Coastguard Worker 7425*da0073e9SAndroid Build Coastguard Worker def test_reentrant_with_callbacks_both_depths(self): 7426*da0073e9SAndroid Build Coastguard Worker # Verify callback is called twice. 7427*da0073e9SAndroid Build Coastguard Worker ret = self._test_reentrant_with_callbacks([0, 1]) 7428*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, ret["outer"]) 7429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(1, ret["inner"]) 7430*da0073e9SAndroid Build Coastguard Worker 7431*da0073e9SAndroid Build Coastguard Worker def test_reentrant_with_leaf_variable_hook(self): 7432*da0073e9SAndroid Build Coastguard Worker handle = None 7433*da0073e9SAndroid Build Coastguard Worker param = torch.rand(10, requires_grad=True) 7434*da0073e9SAndroid Build Coastguard Worker 7435*da0073e9SAndroid Build Coastguard Worker def add_gradient_penalty_to_grad(grad): 7436*da0073e9SAndroid Build Coastguard Worker handle.remove() 7437*da0073e9SAndroid Build Coastguard Worker old_param_grad = grad 7438*da0073e9SAndroid Build Coastguard Worker param.grad = None 7439*da0073e9SAndroid Build Coastguard Worker # Add some sort of gradient penalty by directly updating the gradients 7440*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 7441*da0073e9SAndroid Build Coastguard Worker g = grad.detach().requires_grad_() 7442*da0073e9SAndroid Build Coastguard Worker new_param = param.detach().requires_grad_() 7443*da0073e9SAndroid Build Coastguard Worker out = ((g * 2) + new_param).sum() 7444*da0073e9SAndroid Build Coastguard Worker out.backward() 7445*da0073e9SAndroid Build Coastguard Worker res = g.grad + grad 7446*da0073e9SAndroid Build Coastguard Worker param.grad = old_param_grad 7447*da0073e9SAndroid Build Coastguard Worker return res 7448*da0073e9SAndroid Build Coastguard Worker 7449*da0073e9SAndroid Build Coastguard Worker handle = param.register_hook(add_gradient_penalty_to_grad) 7450*da0073e9SAndroid Build Coastguard Worker # Forward pass 7451*da0073e9SAndroid Build Coastguard Worker tmp = param * param 7452*da0073e9SAndroid Build Coastguard Worker loss = tmp.sum() 7453*da0073e9SAndroid Build Coastguard Worker # Compute the gradients 7454*da0073e9SAndroid Build Coastguard Worker loss.backward() 7455*da0073e9SAndroid Build Coastguard Worker 7456*da0073e9SAndroid Build Coastguard Worker def test_reentrant_with_non_leaf_variable_hook(self): 7457*da0073e9SAndroid Build Coastguard Worker handle = None 7458*da0073e9SAndroid Build Coastguard Worker param = torch.rand(10, requires_grad=True) 7459*da0073e9SAndroid Build Coastguard Worker 7460*da0073e9SAndroid Build Coastguard Worker def manual_increase_gradient(grad): 7461*da0073e9SAndroid Build Coastguard Worker handle.remove() 7462*da0073e9SAndroid Build Coastguard Worker # Add some sort of gradient penalty by directly updating the gradients 7463*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 7464*da0073e9SAndroid Build Coastguard Worker g = grad.detach().requires_grad_() 7465*da0073e9SAndroid Build Coastguard Worker out = ((g * 2) + 5).sum() 7466*da0073e9SAndroid Build Coastguard Worker out.backward() 7467*da0073e9SAndroid Build Coastguard Worker res = g.grad + grad 7468*da0073e9SAndroid Build Coastguard Worker return res 7469*da0073e9SAndroid Build Coastguard Worker 7470*da0073e9SAndroid Build Coastguard Worker # Forward pass 7471*da0073e9SAndroid Build Coastguard Worker tmp = param * param 7472*da0073e9SAndroid Build Coastguard Worker handle = tmp.register_hook(manual_increase_gradient) 7473*da0073e9SAndroid Build Coastguard Worker loss = tmp.sum() 7474*da0073e9SAndroid Build Coastguard Worker # Compute the gradients 7475*da0073e9SAndroid Build Coastguard Worker loss.backward() 7476*da0073e9SAndroid Build Coastguard Worker self.assertEqual(param.grad, 6 * param) 7477*da0073e9SAndroid Build Coastguard Worker 7478*da0073e9SAndroid Build Coastguard Worker def test_grad_fn_attr_bindings(self): 7479*da0073e9SAndroid Build Coastguard Worker # Check that the getter of each type returns what we want 7480*da0073e9SAndroid Build Coastguard Worker # See `gen_autograd_functions.py` for how the getters are generated 7481*da0073e9SAndroid Build Coastguard Worker # 7482*da0073e9SAndroid Build Coastguard Worker # This test is only meant to check if the codegen'd bindings work 7483*da0073e9SAndroid Build Coastguard Worker # Please help update this test if you update the names of any the fields we check! 7484*da0073e9SAndroid Build Coastguard Worker # 7485*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, requires_grad=True) 7486*da0073e9SAndroid Build Coastguard Worker b = torch.zeros(1, requires_grad=True) 7487*da0073e9SAndroid Build Coastguard Worker out1 = torch.stack([a, b], dim=0) 7488*da0073e9SAndroid Build Coastguard Worker out2 = (a * 2) * b 7489*da0073e9SAndroid Build Coastguard Worker # TODO: I don't think we have a backward saving a list of tensors 7490*da0073e9SAndroid Build Coastguard Worker # at the moment. It used to be stack, but for no reason... 7491*da0073e9SAndroid Build Coastguard Worker # see discussion in #84993 7492*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(out.grad_fn._saved_tensors, (a, b)) # TewnsorList -> Tuple[Tensor] 7493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out2.grad_fn._saved_self, a * 2) 7494*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out2.grad_fn._saved_self, torch.Tensor) 7495*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7496*da0073e9SAndroid Build Coastguard Worker out2.grad_fn._raw_saved_self, torch._C._autograd.SavedTensor 7497*da0073e9SAndroid Build Coastguard Worker ) 7498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1.grad_fn._saved_dim, 0) # int64_t -> int 7499*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out1.grad_fn._saved_dim, int) 7500*da0073e9SAndroid Build Coastguard Worker 7501*da0073e9SAndroid Build Coastguard Worker out2.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 7502*da0073e9SAndroid Build Coastguard Worker 7503*da0073e9SAndroid Build Coastguard Worker out2.sum().backward() 7504*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7505*da0073e9SAndroid Build Coastguard Worker out2.grad_fn._saved_self 7506*da0073e9SAndroid Build Coastguard Worker # TODO: interestingly, this only happens if indexing into a list grad_fn._raw_saved_tensors[0], 7507*da0073e9SAndroid Build Coastguard Worker # not when using a saved tensor, see discussion in #84993 7508*da0073e9SAndroid Build Coastguard Worker # with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7509*da0073e9SAndroid Build Coastguard Worker # out2.grad_fn._raw_saved_self 7510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out1.grad_fn._saved_dim, 0) 7511*da0073e9SAndroid Build Coastguard Worker 7512*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 2, requires_grad=True) 7513*da0073e9SAndroid Build Coastguard Worker indices = torch.tensor([0, 1]) 7514*da0073e9SAndroid Build Coastguard Worker out = a[:, indices] 7515*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7516*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_indices, (None, indices) 7517*da0073e9SAndroid Build Coastguard Worker ) # c10::List<std::optional<Tensor>> -> Tuple[Tensor?] 7518*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_indices[1], torch.Tensor) 7519*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7520*da0073e9SAndroid Build Coastguard Worker out.grad_fn._raw_saved_indices[1], torch._C._autograd.SavedTensor 7521*da0073e9SAndroid Build Coastguard Worker ) 7522*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7523*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_self_sym_sizes, a.shape 7524*da0073e9SAndroid Build Coastguard Worker ) # SymIntArrayRef -> Tuple[SymInt] 7525*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_self_sym_sizes[0], int) 7526*da0073e9SAndroid Build Coastguard Worker 7527*da0073e9SAndroid Build Coastguard Worker out.grad_fn._raw_saved_indices[1].register_hooks(lambda x: x, lambda x: x) 7528*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "None is forbidden"): 7529*da0073e9SAndroid Build Coastguard Worker out.grad_fn._raw_saved_indices[0].register_hooks(lambda x: x, lambda x: x) 7530*da0073e9SAndroid Build Coastguard Worker 7531*da0073e9SAndroid Build Coastguard Worker out = a.mean() 7532*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7533*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_self_sym_sizes, a.shape 7534*da0073e9SAndroid Build Coastguard Worker ) # IntArrayRef -> Tuple[int] 7535*da0073e9SAndroid Build Coastguard Worker 7536*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 2, requires_grad=True) 7537*da0073e9SAndroid Build Coastguard Worker out = a * a 7538*da0073e9SAndroid Build Coastguard Worker out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 7539*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 7540*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "after it has been freed"): 7541*da0073e9SAndroid Build Coastguard Worker out.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 7542*da0073e9SAndroid Build Coastguard Worker 7543*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 2, requires_grad=True) 7544*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.interpolate(a, 4, mode="linear") 7545*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7546*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_output_size, (4,) 7547*da0073e9SAndroid Build Coastguard Worker ) # std::optional<IntArrayRef> -> int[]? 7548*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_output_size[0], int) 7549*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_align_corners, False) # bool -> bool 7550*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_align_corners, bool) 7551*da0073e9SAndroid Build Coastguard Worker if hasattr(out.grad_fn, "_saved_scale_factors"): 7552*da0073e9SAndroid Build Coastguard Worker self.assertIsNone( 7553*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_scale_factors 7554*da0073e9SAndroid Build Coastguard Worker ) # std::optional<ArrayRef<double>> -> float[]? 7555*da0073e9SAndroid Build Coastguard Worker else: 7556*da0073e9SAndroid Build Coastguard Worker self.assertIsNone( 7557*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_scales 7558*da0073e9SAndroid Build Coastguard Worker ) # std::optional<ArrayRef<double>> -> float[]? 7559*da0073e9SAndroid Build Coastguard Worker 7560*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 3, 3, requires_grad=True) 7561*da0073e9SAndroid Build Coastguard Worker out = nn.Conv2d(1, 1, 3)(a) 7562*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7563*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_bias_sym_sizes_opt, (1,) 7564*da0073e9SAndroid Build Coastguard Worker ) # std::optional<SymIntArrayRef> -> SymInt[]? 7565*da0073e9SAndroid Build Coastguard Worker out = nn.Conv2d(1, 1, 3, bias=False)(a) 7566*da0073e9SAndroid Build Coastguard Worker # TODO: This is BAD! we converted a std::nullopt into a (0,) 7567*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_bias_sym_sizes_opt, (0,)) 7568*da0073e9SAndroid Build Coastguard Worker 7569*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 3, 3, requires_grad=True) 7570*da0073e9SAndroid Build Coastguard Worker out = torch.addbmm(a.squeeze(0), a, a) 7571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_0, 1) # int64_t 7572*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_batch1_sym_argsize_1, 3) # int64_t 7573*da0073e9SAndroid Build Coastguard Worker 7574*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 3, 3, requires_grad=True) 7575*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.unfold(a, 3) 7576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_2, 3) # SymInt 7577*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_self_sym_argsize_minus_1, 3) # SymInt 7578*da0073e9SAndroid Build Coastguard Worker 7579*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 2, requires_grad=True) 7580*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.interpolate(a, scale_factor=0.5, mode="linear") 7581*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_scales, 0.5) 7582*da0073e9SAndroid Build Coastguard Worker 7583*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 2, requires_grad=True) 7584*da0073e9SAndroid Build Coastguard Worker out = torch.pdist(a, p=1) 7585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_p, 1.0) # double -> float 7586*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_p, float) 7587*da0073e9SAndroid Build Coastguard Worker 7588*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, 2, requires_grad=True) 7589*da0073e9SAndroid Build Coastguard Worker out = torch.logit(a, 1.0) 7590*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad_fn._saved_eps, 1.0) # c10:optional<double> -> float? 7591*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_eps, float) 7592*da0073e9SAndroid Build Coastguard Worker out = torch.logit(a) 7593*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn._saved_eps) 7594*da0073e9SAndroid Build Coastguard Worker 7595*da0073e9SAndroid Build Coastguard Worker if torch._C.has_lapack: 7596*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, 1, requires_grad=True) 7597*da0073e9SAndroid Build Coastguard Worker q, r = torch.linalg.qr(a, mode="reduced") 7598*da0073e9SAndroid Build Coastguard Worker self.assertEqual(q.grad_fn._saved_mode, "reduced") # std::string -> str 7599*da0073e9SAndroid Build Coastguard Worker 7600*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True) 7601*da0073e9SAndroid Build Coastguard Worker out = torch.div(a, 2.0, rounding_mode="trunc") 7602*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7603*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_rounding_mode, "trunc" 7604*da0073e9SAndroid Build Coastguard Worker ) # std::optional<std::string> -> str? 7605*da0073e9SAndroid Build Coastguard Worker out = torch.div(a, 2.0, rounding_mode=None) 7606*da0073e9SAndroid Build Coastguard Worker self.assertIsNone( 7607*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_rounding_mode 7608*da0073e9SAndroid Build Coastguard Worker ) # std::optional<std::string> -> str? 7609*da0073e9SAndroid Build Coastguard Worker 7610*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(5, requires_grad=True) 7611*da0073e9SAndroid Build Coastguard Worker out = torch.threshold(x, threshold=(1 + 0j), value=(1 + 0j)) 7612*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7613*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_threshold, complex 7614*da0073e9SAndroid Build Coastguard Worker ) # Scalar(complex double) -> complex 7615*da0073e9SAndroid Build Coastguard Worker cfloat = torch.tensor(1 + 0j, dtype=torch.complex64) 7616*da0073e9SAndroid Build Coastguard Worker out = torch.threshold(x, threshold=cfloat, value=(1 + 0j)) 7617*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7618*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_threshold, complex 7619*da0073e9SAndroid Build Coastguard Worker ) # Scalar(complex float) -> complex 7620*da0073e9SAndroid Build Coastguard Worker out = torch.threshold(x, threshold=1.0, value=1.0) 7621*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7622*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_threshold, float 7623*da0073e9SAndroid Build Coastguard Worker ) # Scalar(floating point) -> float 7624*da0073e9SAndroid Build Coastguard Worker out = torch.threshold(x, threshold=1, value=1) 7625*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7626*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_threshold, int 7627*da0073e9SAndroid Build Coastguard Worker ) # Scalar(integral) -> int 7628*da0073e9SAndroid Build Coastguard Worker out = torch.threshold(x, threshold=False, value=False) 7629*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 7630*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_threshold, bool 7631*da0073e9SAndroid Build Coastguard Worker ) # Scalar(bool) -> bool 7632*da0073e9SAndroid Build Coastguard Worker 7633*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, 2, requires_grad=True) 7634*da0073e9SAndroid Build Coastguard Worker out = a.as_strided((3,), (1,), 1) 7635*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7636*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_storage_offset, 1 7637*da0073e9SAndroid Build Coastguard Worker ) # c10:optional<int64_t> -> int? 7638*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn._saved_storage_offset, int) 7639*da0073e9SAndroid Build Coastguard Worker out = a.as_strided((3,), (1,)) 7640*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn._saved_storage_offset) 7641*da0073e9SAndroid Build Coastguard Worker 7642*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, requires_grad=True) 7643*da0073e9SAndroid Build Coastguard Worker out = torch.tanh(a) 7644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out, out.grad_fn._saved_result) # saved variable when output 7645*da0073e9SAndroid Build Coastguard Worker 7646*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 5, requires_grad=True) 7647*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1, 0, 4]) 7648*da0073e9SAndroid Build Coastguard Worker loss = nn.NLLLoss() 7649*da0073e9SAndroid Build Coastguard Worker out = loss(a, b) 7650*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn._saved_weight) 7651*da0073e9SAndroid Build Coastguard Worker loss = nn.NLLLoss(weight=torch.ones((5,))) 7652*da0073e9SAndroid Build Coastguard Worker out = loss(a, b) 7653*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 7654*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_weight, torch.ones((5,)) 7655*da0073e9SAndroid Build Coastguard Worker ) # c10:optional<Tensor> -> Tensor? 7656*da0073e9SAndroid Build Coastguard Worker 7657*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 7658*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7659*da0073e9SAndroid Build Coastguard Worker out.grad_fn._saved_weight 7660*da0073e9SAndroid Build Coastguard Worker 7661*da0073e9SAndroid Build Coastguard Worker num_tensors = 3 7662*da0073e9SAndroid Build Coastguard Worker input_tensors = [ 7663*da0073e9SAndroid Build Coastguard Worker torch.ones(2, 2, requires_grad=True) for _ in range(num_tensors) 7664*da0073e9SAndroid Build Coastguard Worker ] 7665*da0073e9SAndroid Build Coastguard Worker scalars = [ 7666*da0073e9SAndroid Build Coastguard Worker 0.0 for _ in range(num_tensors) 7667*da0073e9SAndroid Build Coastguard Worker ] # ArrayRef<Scalar> -> Tuple[Scalar, ...] 7668*da0073e9SAndroid Build Coastguard Worker results = torch._foreach_maximum(input_tensors, scalars) 7669*da0073e9SAndroid Build Coastguard Worker for t in results: 7670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad_fn._saved_scalars, scalars) 7671*da0073e9SAndroid Build Coastguard Worker 7672*da0073e9SAndroid Build Coastguard Worker def test_cant_create_saved_tensors(self): 7673*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7674*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7675*da0073e9SAndroid Build Coastguard Worker "Trying to create a SavedTensor object from Python is forbidden", 7676*da0073e9SAndroid Build Coastguard Worker ): 7677*da0073e9SAndroid Build Coastguard Worker torch.autograd.SavedTensor() 7678*da0073e9SAndroid Build Coastguard Worker 7679*da0073e9SAndroid Build Coastguard Worker def test_custom_function_saved_tensors(self): 7680*da0073e9SAndroid Build Coastguard Worker def getFn(save=True): 7681*da0073e9SAndroid Build Coastguard Worker class MyFn(Function): 7682*da0073e9SAndroid Build Coastguard Worker @staticmethod 7683*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 7684*da0073e9SAndroid Build Coastguard Worker if save: 7685*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, None) 7686*da0073e9SAndroid Build Coastguard Worker return x 7687*da0073e9SAndroid Build Coastguard Worker 7688*da0073e9SAndroid Build Coastguard Worker @staticmethod 7689*da0073e9SAndroid Build Coastguard Worker def backward(ctx, g): 7690*da0073e9SAndroid Build Coastguard Worker return g 7691*da0073e9SAndroid Build Coastguard Worker 7692*da0073e9SAndroid Build Coastguard Worker return MyFn 7693*da0073e9SAndroid Build Coastguard Worker 7694*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 7695*da0073e9SAndroid Build Coastguard Worker 7696*da0073e9SAndroid Build Coastguard Worker y = getFn(True).apply(a) 7697*da0073e9SAndroid Build Coastguard Worker 7698*da0073e9SAndroid Build Coastguard Worker self.assertEqual((a, None), y.grad_fn.saved_tensors) 7699*da0073e9SAndroid Build Coastguard Worker saved = y.grad_fn._raw_saved_tensors 7700*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(saved[0], torch._C._autograd.SavedTensor) 7701*da0073e9SAndroid Build Coastguard Worker # We can't tell the underlying tensor is None without unpacking it 7702*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(saved[1], torch._C._autograd.SavedTensor) 7703*da0073e9SAndroid Build Coastguard Worker 7704*da0073e9SAndroid Build Coastguard Worker # We catch that error when the user calls register_hooks on it 7705*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "None is forbidden"): 7706*da0073e9SAndroid Build Coastguard Worker saved[1].register_hooks(lambda x: x, lambda x: x) 7707*da0073e9SAndroid Build Coastguard Worker 7708*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "incompatible function arguments"): 7709*da0073e9SAndroid Build Coastguard Worker saved[0].register_hooks(lambda x: x) 7710*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "incompatible function arguments"): 7711*da0073e9SAndroid Build Coastguard Worker saved[0].register_hooks(1, 1) 7712*da0073e9SAndroid Build Coastguard Worker saved[0].register_hooks(lambda x: x, lambda x: x) 7713*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "already been set"): 7714*da0073e9SAndroid Build Coastguard Worker saved[0].register_hooks(lambda x: x, lambda x: x) 7715*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 7716*da0073e9SAndroid Build Coastguard Worker 7717*da0073e9SAndroid Build Coastguard Worker # Using a reference to the SavedTensor object after the 7718*da0073e9SAndroid Build Coastguard Worker # saved variables have been released can lead to undefined behavior 7719*da0073e9SAndroid Build Coastguard Worker del saved 7720*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7721*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_tensors 7722*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 7723*da0073e9SAndroid Build Coastguard Worker y.grad_fn.saved_tensors 7724*da0073e9SAndroid Build Coastguard Worker 7725*da0073e9SAndroid Build Coastguard Worker y = getFn(False).apply(a) 7726*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad_fn.saved_tensors, ()) 7727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad_fn._raw_saved_tensors, ()) 7728*da0073e9SAndroid Build Coastguard Worker 7729*da0073e9SAndroid Build Coastguard Worker def test_autograd_node_isinstance(self): 7730*da0073e9SAndroid Build Coastguard Worker # Node is a "virtual" base class of codegen'd nodes. This means that 7731*da0073e9SAndroid Build Coastguard Worker # isinstance and issubclass are overridden, but mro is unchanged 7732*da0073e9SAndroid Build Coastguard Worker Node = torch.autograd.graph.Node 7733*da0073e9SAndroid Build Coastguard Worker 7734*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3, 3, requires_grad=True) 7735*da0073e9SAndroid Build Coastguard Worker b = a.exp() 7736*da0073e9SAndroid Build Coastguard Worker 7737*da0073e9SAndroid Build Coastguard Worker # Some nodes have codegened registrations to the torch._C._function module 7738*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(b.grad_fn, Node) 7739*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(type(b.grad_fn), Node)) 7740*da0073e9SAndroid Build Coastguard Worker self.assertTrue(Node not in type(b.grad_fn).mro()) 7741*da0073e9SAndroid Build Coastguard Worker 7742*da0073e9SAndroid Build Coastguard Worker # Other nodes have manual registrations to the torch._C._function module 7743*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(torch._C._functions.AccumulateGrad, Node) 7744*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(torch._C._functions.AccumulateGrad, Node)) 7745*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(b.grad_fn.next_functions[0][0], Node) 7746*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(torch._C._functions.DelayedError, Node)) 7747*da0073e9SAndroid Build Coastguard Worker 7748*da0073e9SAndroid Build Coastguard Worker # Special cases 7749*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(None, Node) 7750*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(1, Node) 7751*da0073e9SAndroid Build Coastguard Worker self.assertNotIsInstance(Node, Node) 7752*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(Node, Node)) 7753*da0073e9SAndroid Build Coastguard Worker 7754*da0073e9SAndroid Build Coastguard Worker # Custom function case 7755*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(torch.autograd.function.BackwardCFunction, Node)) 7756*da0073e9SAndroid Build Coastguard Worker 7757*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 7758*da0073e9SAndroid Build Coastguard Worker @staticmethod 7759*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 7760*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(ctx, Node) 7761*da0073e9SAndroid Build Coastguard Worker return x 7762*da0073e9SAndroid Build Coastguard Worker 7763*da0073e9SAndroid Build Coastguard Worker @staticmethod 7764*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 7765*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(ctx, Node) 7766*da0073e9SAndroid Build Coastguard Worker return x 7767*da0073e9SAndroid Build Coastguard Worker 7768*da0073e9SAndroid Build Coastguard Worker out = Func.apply(a) 7769*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(out.grad_fn, Node) 7770*da0073e9SAndroid Build Coastguard Worker self.assertTrue(issubclass(type(out.grad_fn), Node)) 7771*da0073e9SAndroid Build Coastguard Worker self.assertTrue(Node not in type(out.grad_fn).mro()) 7772*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 7773*da0073e9SAndroid Build Coastguard Worker 7774*da0073e9SAndroid Build Coastguard Worker def test_autograd_views_codegen(self): 7775*da0073e9SAndroid Build Coastguard Worker # This is not necessarily the absolute correct behavior, but this is the current 7776*da0073e9SAndroid Build Coastguard Worker # one. This test is here to make sure that any change to this behavior is detected 7777*da0073e9SAndroid Build Coastguard Worker # and not silent. The TODOs below mark the places with unexpected behavior. 7778*da0073e9SAndroid Build Coastguard Worker # Note that any change in these test will be BC-breaking and should be done carefully. 7779*da0073e9SAndroid Build Coastguard Worker 7780*da0073e9SAndroid Build Coastguard Worker # This test checks the behavior of two codegen functions (view_as and unbind) 7781*da0073e9SAndroid Build Coastguard Worker # with respect to view tracking and inplace operation on the output. 7782*da0073e9SAndroid Build Coastguard Worker 7783*da0073e9SAndroid Build Coastguard Worker def run_test(grad_mode, requires_grad, is_view, should_raise_tuple): 7784*da0073e9SAndroid Build Coastguard Worker def maybe_check_raise(fn, should_raise): 7785*da0073e9SAndroid Build Coastguard Worker self.assertTrue(should_raise is None or isinstance(should_raise, str)) 7786*da0073e9SAndroid Build Coastguard Worker if should_raise is not None: 7787*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, should_raise): 7788*da0073e9SAndroid Build Coastguard Worker fn() 7789*da0073e9SAndroid Build Coastguard Worker else: 7790*da0073e9SAndroid Build Coastguard Worker fn() 7791*da0073e9SAndroid Build Coastguard Worker 7792*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, requires_grad=requires_grad).clone() 7793*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(grad_mode): 7794*da0073e9SAndroid Build Coastguard Worker out = inp.view_as(inp) 7795*da0073e9SAndroid Build Coastguard Worker # Are they differentiable views? 7796*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out._is_view() == is_view) 7797*da0073e9SAndroid Build Coastguard Worker # Are inplace allowed? 7798*da0073e9SAndroid Build Coastguard Worker maybe_check_raise(lambda: out.add_(1), should_raise_tuple[0]) 7799*da0073e9SAndroid Build Coastguard Worker 7800*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, requires_grad=requires_grad).clone() 7801*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(grad_mode): 7802*da0073e9SAndroid Build Coastguard Worker out = inp.unbind() 7803*da0073e9SAndroid Build Coastguard Worker # Are they differentiable views? 7804*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out[0]._is_view() == is_view) 7805*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out[1]._is_view() == is_view) 7806*da0073e9SAndroid Build Coastguard Worker # Are inplace allowed? 7807*da0073e9SAndroid Build Coastguard Worker maybe_check_raise(lambda: out[0].add_(1), should_raise_tuple[1]) 7808*da0073e9SAndroid Build Coastguard Worker maybe_check_raise(lambda: out[1].add_(1), should_raise_tuple[2]) 7809*da0073e9SAndroid Build Coastguard Worker 7810*da0073e9SAndroid Build Coastguard Worker # should_raise contains None if it should not raise 7811*da0073e9SAndroid Build Coastguard Worker # should_raise contains a string of the error if it should raise 7812*da0073e9SAndroid Build Coastguard Worker # The 3 elements are for view_as, first output of unbind and second output of unbind 7813*da0073e9SAndroid Build Coastguard Worker run_test( 7814*da0073e9SAndroid Build Coastguard Worker grad_mode=True, 7815*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 7816*da0073e9SAndroid Build Coastguard Worker is_view=True, 7817*da0073e9SAndroid Build Coastguard Worker should_raise_tuple=(None, None, None), 7818*da0073e9SAndroid Build Coastguard Worker ) 7819*da0073e9SAndroid Build Coastguard Worker inp_change_err = ( 7820*da0073e9SAndroid Build Coastguard Worker "Output {} of UnbindBackward0 is a view and is being modified inplace." 7821*da0073e9SAndroid Build Coastguard Worker ) 7822*da0073e9SAndroid Build Coastguard Worker run_test( 7823*da0073e9SAndroid Build Coastguard Worker grad_mode=True, 7824*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 7825*da0073e9SAndroid Build Coastguard Worker is_view=True, 7826*da0073e9SAndroid Build Coastguard Worker should_raise_tuple=( 7827*da0073e9SAndroid Build Coastguard Worker None, 7828*da0073e9SAndroid Build Coastguard Worker inp_change_err.format("0"), 7829*da0073e9SAndroid Build Coastguard Worker inp_change_err.format("1"), 7830*da0073e9SAndroid Build Coastguard Worker ), 7831*da0073e9SAndroid Build Coastguard Worker ) 7832*da0073e9SAndroid Build Coastguard Worker leaf_grad_err = ( 7833*da0073e9SAndroid Build Coastguard Worker "A view was created in no_grad mode and is being modified inplace" 7834*da0073e9SAndroid Build Coastguard Worker ) 7835*da0073e9SAndroid Build Coastguard Worker run_test( 7836*da0073e9SAndroid Build Coastguard Worker grad_mode=False, 7837*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 7838*da0073e9SAndroid Build Coastguard Worker is_view=True, 7839*da0073e9SAndroid Build Coastguard Worker should_raise_tuple=(leaf_grad_err, leaf_grad_err, leaf_grad_err), 7840*da0073e9SAndroid Build Coastguard Worker ) 7841*da0073e9SAndroid Build Coastguard Worker run_test( 7842*da0073e9SAndroid Build Coastguard Worker grad_mode=False, 7843*da0073e9SAndroid Build Coastguard Worker requires_grad=False, 7844*da0073e9SAndroid Build Coastguard Worker is_view=True, 7845*da0073e9SAndroid Build Coastguard Worker should_raise_tuple=(None, None, None), 7846*da0073e9SAndroid Build Coastguard Worker ) 7847*da0073e9SAndroid Build Coastguard Worker 7848*da0073e9SAndroid Build Coastguard Worker def test_inplace_not_requires_grad(self): 7849*da0073e9SAndroid Build Coastguard Worker class MyFn(torch.autograd.Function): 7850*da0073e9SAndroid Build Coastguard Worker @staticmethod 7851*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp): 7852*da0073e9SAndroid Build Coastguard Worker return inp.view_as(inp) 7853*da0073e9SAndroid Build Coastguard Worker 7854*da0073e9SAndroid Build Coastguard Worker @staticmethod 7855*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 7856*da0073e9SAndroid Build Coastguard Worker return grad 7857*da0073e9SAndroid Build Coastguard Worker 7858*da0073e9SAndroid Build Coastguard Worker # Original Tensor does not require grad 7859*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, 2) 7860*da0073e9SAndroid Build Coastguard Worker 7861*da0073e9SAndroid Build Coastguard Worker # Tensor being written does require grad 7862*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, requires_grad=True) 7863*da0073e9SAndroid Build Coastguard Worker 7864*da0073e9SAndroid Build Coastguard Worker # Take an invalid view on 'a' that should raise an error (warns during deprecation) 7865*da0073e9SAndroid Build Coastguard Worker view_a = MyFn.apply(a) 7866*da0073e9SAndroid Build Coastguard Worker 7867*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7868*da0073e9SAndroid Build Coastguard Worker RuntimeError, "This view was created inside a custom Function" 7869*da0073e9SAndroid Build Coastguard Worker ): 7870*da0073e9SAndroid Build Coastguard Worker view_a += b 7871*da0073e9SAndroid Build Coastguard Worker 7872*da0073e9SAndroid Build Coastguard Worker # Extra test for copy_ that is a manual implementation and could be easily 7873*da0073e9SAndroid Build Coastguard Worker # forgotten when the codegen is updated (warns during deprecation) 7874*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, 2) 7875*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, requires_grad=True) 7876*da0073e9SAndroid Build Coastguard Worker view_a = MyFn.apply(a) 7877*da0073e9SAndroid Build Coastguard Worker 7878*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7879*da0073e9SAndroid Build Coastguard Worker RuntimeError, "This view was created inside a custom Function" 7880*da0073e9SAndroid Build Coastguard Worker ): 7881*da0073e9SAndroid Build Coastguard Worker view_a.copy_(b) 7882*da0073e9SAndroid Build Coastguard Worker 7883*da0073e9SAndroid Build Coastguard Worker # Functions that should throw must properly throw 7884*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, 2) 7885*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, requires_grad=True) 7886*da0073e9SAndroid Build Coastguard Worker view_a = a.unbind()[0] 7887*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 7888*da0073e9SAndroid Build Coastguard Worker RuntimeError, 7889*da0073e9SAndroid Build Coastguard Worker "This view is the output of a function that returns " "multiple views.", 7890*da0073e9SAndroid Build Coastguard Worker ): 7891*da0073e9SAndroid Build Coastguard Worker view_a.copy_(b) 7892*da0073e9SAndroid Build Coastguard Worker 7893*da0073e9SAndroid Build Coastguard Worker # Sanity check that views that should work still work 7894*da0073e9SAndroid Build Coastguard Worker a = torch.rand(1, 2) 7895*da0073e9SAndroid Build Coastguard Worker b = torch.rand(1, requires_grad=True) 7896*da0073e9SAndroid Build Coastguard Worker a.select(1, 0).copy_(b) 7897*da0073e9SAndroid Build Coastguard Worker 7898*da0073e9SAndroid Build Coastguard Worker def _do_test_autograd_simple_views_python(self, dtype): 7899*da0073e9SAndroid Build Coastguard Worker # This is not necessarily the absolute correct behavior, but this is the current 7900*da0073e9SAndroid Build Coastguard Worker # one. This test is here to make sure that any change to this behavior is detected 7901*da0073e9SAndroid Build Coastguard Worker # and not silent. The TODOs below mark the places with unexpected behavior. 7902*da0073e9SAndroid Build Coastguard Worker # Note that any change in these test will be BC-breaking and should be done carefully. 7903*da0073e9SAndroid Build Coastguard Worker 7904*da0073e9SAndroid Build Coastguard Worker # This checks the autograd.Function behavior when we return one or multiple outputs 7905*da0073e9SAndroid Build Coastguard Worker # while one of these is an input, a view of an input or of a temporary tensor. 7906*da0073e9SAndroid Build Coastguard Worker 7907*da0073e9SAndroid Build Coastguard Worker # This indicator is used to track how many times the backward function was called 7908*da0073e9SAndroid Build Coastguard Worker bw_called = [0] 7909*da0073e9SAndroid Build Coastguard Worker # This indicator is used to check if the argument `ga` contains non-zero values 7910*da0073e9SAndroid Build Coastguard Worker ga_nz = [False] 7911*da0073e9SAndroid Build Coastguard Worker 7912*da0073e9SAndroid Build Coastguard Worker class IdOneOutput(Function): 7913*da0073e9SAndroid Build Coastguard Worker @staticmethod 7914*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b, make_view): 7915*da0073e9SAndroid Build Coastguard Worker if make_view: 7916*da0073e9SAndroid Build Coastguard Worker a = a.narrow(0, 0, 2) 7917*da0073e9SAndroid Build Coastguard Worker else: 7918*da0073e9SAndroid Build Coastguard Worker a = a.clone() 7919*da0073e9SAndroid Build Coastguard Worker return a 7920*da0073e9SAndroid Build Coastguard Worker 7921*da0073e9SAndroid Build Coastguard Worker @staticmethod 7922*da0073e9SAndroid Build Coastguard Worker def backward(ctx, ga): 7923*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 7924*da0073e9SAndroid Build Coastguard Worker return ga, None, None 7925*da0073e9SAndroid Build Coastguard Worker 7926*da0073e9SAndroid Build Coastguard Worker class IdTwoOutput(Function): 7927*da0073e9SAndroid Build Coastguard Worker @staticmethod 7928*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b, make_view): 7929*da0073e9SAndroid Build Coastguard Worker if make_view: 7930*da0073e9SAndroid Build Coastguard Worker a = a.narrow(0, 0, 2) 7931*da0073e9SAndroid Build Coastguard Worker else: 7932*da0073e9SAndroid Build Coastguard Worker a = a.clone() 7933*da0073e9SAndroid Build Coastguard Worker return a, a + b 7934*da0073e9SAndroid Build Coastguard Worker 7935*da0073e9SAndroid Build Coastguard Worker @staticmethod 7936*da0073e9SAndroid Build Coastguard Worker def backward(ctx, ga, gab): 7937*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 7938*da0073e9SAndroid Build Coastguard Worker if ga.eq(0).all(): 7939*da0073e9SAndroid Build Coastguard Worker ga_nz[0] = False 7940*da0073e9SAndroid Build Coastguard Worker else: 7941*da0073e9SAndroid Build Coastguard Worker ga_nz[0] = True 7942*da0073e9SAndroid Build Coastguard Worker return ga + gab, gab, None 7943*da0073e9SAndroid Build Coastguard Worker 7944*da0073e9SAndroid Build Coastguard Worker class ViewOfTemp(Function): 7945*da0073e9SAndroid Build Coastguard Worker @staticmethod 7946*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, make_view): 7947*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(a) 7948*da0073e9SAndroid Build Coastguard Worker if make_view: 7949*da0073e9SAndroid Build Coastguard Worker a = a.narrow(0, 0, 2) 7950*da0073e9SAndroid Build Coastguard Worker else: 7951*da0073e9SAndroid Build Coastguard Worker a = a.clone() 7952*da0073e9SAndroid Build Coastguard Worker b = a.clone() 7953*da0073e9SAndroid Build Coastguard Worker return b.select(0, 0) 7954*da0073e9SAndroid Build Coastguard Worker 7955*da0073e9SAndroid Build Coastguard Worker @staticmethod 7956*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 7957*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 7958*da0073e9SAndroid Build Coastguard Worker (a,) = ctx.saved_tensors 7959*da0073e9SAndroid Build Coastguard Worker res = torch.zeros_like(a) 7960*da0073e9SAndroid Build Coastguard Worker res.select(0, 0).copy_(grad) 7961*da0073e9SAndroid Build Coastguard Worker return res, None 7962*da0073e9SAndroid Build Coastguard Worker 7963*da0073e9SAndroid Build Coastguard Worker fn_id_to_inplace_on_view_err_msg = { 7964*da0073e9SAndroid Build Coastguard Worker "one_output": ( 7965*da0073e9SAndroid Build Coastguard Worker "Output 0 of IdOneOutputBackward is a view and is being " 7966*da0073e9SAndroid Build Coastguard Worker "modified inplace. This view was created inside a custom Function" 7967*da0073e9SAndroid Build Coastguard Worker ), 7968*da0073e9SAndroid Build Coastguard Worker "two_output": ( 7969*da0073e9SAndroid Build Coastguard Worker "Output 0 of IdTwoOutputBackward is a view and is being modified inplace." 7970*da0073e9SAndroid Build Coastguard Worker " This view is the output of a function that returns multiple views." 7971*da0073e9SAndroid Build Coastguard Worker ), 7972*da0073e9SAndroid Build Coastguard Worker "view_of_temp": ( 7973*da0073e9SAndroid Build Coastguard Worker "Output 0 of ViewOfTempBackward is a view and is being " 7974*da0073e9SAndroid Build Coastguard Worker "modified inplace. This view was created inside a custom Function" 7975*da0073e9SAndroid Build Coastguard Worker ), 7976*da0073e9SAndroid Build Coastguard Worker } 7977*da0073e9SAndroid Build Coastguard Worker 7978*da0073e9SAndroid Build Coastguard Worker for fn_id in ["one_output", "two_output", "view_of_temp"]: 7979*da0073e9SAndroid Build Coastguard Worker for inplace in [True, False]: 7980*da0073e9SAndroid Build Coastguard Worker for make_view in [True, False]: 7981*da0073e9SAndroid Build Coastguard Worker # Used for special casing the tests below 7982*da0073e9SAndroid Build Coastguard Worker output_is_a_view = make_view or fn_id == "view_of_temp" 7983*da0073e9SAndroid Build Coastguard Worker 7984*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 7985*da0073e9SAndroid Build Coastguard Worker # never modify a, b inplace for gracheck 7986*da0073e9SAndroid Build Coastguard Worker a = a.clone() 7987*da0073e9SAndroid Build Coastguard Worker b = b.clone() 7988*da0073e9SAndroid Build Coastguard Worker if fn_id == "two_output": 7989*da0073e9SAndroid Build Coastguard Worker tmp1, tmp2 = IdTwoOutput.apply(a, b, make_view) 7990*da0073e9SAndroid Build Coastguard Worker if inplace: 7991*da0073e9SAndroid Build Coastguard Worker tmp1 += 3 7992*da0073e9SAndroid Build Coastguard Worker tmp2 += 3 7993*da0073e9SAndroid Build Coastguard Worker else: 7994*da0073e9SAndroid Build Coastguard Worker tmp1 = tmp1 + 3 7995*da0073e9SAndroid Build Coastguard Worker tmp2 = tmp2 + 3 7996*da0073e9SAndroid Build Coastguard Worker tmp = tmp1 * tmp2 7997*da0073e9SAndroid Build Coastguard Worker else: 7998*da0073e9SAndroid Build Coastguard Worker if fn_id == "one_output": 7999*da0073e9SAndroid Build Coastguard Worker tmp = IdOneOutput.apply(a, b, make_view) 8000*da0073e9SAndroid Build Coastguard Worker else: 8001*da0073e9SAndroid Build Coastguard Worker tmp = ViewOfTemp.apply(a + b, make_view) 8002*da0073e9SAndroid Build Coastguard Worker if inplace: 8003*da0073e9SAndroid Build Coastguard Worker tmp += 3 8004*da0073e9SAndroid Build Coastguard Worker else: 8005*da0073e9SAndroid Build Coastguard Worker tmp = tmp + 3 8006*da0073e9SAndroid Build Coastguard Worker 8007*da0073e9SAndroid Build Coastguard Worker return tmp.sum() 8008*da0073e9SAndroid Build Coastguard Worker 8009*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, dtype=dtype, requires_grad=True) 8010*da0073e9SAndroid Build Coastguard Worker b = torch.ones(2, dtype=dtype, requires_grad=True) 8011*da0073e9SAndroid Build Coastguard Worker 8012*da0073e9SAndroid Build Coastguard Worker err_msg = fn_id_to_inplace_on_view_err_msg[fn_id] 8013*da0073e9SAndroid Build Coastguard Worker 8014*da0073e9SAndroid Build Coastguard Worker if not inplace or not output_is_a_view: 8015*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (a, b), check_batched_grad=False) 8016*da0073e9SAndroid Build Coastguard Worker 8017*da0073e9SAndroid Build Coastguard Worker # Was the custom backward called properly 8018*da0073e9SAndroid Build Coastguard Worker bw_called[0] = 0 8019*da0073e9SAndroid Build Coastguard Worker ga_nz[0] = True # For the case where the backward is called 8020*da0073e9SAndroid Build Coastguard Worker 8021*da0073e9SAndroid Build Coastguard Worker if inplace and output_is_a_view: 8022*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 8023*da0073e9SAndroid Build Coastguard Worker fn(a, b) 8024*da0073e9SAndroid Build Coastguard Worker else: 8025*da0073e9SAndroid Build Coastguard Worker fn(a, b).abs().backward() 8026*da0073e9SAndroid Build Coastguard Worker 8027*da0073e9SAndroid Build Coastguard Worker expected_called = 1 8028*da0073e9SAndroid Build Coastguard Worker expected_ga_nz = True 8029*da0073e9SAndroid Build Coastguard Worker 8030*da0073e9SAndroid Build Coastguard Worker if output_is_a_view and inplace: 8031*da0073e9SAndroid Build Coastguard Worker expected_called = 0 8032*da0073e9SAndroid Build Coastguard Worker 8033*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == expected_called) 8034*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ga_nz[0] == expected_ga_nz) 8035*da0073e9SAndroid Build Coastguard Worker 8036*da0073e9SAndroid Build Coastguard Worker def test_autograd_simple_views_python(self): 8037*da0073e9SAndroid Build Coastguard Worker self._do_test_autograd_simple_views_python(torch.double) 8038*da0073e9SAndroid Build Coastguard Worker self._do_test_autograd_simple_views_python(torch.cdouble) 8039*da0073e9SAndroid Build Coastguard Worker 8040*da0073e9SAndroid Build Coastguard Worker def test_autograd_inplace_views_creation_meta(self): 8041*da0073e9SAndroid Build Coastguard Worker # Tests creation_meta properly handled for inplace views 8042*da0073e9SAndroid Build Coastguard Worker 8043*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8044*da0073e9SAndroid Build Coastguard Worker @staticmethod 8045*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 8046*da0073e9SAndroid Build Coastguard Worker return x.view_as(x) 8047*da0073e9SAndroid Build Coastguard Worker 8048*da0073e9SAndroid Build Coastguard Worker @staticmethod 8049*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 8050*da0073e9SAndroid Build Coastguard Worker return x 8051*da0073e9SAndroid Build Coastguard Worker 8052*da0073e9SAndroid Build Coastguard Worker view_custom = Func.apply 8053*da0073e9SAndroid Build Coastguard Worker 8054*da0073e9SAndroid Build Coastguard Worker def run_test( 8055*da0073e9SAndroid Build Coastguard Worker fn, fn_type, grad_mode_view, grad_mode_iview, requires_grad, error1, error2 8056*da0073e9SAndroid Build Coastguard Worker ): 8057*da0073e9SAndroid Build Coastguard Worker # This test checks the behavior of inplace-view functions when 8058*da0073e9SAndroid Build Coastguard Worker # the views are created in grad mode or not 8059*da0073e9SAndroid Build Coastguard Worker base = torch.rand(2, 3, requires_grad=requires_grad).clone() 8060*da0073e9SAndroid Build Coastguard Worker # 1. Create a view with `grad_mode=grad_mode_view` 8061*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(grad_mode_view): 8062*da0073e9SAndroid Build Coastguard Worker if fn_type == "multi_view": 8063*da0073e9SAndroid Build Coastguard Worker inp = base.unbind()[0] 8064*da0073e9SAndroid Build Coastguard Worker elif fn_type == "custom": 8065*da0073e9SAndroid Build Coastguard Worker inp = view_custom(base) 8066*da0073e9SAndroid Build Coastguard Worker else: 8067*da0073e9SAndroid Build Coastguard Worker inp = base.view_as(base) 8068*da0073e9SAndroid Build Coastguard Worker 8069*da0073e9SAndroid Build Coastguard Worker # 2. Perform inplace view with `grad_mode=grad_mode_iview` 8070*da0073e9SAndroid Build Coastguard Worker with torch.set_grad_enabled(grad_mode_iview): 8071*da0073e9SAndroid Build Coastguard Worker if error1 is not None: 8072*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error1): 8073*da0073e9SAndroid Build Coastguard Worker fn(inp) 8074*da0073e9SAndroid Build Coastguard Worker return 8075*da0073e9SAndroid Build Coastguard Worker else: 8076*da0073e9SAndroid Build Coastguard Worker # If error is None, check that runs without error 8077*da0073e9SAndroid Build Coastguard Worker fn(inp) 8078*da0073e9SAndroid Build Coastguard Worker # 3. Do inplace on the (new) view 8079*da0073e9SAndroid Build Coastguard Worker if error2 is not None: 8080*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error2): 8081*da0073e9SAndroid Build Coastguard Worker inp.add_(1) 8082*da0073e9SAndroid Build Coastguard Worker else: 8083*da0073e9SAndroid Build Coastguard Worker # If error is None, check that runs without error 8084*da0073e9SAndroid Build Coastguard Worker inp.add_(1) 8085*da0073e9SAndroid Build Coastguard Worker 8086*da0073e9SAndroid Build Coastguard Worker no_grad_err = "A view was created in no_grad mode" 8087*da0073e9SAndroid Build Coastguard Worker multi_view_err = "function that returns multiple views" 8088*da0073e9SAndroid Build Coastguard Worker custom_err = "view was created inside a custom Function" 8089*da0073e9SAndroid Build Coastguard Worker 8090*da0073e9SAndroid Build Coastguard Worker def run_tests(fn): 8091*da0073e9SAndroid Build Coastguard Worker for fn_type in ("normal", "multi_view", "custom"): 8092*da0073e9SAndroid Build Coastguard Worker for grad_mode_view in (True, False): 8093*da0073e9SAndroid Build Coastguard Worker for grad_mode_iview in (True, False): 8094*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 8095*da0073e9SAndroid Build Coastguard Worker error1 = None # expected error when we do inplace_view on original view 8096*da0073e9SAndroid Build Coastguard Worker error2 = None # expected error when we do inplace on the resulting view 8097*da0073e9SAndroid Build Coastguard Worker 8098*da0073e9SAndroid Build Coastguard Worker if requires_grad: 8099*da0073e9SAndroid Build Coastguard Worker if not grad_mode_view and grad_mode_iview: 8100*da0073e9SAndroid Build Coastguard Worker error1 = no_grad_err 8101*da0073e9SAndroid Build Coastguard Worker if not grad_mode_view and not grad_mode_iview: 8102*da0073e9SAndroid Build Coastguard Worker error2 = no_grad_err 8103*da0073e9SAndroid Build Coastguard Worker 8104*da0073e9SAndroid Build Coastguard Worker if fn_type == "multi_view": 8105*da0073e9SAndroid Build Coastguard Worker if grad_mode_view and grad_mode_iview: 8106*da0073e9SAndroid Build Coastguard Worker error1 = multi_view_err 8107*da0073e9SAndroid Build Coastguard Worker if grad_mode_view and not grad_mode_iview: 8108*da0073e9SAndroid Build Coastguard Worker error2 = multi_view_err 8109*da0073e9SAndroid Build Coastguard Worker 8110*da0073e9SAndroid Build Coastguard Worker if fn_type == "custom": 8111*da0073e9SAndroid Build Coastguard Worker if grad_mode_view and grad_mode_iview: 8112*da0073e9SAndroid Build Coastguard Worker error1 = custom_err 8113*da0073e9SAndroid Build Coastguard Worker if grad_mode_view and not grad_mode_iview: 8114*da0073e9SAndroid Build Coastguard Worker error2 = custom_err 8115*da0073e9SAndroid Build Coastguard Worker 8116*da0073e9SAndroid Build Coastguard Worker run_test( 8117*da0073e9SAndroid Build Coastguard Worker fn, 8118*da0073e9SAndroid Build Coastguard Worker fn_type, 8119*da0073e9SAndroid Build Coastguard Worker grad_mode_view, 8120*da0073e9SAndroid Build Coastguard Worker grad_mode_iview, 8121*da0073e9SAndroid Build Coastguard Worker requires_grad, 8122*da0073e9SAndroid Build Coastguard Worker error1, 8123*da0073e9SAndroid Build Coastguard Worker error2, 8124*da0073e9SAndroid Build Coastguard Worker ) 8125*da0073e9SAndroid Build Coastguard Worker 8126*da0073e9SAndroid Build Coastguard Worker # This list was created by logging gen_inplace_or_view_type.py 8127*da0073e9SAndroid Build Coastguard Worker # detach_ is excluded for this test because it cannot be applied to 8128*da0073e9SAndroid Build Coastguard Worker # views and thus does not return a view 8129*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.as_strided_((1, 0), (2, 2))) 8130*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.transpose_(0, 0)) 8131*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.t_()) 8132*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.squeeze_(0)) 8133*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.unsqueeze_(0)) 8134*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.swapdims_(0, 0)) 8135*da0073e9SAndroid Build Coastguard Worker run_tests(lambda v: v.swapaxes_(0, 0)) 8136*da0073e9SAndroid Build Coastguard Worker 8137*da0073e9SAndroid Build Coastguard Worker def test_autograd_print_tensor(self): 8138*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, requires_grad=True) 8139*da0073e9SAndroid Build Coastguard Worker a_clone = a.clone() 8140*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(a), "tensor([1.], requires_grad=True)") 8141*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(a_clone), "tensor([1.], grad_fn=<CloneBackward0>)") 8142*da0073e9SAndroid Build Coastguard Worker 8143*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8144*da0073e9SAndroid Build Coastguard Worker b = a[:] 8145*da0073e9SAndroid Build Coastguard Worker b *= 2 8146*da0073e9SAndroid Build Coastguard Worker 8147*da0073e9SAndroid Build Coastguard Worker # Special handling for printing view created in no-grad and modified 8148*da0073e9SAndroid Build Coastguard Worker # in-placed in no-grad. 8149*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(b), "tensor([2.], grad_fn=<Invalid>)") 8150*da0073e9SAndroid Build Coastguard Worker 8151*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8152*da0073e9SAndroid Build Coastguard Worker @staticmethod 8153*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 8154*da0073e9SAndroid Build Coastguard Worker return x 8155*da0073e9SAndroid Build Coastguard Worker 8156*da0073e9SAndroid Build Coastguard Worker @staticmethod 8157*da0073e9SAndroid Build Coastguard Worker def backward(ctx, x): 8158*da0073e9SAndroid Build Coastguard Worker return x 8159*da0073e9SAndroid Build Coastguard Worker 8160*da0073e9SAndroid Build Coastguard Worker c = Func.apply(a) 8161*da0073e9SAndroid Build Coastguard Worker self.assertEqual(repr(c), "tensor([2.], grad_fn=<FuncBackward>)") 8162*da0073e9SAndroid Build Coastguard Worker 8163*da0073e9SAndroid Build Coastguard Worker def test_autograd_inplace_view_of_view(self): 8164*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 8165*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8166*da0073e9SAndroid Build Coastguard Worker y = x.view(2) 8167*da0073e9SAndroid Build Coastguard Worker y.requires_grad_(True) 8168*da0073e9SAndroid Build Coastguard Worker z = y.view(2) 8169*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8170*da0073e9SAndroid Build Coastguard Worker RuntimeError, "a view of a view .* is being .* inside the no_grad block" 8171*da0073e9SAndroid Build Coastguard Worker ): 8172*da0073e9SAndroid Build Coastguard Worker z /= 2 8173*da0073e9SAndroid Build Coastguard Worker 8174*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(2) 8175*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 8176*da0073e9SAndroid Build Coastguard Worker y = x.view(2) 8177*da0073e9SAndroid Build Coastguard Worker y.requires_grad_(True) 8178*da0073e9SAndroid Build Coastguard Worker z = y.view(2) 8179*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8180*da0073e9SAndroid Build Coastguard Worker RuntimeError, "a view of a view .* is being .* inside the inference_mode" 8181*da0073e9SAndroid Build Coastguard Worker ): 8182*da0073e9SAndroid Build Coastguard Worker z /= 2 8183*da0073e9SAndroid Build Coastguard Worker 8184*da0073e9SAndroid Build Coastguard Worker # TODO This is not the correct behavior - 8185*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/49825#issuecomment-794466627 8186*da0073e9SAndroid Build Coastguard Worker def test_autograd_inplace_views_cross_dtype(self): 8187*da0073e9SAndroid Build Coastguard Worker # This test is here to make sure that any change to this behavior is detected 8188*da0073e9SAndroid Build Coastguard Worker # and not silent. The TODOs below mark the places with unexpected behavior. 8189*da0073e9SAndroid Build Coastguard Worker a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64) 8190*da0073e9SAndroid Build Coastguard Worker a = a_orig.clone() 8191*da0073e9SAndroid Build Coastguard Worker b = torch.view_as_real(a) 8192*da0073e9SAndroid Build Coastguard Worker b = b.transpose(0, 1) 8193*da0073e9SAndroid Build Coastguard Worker b += 1 8194*da0073e9SAndroid Build Coastguard Worker b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2)) 8195*da0073e9SAndroid Build Coastguard Worker non_inplace_grad = a_orig.grad 8196*da0073e9SAndroid Build Coastguard Worker 8197*da0073e9SAndroid Build Coastguard Worker a_orig = torch.rand(3, 3, requires_grad=True, dtype=torch.complex64) 8198*da0073e9SAndroid Build Coastguard Worker a = a_orig.clone() 8199*da0073e9SAndroid Build Coastguard Worker b = torch.view_as_real(a) 8200*da0073e9SAndroid Build Coastguard Worker b.transpose_(0, 1) 8201*da0073e9SAndroid Build Coastguard Worker b += 1 8202*da0073e9SAndroid Build Coastguard Worker b.backward(torch.arange(0, 18, dtype=torch.float).view(3, 3, 2)) 8203*da0073e9SAndroid Build Coastguard Worker inplace_grad = a_orig.grad 8204*da0073e9SAndroid Build Coastguard Worker 8205*da0073e9SAndroid Build Coastguard Worker # TODO: this is a bug! 8206*da0073e9SAndroid Build Coastguard Worker # once this is fixed, it should have the transpose removed: 8207*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(non_inplace_grad, inplace_grad) 8208*da0073e9SAndroid Build Coastguard Worker self.assertEqual(non_inplace_grad.T, inplace_grad) 8209*da0073e9SAndroid Build Coastguard Worker 8210*da0073e9SAndroid Build Coastguard Worker def test_autograd_multiple_views_python(self): 8211*da0073e9SAndroid Build Coastguard Worker # This is not necessarily the absolute correct behavior, but this is the current 8212*da0073e9SAndroid Build Coastguard Worker # one. This test is here to make sure that any change to this behavior is detected 8213*da0073e9SAndroid Build Coastguard Worker # and not silent. The TODOs below mark the places with unexpected behavior. 8214*da0073e9SAndroid Build Coastguard Worker # Note that any change in these test will be BC-breaking and should be done carefully. 8215*da0073e9SAndroid Build Coastguard Worker 8216*da0073e9SAndroid Build Coastguard Worker # This checks that multiples views in the forward are properly traced and how they 8217*da0073e9SAndroid Build Coastguard Worker # behave with respect to inplace operations. 8218*da0073e9SAndroid Build Coastguard Worker 8219*da0073e9SAndroid Build Coastguard Worker # This indicator is used to track how many times the backward function was called 8220*da0073e9SAndroid Build Coastguard Worker bw_called = [0] 8221*da0073e9SAndroid Build Coastguard Worker 8222*da0073e9SAndroid Build Coastguard Worker class ComplexView(Function): 8223*da0073e9SAndroid Build Coastguard Worker @staticmethod 8224*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, idx): 8225*da0073e9SAndroid Build Coastguard Worker res = a.narrow(0, idx, 1) 8226*da0073e9SAndroid Build Coastguard Worker res = a.select(0, idx) 8227*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(a) 8228*da0073e9SAndroid Build Coastguard Worker ctx.idx = idx 8229*da0073e9SAndroid Build Coastguard Worker return res 8230*da0073e9SAndroid Build Coastguard Worker 8231*da0073e9SAndroid Build Coastguard Worker @staticmethod 8232*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 8233*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 8234*da0073e9SAndroid Build Coastguard Worker (a,) = ctx.saved_tensors 8235*da0073e9SAndroid Build Coastguard Worker res = torch.zeros_like(a) 8236*da0073e9SAndroid Build Coastguard Worker res.select(0, ctx.idx).copy_(grad) 8237*da0073e9SAndroid Build Coastguard Worker return res, None 8238*da0073e9SAndroid Build Coastguard Worker 8239*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, requires_grad=True) 8240*da0073e9SAndroid Build Coastguard Worker idx = 1 8241*da0073e9SAndroid Build Coastguard Worker 8242*da0073e9SAndroid Build Coastguard Worker bw_called[0] = 0 8243*da0073e9SAndroid Build Coastguard Worker out = ComplexView.apply(a.clone(), idx) 8244*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 8245*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == 1) 8246*da0073e9SAndroid Build Coastguard Worker 8247*da0073e9SAndroid Build Coastguard Worker out = ComplexView.apply(a.clone(), idx) 8248*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8249*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8250*da0073e9SAndroid Build Coastguard Worker "Output 0 of ComplexViewBackward is a view and is being modified inplace", 8251*da0073e9SAndroid Build Coastguard Worker ): 8252*da0073e9SAndroid Build Coastguard Worker out += 1 8253*da0073e9SAndroid Build Coastguard Worker 8254*da0073e9SAndroid Build Coastguard Worker def test_autograd_python_custom_function_inplace(self): 8255*da0073e9SAndroid Build Coastguard Worker # This is not necessarily the absolute correct behavior, but this is the current 8256*da0073e9SAndroid Build Coastguard Worker # one. This test is here to make sure that any change to this behavior is detected 8257*da0073e9SAndroid Build Coastguard Worker # and not silent. The TODOs below mark the places with unexpected behavior. 8258*da0073e9SAndroid Build Coastguard Worker # Note that any change in these test will be BC-breaking and should be done carefully. 8259*da0073e9SAndroid Build Coastguard Worker 8260*da0073e9SAndroid Build Coastguard Worker # This test checks custom autograd.Function that perform inplace operations 8261*da0073e9SAndroid Build Coastguard Worker 8262*da0073e9SAndroid Build Coastguard Worker bw_called = [0] 8263*da0073e9SAndroid Build Coastguard Worker 8264*da0073e9SAndroid Build Coastguard Worker # I) Single output 8265*da0073e9SAndroid Build Coastguard Worker class MyAdder(Function): 8266*da0073e9SAndroid Build Coastguard Worker @staticmethod 8267*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 8268*da0073e9SAndroid Build Coastguard Worker a.add_(b) 8269*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(a) 8270*da0073e9SAndroid Build Coastguard Worker return a 8271*da0073e9SAndroid Build Coastguard Worker 8272*da0073e9SAndroid Build Coastguard Worker @staticmethod 8273*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 8274*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 8275*da0073e9SAndroid Build Coastguard Worker return grad, grad 8276*da0073e9SAndroid Build Coastguard Worker 8277*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, requires_grad=True) 8278*da0073e9SAndroid Build Coastguard Worker b = torch.ones(2, requires_grad=True) 8279*da0073e9SAndroid Build Coastguard Worker 8280*da0073e9SAndroid Build Coastguard Worker # No extra inplace 8281*da0073e9SAndroid Build Coastguard Worker c = MyAdder.apply(a.clone(), b) 8282*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 8283*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == 1) 8284*da0073e9SAndroid Build Coastguard Worker 8285*da0073e9SAndroid Build Coastguard Worker # With extra inplace on the output 8286*da0073e9SAndroid Build Coastguard Worker bw_called[0] = 0 8287*da0073e9SAndroid Build Coastguard Worker c = MyAdder.apply(a.clone(), b) 8288*da0073e9SAndroid Build Coastguard Worker c += 2 8289*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 8290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == 1) 8291*da0073e9SAndroid Build Coastguard Worker 8292*da0073e9SAndroid Build Coastguard Worker # The input is a view 8293*da0073e9SAndroid Build Coastguard Worker bw_called[0] = 0 8294*da0073e9SAndroid Build Coastguard Worker c = MyAdder.apply(a.clone().view_as(a), b) 8295*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 8296*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == 1) 8297*da0073e9SAndroid Build Coastguard Worker 8298*da0073e9SAndroid Build Coastguard Worker # Should not give non-inputs to mark_dirty 8299*da0073e9SAndroid Build Coastguard Worker class MyAdderBad(Function): 8300*da0073e9SAndroid Build Coastguard Worker @staticmethod 8301*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 8302*da0073e9SAndroid Build Coastguard Worker c = 3 * a 8303*da0073e9SAndroid Build Coastguard Worker c.add_(b) 8304*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(c) 8305*da0073e9SAndroid Build Coastguard Worker return c 8306*da0073e9SAndroid Build Coastguard Worker 8307*da0073e9SAndroid Build Coastguard Worker @staticmethod 8308*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 8309*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 8310*da0073e9SAndroid Build Coastguard Worker grad = 3 * grad 8311*da0073e9SAndroid Build Coastguard Worker return grad, grad 8312*da0073e9SAndroid Build Coastguard Worker 8313*da0073e9SAndroid Build Coastguard Worker a = torch.ones(2, requires_grad=True) 8314*da0073e9SAndroid Build Coastguard Worker b = torch.ones(2, requires_grad=True) 8315*da0073e9SAndroid Build Coastguard Worker 8316*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 8317*da0073e9SAndroid Build Coastguard Worker MyAdderBad.apply(a.clone(), b) 8318*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 1) 8319*da0073e9SAndroid Build Coastguard Worker 8320*da0073e9SAndroid Build Coastguard Worker # II) Multiple outputs 8321*da0073e9SAndroid Build Coastguard Worker class MyBadAdder(Function): 8322*da0073e9SAndroid Build Coastguard Worker @staticmethod 8323*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 8324*da0073e9SAndroid Build Coastguard Worker a.add_(b) 8325*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(a) 8326*da0073e9SAndroid Build Coastguard Worker return a, a + b 8327*da0073e9SAndroid Build Coastguard Worker 8328*da0073e9SAndroid Build Coastguard Worker @staticmethod 8329*da0073e9SAndroid Build Coastguard Worker def backward(ctx, ga, gab): 8330*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 8331*da0073e9SAndroid Build Coastguard Worker return ga + gab, ga + gab 8332*da0073e9SAndroid Build Coastguard Worker 8333*da0073e9SAndroid Build Coastguard Worker # No extra inplace 8334*da0073e9SAndroid Build Coastguard Worker bw_called[0] = 0 8335*da0073e9SAndroid Build Coastguard Worker c, d = MyBadAdder.apply(a.clone(), b) 8336*da0073e9SAndroid Build Coastguard Worker (c * d).sum().backward() 8337*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == 1) 8338*da0073e9SAndroid Build Coastguard Worker 8339*da0073e9SAndroid Build Coastguard Worker # With extra inplace on the output 8340*da0073e9SAndroid Build Coastguard Worker bw_called[0] = 0 8341*da0073e9SAndroid Build Coastguard Worker c, d = MyBadAdder.apply(a.clone(), b) 8342*da0073e9SAndroid Build Coastguard Worker c += 2 8343*da0073e9SAndroid Build Coastguard Worker (c * d).sum().backward() 8344*da0073e9SAndroid Build Coastguard Worker self.assertTrue(bw_called[0] == 1) 8345*da0073e9SAndroid Build Coastguard Worker 8346*da0073e9SAndroid Build Coastguard Worker # The input is a view 8347*da0073e9SAndroid Build Coastguard Worker inplace_on_view_err = ( 8348*da0073e9SAndroid Build Coastguard Worker "your Function modifies inplace an input that is a view of another Tensor" 8349*da0073e9SAndroid Build Coastguard Worker ) 8350*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, inplace_on_view_err): 8351*da0073e9SAndroid Build Coastguard Worker c, d = MyBadAdder.apply(a.clone().view_as(a), b) 8352*da0073e9SAndroid Build Coastguard Worker 8353*da0073e9SAndroid Build Coastguard Worker # III) Inplace + other op 8354*da0073e9SAndroid Build Coastguard Worker class MyOutPlaceAdder(Function): 8355*da0073e9SAndroid Build Coastguard Worker @staticmethod 8356*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a, b): 8357*da0073e9SAndroid Build Coastguard Worker a.add_(b) 8358*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(a) 8359*da0073e9SAndroid Build Coastguard Worker return a.clone(), a + b 8360*da0073e9SAndroid Build Coastguard Worker 8361*da0073e9SAndroid Build Coastguard Worker @staticmethod 8362*da0073e9SAndroid Build Coastguard Worker def backward(ctx, ga, gab): 8363*da0073e9SAndroid Build Coastguard Worker bw_called[0] += 1 8364*da0073e9SAndroid Build Coastguard Worker return ga + gab, ga + 2 * gab 8365*da0073e9SAndroid Build Coastguard Worker 8366*da0073e9SAndroid Build Coastguard Worker # We don't reuse the input 8367*da0073e9SAndroid Build Coastguard Worker def fn(a, b): 8368*da0073e9SAndroid Build Coastguard Worker orig_a = a.clone().view_as(a) 8369*da0073e9SAndroid Build Coastguard Worker c, d = MyOutPlaceAdder.apply(orig_a, b) 8370*da0073e9SAndroid Build Coastguard Worker return (c * d).sum() 8371*da0073e9SAndroid Build Coastguard Worker 8372*da0073e9SAndroid Build Coastguard Worker bad_mark_dirty_err = "Some elements marked as dirty during the forward method were not returned as output." 8373*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err): 8374*da0073e9SAndroid Build Coastguard Worker fn(a, b) 8375*da0073e9SAndroid Build Coastguard Worker 8376*da0073e9SAndroid Build Coastguard Worker def test_custom_function_mark_dirty_not_differentiable(self): 8377*da0073e9SAndroid Build Coastguard Worker def get_custom_fn(jvp_err): 8378*da0073e9SAndroid Build Coastguard Worker class InplaceMul(torch.autograd.Function): 8379*da0073e9SAndroid Build Coastguard Worker @staticmethod 8380*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 8381*da0073e9SAndroid Build Coastguard Worker result = x.mul_(2) 8382*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(result) 8383*da0073e9SAndroid Build Coastguard Worker return result 8384*da0073e9SAndroid Build Coastguard Worker 8385*da0073e9SAndroid Build Coastguard Worker @staticmethod 8386*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 8387*da0073e9SAndroid Build Coastguard Worker pass 8388*da0073e9SAndroid Build Coastguard Worker 8389*da0073e9SAndroid Build Coastguard Worker @staticmethod 8390*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_t): 8391*da0073e9SAndroid Build Coastguard Worker if jvp_err: 8392*da0073e9SAndroid Build Coastguard Worker return x_t 8393*da0073e9SAndroid Build Coastguard Worker else: 8394*da0073e9SAndroid Build Coastguard Worker return x_t.mul_(2) 8395*da0073e9SAndroid Build Coastguard Worker 8396*da0073e9SAndroid Build Coastguard Worker return InplaceMul 8397*da0073e9SAndroid Build Coastguard Worker 8398*da0073e9SAndroid Build Coastguard Worker for requires_grad, jvp_err in product([True, False], repeat=2): 8399*da0073e9SAndroid Build Coastguard Worker InplaceMul = get_custom_fn(jvp_err) 8400*da0073e9SAndroid Build Coastguard Worker # Make sure that tensor is always returned as-is if marked dirty 8401*da0073e9SAndroid Build Coastguard Worker z = torch.tensor(1.0, requires_grad=requires_grad) 8402*da0073e9SAndroid Build Coastguard Worker x = z.clone() 8403*da0073e9SAndroid Build Coastguard Worker y = InplaceMul.apply(x) 8404*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x is y) 8405*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, z * 2) 8406*da0073e9SAndroid Build Coastguard Worker 8407*da0073e9SAndroid Build Coastguard Worker # jvp must properly modify the input grad if mark_dirty is set 8408*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8409*da0073e9SAndroid Build Coastguard Worker x_tangent = torch.ones_like(x) 8410*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, x_tangent) 8411*da0073e9SAndroid Build Coastguard Worker 8412*da0073e9SAndroid Build Coastguard Worker if jvp_err: 8413*da0073e9SAndroid Build Coastguard Worker bad_mark_dirty_err = ( 8414*da0073e9SAndroid Build Coastguard Worker "jvp function must modify the corresponding gradient inplace" 8415*da0073e9SAndroid Build Coastguard Worker ) 8416*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err): 8417*da0073e9SAndroid Build Coastguard Worker InplaceMul.apply(x_dual) 8418*da0073e9SAndroid Build Coastguard Worker else: 8419*da0073e9SAndroid Build Coastguard Worker out_dual = InplaceMul.apply(x_dual) 8420*da0073e9SAndroid Build Coastguard Worker _, out_tangent = fwAD.unpack_dual(out_dual) 8421*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_dual is x_dual) 8422*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out_tangent is x_tangent) 8423*da0073e9SAndroid Build Coastguard Worker 8424*da0073e9SAndroid Build Coastguard Worker def test_named_tensor_for_complex_views(self): 8425*da0073e9SAndroid Build Coastguard Worker names = ["batch", "height", "width", "complex"] 8426*da0073e9SAndroid Build Coastguard Worker z = torch.ones((2, 1, 2, 2), requires_grad=True) 8427*da0073e9SAndroid Build Coastguard Worker z_named = z.refine_names(*names) 8428*da0073e9SAndroid Build Coastguard Worker z_complex = torch.view_as_complex(z_named.rename(None)).refine_names( 8429*da0073e9SAndroid Build Coastguard Worker *names[:-1] 8430*da0073e9SAndroid Build Coastguard Worker ) 8431*da0073e9SAndroid Build Coastguard Worker z_complex.sum().abs().backward() 8432*da0073e9SAndroid Build Coastguard Worker expected = torch.ones_like(z_complex).rename(None) 8433*da0073e9SAndroid Build Coastguard Worker abs_1_1j = abs(1 + 1j) 8434*da0073e9SAndroid Build Coastguard Worker expected.fill_(complex(abs_1_1j / 2, abs_1_1j / 2)) 8435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.grad, torch.view_as_real(expected)) 8436*da0073e9SAndroid Build Coastguard Worker 8437*da0073e9SAndroid Build Coastguard Worker def test_custom_function_return_view_in_nograd(self): 8438*da0073e9SAndroid Build Coastguard Worker class Alias(Function): 8439*da0073e9SAndroid Build Coastguard Worker @staticmethod 8440*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 8441*da0073e9SAndroid Build Coastguard Worker return x[:] 8442*da0073e9SAndroid Build Coastguard Worker 8443*da0073e9SAndroid Build Coastguard Worker @staticmethod 8444*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gx): 8445*da0073e9SAndroid Build Coastguard Worker return gx 8446*da0073e9SAndroid Build Coastguard Worker 8447*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, requires_grad=True) 8448*da0073e9SAndroid Build Coastguard Worker 8449*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8450*da0073e9SAndroid Build Coastguard Worker output = Alias.apply(inp) 8451*da0073e9SAndroid Build Coastguard Worker 8452*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 8453*da0073e9SAndroid Build Coastguard Worker expected_output = inp[:] 8454*da0073e9SAndroid Build Coastguard Worker 8455*da0073e9SAndroid Build Coastguard Worker # Calling the custom function should operate as if we called an equivalent op 8456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(output.requires_grad, expected_output.requires_grad) 8457*da0073e9SAndroid Build Coastguard Worker 8458*da0073e9SAndroid Build Coastguard Worker # Check that in-place modification on view throws 8459*da0073e9SAndroid Build Coastguard Worker leaf_grad_err = ( 8460*da0073e9SAndroid Build Coastguard Worker "A view was created in no_grad mode and is being modified inplace" 8461*da0073e9SAndroid Build Coastguard Worker ) 8462*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, leaf_grad_err): 8463*da0073e9SAndroid Build Coastguard Worker output.zero_() 8464*da0073e9SAndroid Build Coastguard Worker 8465*da0073e9SAndroid Build Coastguard Worker def test_custom_function_preserve_torch_function_when_return_as_is(self): 8466*da0073e9SAndroid Build Coastguard Worker class Custom(torch.Tensor): 8467*da0073e9SAndroid Build Coastguard Worker def __init__(self, data): 8468*da0073e9SAndroid Build Coastguard Worker super().__init__() 8469*da0073e9SAndroid Build Coastguard Worker self._data = data 8470*da0073e9SAndroid Build Coastguard Worker 8471*da0073e9SAndroid Build Coastguard Worker @classmethod 8472*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, func, types, args=(), kwargs=None): 8473*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 8474*da0073e9SAndroid Build Coastguard Worker args = tuple(a._data if isinstance(a, cls) else a for a in args) 8475*da0073e9SAndroid Build Coastguard Worker out = func(*args, **kwargs) 8476*da0073e9SAndroid Build Coastguard Worker if isinstance(out, torch.Tensor): 8477*da0073e9SAndroid Build Coastguard Worker out = cls(out) 8478*da0073e9SAndroid Build Coastguard Worker return out 8479*da0073e9SAndroid Build Coastguard Worker 8480*da0073e9SAndroid Build Coastguard Worker class Fn(torch.autograd.Function): 8481*da0073e9SAndroid Build Coastguard Worker @staticmethod 8482*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 8483*da0073e9SAndroid Build Coastguard Worker return input 8484*da0073e9SAndroid Build Coastguard Worker 8485*da0073e9SAndroid Build Coastguard Worker @staticmethod 8486*da0073e9SAndroid Build Coastguard Worker def backward(ctx): 8487*da0073e9SAndroid Build Coastguard Worker pass 8488*da0073e9SAndroid Build Coastguard Worker 8489*da0073e9SAndroid Build Coastguard Worker x = Custom(torch.randn(2, 3)) 8490*da0073e9SAndroid Build Coastguard Worker y = Fn.apply(x) 8491*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(y, Custom)) 8492*da0073e9SAndroid Build Coastguard Worker 8493*da0073e9SAndroid Build Coastguard Worker def test_grad_mode_restored_reentrant(self): 8494*da0073e9SAndroid Build Coastguard Worker class MyFunction(Function): 8495*da0073e9SAndroid Build Coastguard Worker @staticmethod 8496*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp): 8497*da0073e9SAndroid Build Coastguard Worker return inp.clone() 8498*da0073e9SAndroid Build Coastguard Worker 8499*da0073e9SAndroid Build Coastguard Worker @staticmethod 8500*da0073e9SAndroid Build Coastguard Worker def backward(ctx, go): 8501*da0073e9SAndroid Build Coastguard Worker original = torch._C.is_grad_enabled() 8502*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 8503*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C.is_grad_enabled()) 8504*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(go.size(), requires_grad=True) 8505*da0073e9SAndroid Build Coastguard Worker (grad,) = torch.autograd.grad(foo**3, foo, grad_outputs=go) 8506*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C.is_grad_enabled()) 8507*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C.is_grad_enabled() == original) 8508*da0073e9SAndroid Build Coastguard Worker return grad 8509*da0073e9SAndroid Build Coastguard Worker 8510*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(3, requires_grad=True) 8511*da0073e9SAndroid Build Coastguard Worker 8512*da0073e9SAndroid Build Coastguard Worker # Case where original==False 8513*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(inp).sum().backward() 8514*da0073e9SAndroid Build Coastguard Worker # Case where original==True 8515*da0073e9SAndroid Build Coastguard Worker MyFunction.apply(inp).sum().backward(create_graph=True) 8516*da0073e9SAndroid Build Coastguard Worker 8517*da0073e9SAndroid Build Coastguard Worker def test_power_function(self): 8518*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([0.0, 0.0, 0.0]) 8519*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True) 8520*da0073e9SAndroid Build Coastguard Worker c = torch.sum(a**b) 8521*da0073e9SAndroid Build Coastguard Worker c.backward() 8522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0])) 8523*da0073e9SAndroid Build Coastguard Worker 8524*da0073e9SAndroid Build Coastguard Worker s = 0 8525*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([-1.0, 0.0, 1.0], requires_grad=True) 8526*da0073e9SAndroid Build Coastguard Worker c = torch.sum(s**b) 8527*da0073e9SAndroid Build Coastguard Worker c.backward() 8528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, torch.tensor([-inf, 0.0, 0.0])) 8529*da0073e9SAndroid Build Coastguard Worker 8530*da0073e9SAndroid Build Coastguard Worker def test_custom_function_error(self): 8531*da0073e9SAndroid Build Coastguard Worker class BadFw(Function): 8532*da0073e9SAndroid Build Coastguard Worker @staticmethod 8533*da0073e9SAndroid Build Coastguard Worker def backward(ctx, foo): 8534*da0073e9SAndroid Build Coastguard Worker return foo 8535*da0073e9SAndroid Build Coastguard Worker 8536*da0073e9SAndroid Build Coastguard Worker class BadBw(Function): 8537*da0073e9SAndroid Build Coastguard Worker @staticmethod 8538*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo): 8539*da0073e9SAndroid Build Coastguard Worker return foo.clone() 8540*da0073e9SAndroid Build Coastguard Worker 8541*da0073e9SAndroid Build Coastguard Worker class BadBw2(Function): 8542*da0073e9SAndroid Build Coastguard Worker @staticmethod 8543*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo): 8544*da0073e9SAndroid Build Coastguard Worker return foo.clone() 8545*da0073e9SAndroid Build Coastguard Worker 8546*da0073e9SAndroid Build Coastguard Worker @staticmethod 8547*da0073e9SAndroid Build Coastguard Worker def backward(ctx, foo): 8548*da0073e9SAndroid Build Coastguard Worker return foo 8549*da0073e9SAndroid Build Coastguard Worker 8550*da0073e9SAndroid Build Coastguard Worker @staticmethod 8551*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, foo): 8552*da0073e9SAndroid Build Coastguard Worker return foo 8553*da0073e9SAndroid Build Coastguard Worker 8554*da0073e9SAndroid Build Coastguard Worker class BadJvp(Function): 8555*da0073e9SAndroid Build Coastguard Worker @staticmethod 8556*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo): 8557*da0073e9SAndroid Build Coastguard Worker return foo.clone() 8558*da0073e9SAndroid Build Coastguard Worker 8559*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(1, requires_grad=True) 8560*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, "must implement the forward"): 8561*da0073e9SAndroid Build Coastguard Worker BadFw.apply(inp) 8562*da0073e9SAndroid Build Coastguard Worker 8563*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must implement either the backward"): 8564*da0073e9SAndroid Build Coastguard Worker BadBw.apply(inp).sum().backward() 8565*da0073e9SAndroid Build Coastguard Worker 8566*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8567*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Implementing both 'backward' and 'vjp'" 8568*da0073e9SAndroid Build Coastguard Worker ): 8569*da0073e9SAndroid Build Coastguard Worker BadBw2.apply(inp).sum().backward() 8570*da0073e9SAndroid Build Coastguard Worker 8571*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "must implement the jvp function"): 8572*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8573*da0073e9SAndroid Build Coastguard Worker d = fwAD.make_dual(inp, torch.rand_like(inp)) 8574*da0073e9SAndroid Build Coastguard Worker res = BadJvp.apply(d) 8575*da0073e9SAndroid Build Coastguard Worker 8576*da0073e9SAndroid Build Coastguard Worker def test_custom_function_forward_mode_view_checks(self): 8577*da0073e9SAndroid Build Coastguard Worker flag_to_error = { 8578*da0073e9SAndroid Build Coastguard Worker "ok": None, 8579*da0073e9SAndroid Build Coastguard Worker "not_a_view": "jvp is not returning a view", 8580*da0073e9SAndroid Build Coastguard Worker "not_a_view_of_inp": "jvp is not returning a view of the given", 8581*da0073e9SAndroid Build Coastguard Worker "not_a_view_of_inp_base": "jvp is not returning a view of the same base", 8582*da0073e9SAndroid Build Coastguard Worker } 8583*da0073e9SAndroid Build Coastguard Worker 8584*da0073e9SAndroid Build Coastguard Worker class ViewFn(Function): 8585*da0073e9SAndroid Build Coastguard Worker @staticmethod 8586*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo, flag): 8587*da0073e9SAndroid Build Coastguard Worker ctx.flag = flag 8588*da0073e9SAndroid Build Coastguard Worker ctx.size = foo.size() 8589*da0073e9SAndroid Build Coastguard Worker return foo.narrow(0, 0, 2) 8590*da0073e9SAndroid Build Coastguard Worker 8591*da0073e9SAndroid Build Coastguard Worker @staticmethod 8592*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, gO): 8593*da0073e9SAndroid Build Coastguard Worker gI = gO.new_zeros(ctx.size) 8594*da0073e9SAndroid Build Coastguard Worker gI.narrow(0, 0, 2).copy_(gO) 8595*da0073e9SAndroid Build Coastguard Worker return gI, None 8596*da0073e9SAndroid Build Coastguard Worker 8597*da0073e9SAndroid Build Coastguard Worker @staticmethod 8598*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, gI, _): 8599*da0073e9SAndroid Build Coastguard Worker res = gI.narrow(0, 0, 2) 8600*da0073e9SAndroid Build Coastguard Worker if ctx.flag != "ok": 8601*da0073e9SAndroid Build Coastguard Worker # Break the view in the gradients! 8602*da0073e9SAndroid Build Coastguard Worker res = res.clone() 8603*da0073e9SAndroid Build Coastguard Worker if ctx.flag in ["not_a_view_of_inp", "not_a_view_of_inp_base"]: 8604*da0073e9SAndroid Build Coastguard Worker # Result should be a view, just of the wrong thing 8605*da0073e9SAndroid Build Coastguard Worker res = res.view_as(res) 8606*da0073e9SAndroid Build Coastguard Worker return res 8607*da0073e9SAndroid Build Coastguard Worker 8608*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 8609*da0073e9SAndroid Build Coastguard Worker 8610*da0073e9SAndroid Build Coastguard Worker for flag, msg in flag_to_error.items(): 8611*da0073e9SAndroid Build Coastguard Worker 8612*da0073e9SAndroid Build Coastguard Worker def test_fn(inp): 8613*da0073e9SAndroid Build Coastguard Worker if flag == "not_a_view_of_inp_base": 8614*da0073e9SAndroid Build Coastguard Worker inp = inp.view_as(inp) 8615*da0073e9SAndroid Build Coastguard Worker return ViewFn.apply(inp, flag) 8616*da0073e9SAndroid Build Coastguard Worker 8617*da0073e9SAndroid Build Coastguard Worker if msg is None: 8618*da0073e9SAndroid Build Coastguard Worker gradcheck(test_fn, inp, check_forward_ad=True) 8619*da0073e9SAndroid Build Coastguard Worker else: 8620*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 8621*da0073e9SAndroid Build Coastguard Worker gradcheck(test_fn, inp, check_forward_ad=True) 8622*da0073e9SAndroid Build Coastguard Worker 8623*da0073e9SAndroid Build Coastguard Worker def test_custom_function_forward_mode_inplace_checks(self): 8624*da0073e9SAndroid Build Coastguard Worker class InplaceFn(Function): 8625*da0073e9SAndroid Build Coastguard Worker @staticmethod 8626*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo, flag): 8627*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(foo) 8628*da0073e9SAndroid Build Coastguard Worker ctx.flag = flag 8629*da0073e9SAndroid Build Coastguard Worker foo.mul_(2) 8630*da0073e9SAndroid Build Coastguard Worker return foo 8631*da0073e9SAndroid Build Coastguard Worker 8632*da0073e9SAndroid Build Coastguard Worker @staticmethod 8633*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, gO): 8634*da0073e9SAndroid Build Coastguard Worker return 2 * gO, None 8635*da0073e9SAndroid Build Coastguard Worker 8636*da0073e9SAndroid Build Coastguard Worker @staticmethod 8637*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, gI, _): 8638*da0073e9SAndroid Build Coastguard Worker if ctx.flag: 8639*da0073e9SAndroid Build Coastguard Worker # Don't do the change inplace 8640*da0073e9SAndroid Build Coastguard Worker return 2 * gI 8641*da0073e9SAndroid Build Coastguard Worker else: 8642*da0073e9SAndroid Build Coastguard Worker gI.mul_(2) 8643*da0073e9SAndroid Build Coastguard Worker return gI 8644*da0073e9SAndroid Build Coastguard Worker 8645*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 8646*da0073e9SAndroid Build Coastguard Worker 8647*da0073e9SAndroid Build Coastguard Worker def test_fn(inp, flag): 8648*da0073e9SAndroid Build Coastguard Worker inp = inp.clone() 8649*da0073e9SAndroid Build Coastguard Worker return InplaceFn.apply(inp, flag) 8650*da0073e9SAndroid Build Coastguard Worker 8651*da0073e9SAndroid Build Coastguard Worker gradcheck(test_fn, (inp, False), check_forward_ad=True) 8652*da0073e9SAndroid Build Coastguard Worker 8653*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8654*da0073e9SAndroid Build Coastguard Worker RuntimeError, 8655*da0073e9SAndroid Build Coastguard Worker "inplace custom Function is not modifying the forward mode gradients inplace", 8656*da0073e9SAndroid Build Coastguard Worker ): 8657*da0073e9SAndroid Build Coastguard Worker gradcheck(test_fn, (inp, True), check_forward_ad=True) 8658*da0073e9SAndroid Build Coastguard Worker 8659*da0073e9SAndroid Build Coastguard Worker def test_custom_function_forward_mode_wrong_formula(self): 8660*da0073e9SAndroid Build Coastguard Worker class UserFn(Function): 8661*da0073e9SAndroid Build Coastguard Worker @staticmethod 8662*da0073e9SAndroid Build Coastguard Worker def forward(ctx, foo, should_fail): 8663*da0073e9SAndroid Build Coastguard Worker ctx.should_fail = should_fail 8664*da0073e9SAndroid Build Coastguard Worker return foo * 2 8665*da0073e9SAndroid Build Coastguard Worker 8666*da0073e9SAndroid Build Coastguard Worker @staticmethod 8667*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, gO): 8668*da0073e9SAndroid Build Coastguard Worker return 2 * gO, None 8669*da0073e9SAndroid Build Coastguard Worker 8670*da0073e9SAndroid Build Coastguard Worker @staticmethod 8671*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, gI, _): 8672*da0073e9SAndroid Build Coastguard Worker if ctx.should_fail: 8673*da0073e9SAndroid Build Coastguard Worker # Wrong gradient formula 8674*da0073e9SAndroid Build Coastguard Worker return 3 * gI 8675*da0073e9SAndroid Build Coastguard Worker else: 8676*da0073e9SAndroid Build Coastguard Worker return 2 * gI 8677*da0073e9SAndroid Build Coastguard Worker 8678*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(10, dtype=torch.double, requires_grad=True) 8679*da0073e9SAndroid Build Coastguard Worker gradcheck(UserFn.apply, (inp, False), check_forward_ad=True) 8680*da0073e9SAndroid Build Coastguard Worker 8681*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8682*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Jacobian computed with forward mode mismatch for output 0" 8683*da0073e9SAndroid Build Coastguard Worker ): 8684*da0073e9SAndroid Build Coastguard Worker gradcheck(UserFn.apply, (inp, True), check_forward_ad=True) 8685*da0073e9SAndroid Build Coastguard Worker 8686*da0073e9SAndroid Build Coastguard Worker def test_custom_function_forward_mode_non_tensor_before_tensor_args(self): 8687*da0073e9SAndroid Build Coastguard Worker class MyFn(torch.autograd.Function): 8688*da0073e9SAndroid Build Coastguard Worker @staticmethod 8689*da0073e9SAndroid Build Coastguard Worker def forward(ctx, nt, x, nt2, y): 8690*da0073e9SAndroid Build Coastguard Worker return x * 2 + y * 3 8691*da0073e9SAndroid Build Coastguard Worker 8692*da0073e9SAndroid Build Coastguard Worker @staticmethod 8693*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, nt, x_t, nt2, y_t): 8694*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(nt) 8695*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(nt2) 8696*da0073e9SAndroid Build Coastguard Worker return x_t * 2 + y_t * 3 8697*da0073e9SAndroid Build Coastguard Worker 8698*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1.0, dtype=torch.double) 8699*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0, dtype=torch.double) 8700*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(1.0, dtype=torch.double) 8701*da0073e9SAndroid Build Coastguard Worker 8702*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8703*da0073e9SAndroid Build Coastguard Worker dual_x = fwAD.make_dual(x, t) 8704*da0073e9SAndroid Build Coastguard Worker MyFn.apply(1, dual_x, 1, y) 8705*da0073e9SAndroid Build Coastguard Worker 8706*da0073e9SAndroid Build Coastguard Worker gradcheck( 8707*da0073e9SAndroid Build Coastguard Worker MyFn.apply, 8708*da0073e9SAndroid Build Coastguard Worker (1, x.requires_grad_(True), 1, y.requires_grad_(True)), 8709*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 8710*da0073e9SAndroid Build Coastguard Worker check_backward_ad=False, 8711*da0073e9SAndroid Build Coastguard Worker check_batched_grad=False, 8712*da0073e9SAndroid Build Coastguard Worker ) 8713*da0073e9SAndroid Build Coastguard Worker 8714*da0073e9SAndroid Build Coastguard Worker def test_custom_function_forward_mode_forward_is_no_op(self): 8715*da0073e9SAndroid Build Coastguard Worker error_regex = ( 8716*da0073e9SAndroid Build Coastguard Worker "A custom Function's forward is returning a view \\(or an input as-is\\)" 8717*da0073e9SAndroid Build Coastguard Worker ) 8718*da0073e9SAndroid Build Coastguard Worker 8719*da0073e9SAndroid Build Coastguard Worker return_lambdas = { 8720*da0073e9SAndroid Build Coastguard Worker # If we return an input as-is in forward, that is treated 8721*da0073e9SAndroid Build Coastguard Worker # as if self.view_as(self) is performed. If jvp returns x.view_as(x), 8722*da0073e9SAndroid Build Coastguard Worker # this is OK. 8723*da0073e9SAndroid Build Coastguard Worker "view_as": lambda x: x.view_as(x), 8724*da0073e9SAndroid Build Coastguard Worker # Expect this to raise an error 8725*da0073e9SAndroid Build Coastguard Worker "self": lambda x: x, 8726*da0073e9SAndroid Build Coastguard Worker # Expect this to raise the same error 8727*da0073e9SAndroid Build Coastguard Worker "mul_by_2": lambda x: x * 2, 8728*da0073e9SAndroid Build Coastguard Worker } 8729*da0073e9SAndroid Build Coastguard Worker 8730*da0073e9SAndroid Build Coastguard Worker for k, fn in return_lambdas.items(): 8731*da0073e9SAndroid Build Coastguard Worker 8732*da0073e9SAndroid Build Coastguard Worker class MyFn(torch.autograd.Function): 8733*da0073e9SAndroid Build Coastguard Worker @staticmethod 8734*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 8735*da0073e9SAndroid Build Coastguard Worker return x + y, x 8736*da0073e9SAndroid Build Coastguard Worker 8737*da0073e9SAndroid Build Coastguard Worker @staticmethod 8738*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, gO1, gO2): 8739*da0073e9SAndroid Build Coastguard Worker return gO1 + gO2, gO1 8740*da0073e9SAndroid Build Coastguard Worker 8741*da0073e9SAndroid Build Coastguard Worker @staticmethod 8742*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_t, y_t): 8743*da0073e9SAndroid Build Coastguard Worker return x_t + y_t, fn(x_t) 8744*da0073e9SAndroid Build Coastguard Worker 8745*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, dtype=torch.double, requires_grad=True) 8746*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0, dtype=torch.double) 8747*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1.0, dtype=torch.double, requires_grad=True) 8748*da0073e9SAndroid Build Coastguard Worker 8749*da0073e9SAndroid Build Coastguard Worker c = torch.tensor(1.0, dtype=torch.double) 8750*da0073e9SAndroid Build Coastguard Worker t2 = torch.tensor(1.0, dtype=torch.double) 8751*da0073e9SAndroid Build Coastguard Worker d = torch.tensor(1.0, dtype=torch.double) 8752*da0073e9SAndroid Build Coastguard Worker 8753*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8754*da0073e9SAndroid Build Coastguard Worker a_dual = fwAD.make_dual(a, t) 8755*da0073e9SAndroid Build Coastguard Worker c_dual = fwAD.make_dual(c, t2) 8756*da0073e9SAndroid Build Coastguard Worker 8757*da0073e9SAndroid Build Coastguard Worker if k == "view_as": 8758*da0073e9SAndroid Build Coastguard Worker _, out2 = MyFn.apply(a_dual, b) 8759*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t) 8760*da0073e9SAndroid Build Coastguard Worker 8761*da0073e9SAndroid Build Coastguard Worker _, out2 = MyFn.apply(c_dual, d) 8762*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwAD.unpack_dual(out2).tangent._base is t2) 8763*da0073e9SAndroid Build Coastguard Worker else: 8764*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 8765*da0073e9SAndroid Build Coastguard Worker MyFn.apply(a_dual, b) 8766*da0073e9SAndroid Build Coastguard Worker 8767*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 8768*da0073e9SAndroid Build Coastguard Worker MyFn.apply(c_dual, d) 8769*da0073e9SAndroid Build Coastguard Worker 8770*da0073e9SAndroid Build Coastguard Worker if k == "view_as": 8771*da0073e9SAndroid Build Coastguard Worker gradcheck(MyFn.apply, (a, c), check_forward_ad=True) 8772*da0073e9SAndroid Build Coastguard Worker else: 8773*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_regex): 8774*da0073e9SAndroid Build Coastguard Worker gradcheck(MyFn.apply, (a, c), check_forward_ad=True) 8775*da0073e9SAndroid Build Coastguard Worker 8776*da0073e9SAndroid Build Coastguard Worker def test_custom_function_save_for_forward(self): 8777*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8778*da0073e9SAndroid Build Coastguard Worker @staticmethod 8779*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): 8780*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x, y) 8781*da0073e9SAndroid Build Coastguard Worker ctx.save_for_forward(x, y) 8782*da0073e9SAndroid Build Coastguard Worker ctx.z = z 8783*da0073e9SAndroid Build Coastguard Worker ctx.prod = x * y 8784*da0073e9SAndroid Build Coastguard Worker return z * ctx.prod 8785*da0073e9SAndroid Build Coastguard Worker 8786*da0073e9SAndroid Build Coastguard Worker @staticmethod 8787*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_t, y_t, _): 8788*da0073e9SAndroid Build Coastguard Worker x_p, y_p = ctx.saved_tensors 8789*da0073e9SAndroid Build Coastguard Worker z = ctx.z 8790*da0073e9SAndroid Build Coastguard Worker return z * (y_p * x_t + x_p * y_t) 8791*da0073e9SAndroid Build Coastguard Worker 8792*da0073e9SAndroid Build Coastguard Worker @staticmethod 8793*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, grad_out): 8794*da0073e9SAndroid Build Coastguard Worker x, y = ctx.saved_tensors 8795*da0073e9SAndroid Build Coastguard Worker z = ctx.z 8796*da0073e9SAndroid Build Coastguard Worker return z * grad_out * y, z * grad_out * x, None 8797*da0073e9SAndroid Build Coastguard Worker 8798*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True, dtype=torch.double) 8799*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(1.0, dtype=torch.double) 8800*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(2.0, requires_grad=True, dtype=torch.double) 8801*da0073e9SAndroid Build Coastguard Worker c = 4 8802*da0073e9SAndroid Build Coastguard Worker 8803*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8804*da0073e9SAndroid Build Coastguard Worker a_dual = fwAD.make_dual(a, t) 8805*da0073e9SAndroid Build Coastguard Worker out = Func.apply(a_dual, b, c) 8806*da0073e9SAndroid Build Coastguard Worker out.backward() 8807*da0073e9SAndroid Build Coastguard Worker 8808*da0073e9SAndroid Build Coastguard Worker gradcheck(Func.apply, (a, b, c), check_forward_ad=True) 8809*da0073e9SAndroid Build Coastguard Worker 8810*da0073e9SAndroid Build Coastguard Worker # When saved for backward, but not saved for forward 8811*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8812*da0073e9SAndroid Build Coastguard Worker @staticmethod 8813*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x: torch.Tensor): 8814*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 8815*da0073e9SAndroid Build Coastguard Worker return x.clone() 8816*da0073e9SAndroid Build Coastguard Worker 8817*da0073e9SAndroid Build Coastguard Worker @staticmethod 8818*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_t): 8819*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ctx.saved_tensors), 0) 8820*da0073e9SAndroid Build Coastguard Worker return x_t 8821*da0073e9SAndroid Build Coastguard Worker 8822*da0073e9SAndroid Build Coastguard Worker @staticmethod 8823*da0073e9SAndroid Build Coastguard Worker def vjp(ctx, grad_out): 8824*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 8825*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(ctx.saved_tensors), 1) 8826*da0073e9SAndroid Build Coastguard Worker return grad_out 8827*da0073e9SAndroid Build Coastguard Worker 8828*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8829*da0073e9SAndroid Build Coastguard Worker a_dual = fwAD.make_dual(a, t) 8830*da0073e9SAndroid Build Coastguard Worker out = Func.apply(a_dual) 8831*da0073e9SAndroid Build Coastguard Worker out.backward() 8832*da0073e9SAndroid Build Coastguard Worker 8833*da0073e9SAndroid Build Coastguard Worker gradcheck(Func.apply, (a,), check_forward_ad=True) 8834*da0073e9SAndroid Build Coastguard Worker 8835*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_autograd_function.py") 8836*da0073e9SAndroid Build Coastguard Worker def test_custom_function_forward_mode_non_differentiable(self): 8837*da0073e9SAndroid Build Coastguard Worker # returns differentiable type, marked non-differentiable 8838*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8839*da0073e9SAndroid Build Coastguard Worker @staticmethod 8840*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 8841*da0073e9SAndroid Build Coastguard Worker out = y.clone() 8842*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(out) 8843*da0073e9SAndroid Build Coastguard Worker return x.clone(), out 8844*da0073e9SAndroid Build Coastguard Worker 8845*da0073e9SAndroid Build Coastguard Worker @staticmethod 8846*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_tangent, y_tangent): 8847*da0073e9SAndroid Build Coastguard Worker return x_tangent, None 8848*da0073e9SAndroid Build Coastguard Worker 8849*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2.0) 8850*da0073e9SAndroid Build Coastguard Worker x_tangent = torch.tensor(1.0) 8851*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(3.0) 8852*da0073e9SAndroid Build Coastguard Worker 8853*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8854*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, x_tangent) 8855*da0073e9SAndroid Build Coastguard Worker _, out2_dual = Func.apply(x_dual, y) 8856*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None) 8857*da0073e9SAndroid Build Coastguard Worker 8858*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(3) 8859*da0073e9SAndroid Build Coastguard Worker 8860*da0073e9SAndroid Build Coastguard Worker # returns non-differentiable type, NOT marked non-differentiable 8861*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8862*da0073e9SAndroid Build Coastguard Worker @staticmethod 8863*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 8864*da0073e9SAndroid Build Coastguard Worker return x.clone(), y.clone() 8865*da0073e9SAndroid Build Coastguard Worker 8866*da0073e9SAndroid Build Coastguard Worker @staticmethod 8867*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_tangent, y_tangent): 8868*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(y_tangent) 8869*da0073e9SAndroid Build Coastguard Worker return x_tangent, None 8870*da0073e9SAndroid Build Coastguard Worker 8871*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8872*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, x_tangent) 8873*da0073e9SAndroid Build Coastguard Worker _, out2_dual = Func.apply(x_dual, y) 8874*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, None) 8875*da0073e9SAndroid Build Coastguard Worker 8876*da0073e9SAndroid Build Coastguard Worker class FuncWrong(torch.autograd.Function): 8877*da0073e9SAndroid Build Coastguard Worker @staticmethod 8878*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 8879*da0073e9SAndroid Build Coastguard Worker out = y.clone() 8880*da0073e9SAndroid Build Coastguard Worker ctx.mark_non_differentiable(out) 8881*da0073e9SAndroid Build Coastguard Worker return x.clone(), out 8882*da0073e9SAndroid Build Coastguard Worker 8883*da0073e9SAndroid Build Coastguard Worker @staticmethod 8884*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_tangent, y_tangent): 8885*da0073e9SAndroid Build Coastguard Worker return x_tangent, x_tangent.clone() 8886*da0073e9SAndroid Build Coastguard Worker 8887*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8888*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, x_tangent) 8889*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 8890*da0073e9SAndroid Build Coastguard Worker RuntimeError, "You should return None at that position instead" 8891*da0073e9SAndroid Build Coastguard Worker ): 8892*da0073e9SAndroid Build Coastguard Worker FuncWrong.apply(x_dual, y) 8893*da0073e9SAndroid Build Coastguard Worker 8894*da0073e9SAndroid Build Coastguard Worker # returns non-tensor 8895*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 8896*da0073e9SAndroid Build Coastguard Worker @staticmethod 8897*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 8898*da0073e9SAndroid Build Coastguard Worker return x.clone(), object(), x.clone() 8899*da0073e9SAndroid Build Coastguard Worker 8900*da0073e9SAndroid Build Coastguard Worker @staticmethod 8901*da0073e9SAndroid Build Coastguard Worker def jvp(ctx, x_tangent): 8902*da0073e9SAndroid Build Coastguard Worker return x_tangent, None, x_tangent 8903*da0073e9SAndroid Build Coastguard Worker 8904*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 8905*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(x, x_tangent) 8906*da0073e9SAndroid Build Coastguard Worker out_dual, _, out2_dual = Func.apply(x_dual) 8907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(out_dual).tangent, x_tangent) 8908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(out2_dual).tangent, x_tangent) 8909*da0073e9SAndroid Build Coastguard Worker 8910*da0073e9SAndroid Build Coastguard Worker def test_custom_function_local_inplace(self): 8911*da0073e9SAndroid Build Coastguard Worker class MyFn(torch.autograd.Function): 8912*da0073e9SAndroid Build Coastguard Worker @staticmethod 8913*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp, inplace): 8914*da0073e9SAndroid Build Coastguard Worker view = inp.clone()[:3] 8915*da0073e9SAndroid Build Coastguard Worker if inplace: 8916*da0073e9SAndroid Build Coastguard Worker view += 2 8917*da0073e9SAndroid Build Coastguard Worker return view 8918*da0073e9SAndroid Build Coastguard Worker 8919*da0073e9SAndroid Build Coastguard Worker @staticmethod 8920*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 8921*da0073e9SAndroid Build Coastguard Worker return grad, None 8922*da0073e9SAndroid Build Coastguard Worker 8923*da0073e9SAndroid Build Coastguard Worker base = torch.rand(10, requires_grad=True) 8924*da0073e9SAndroid Build Coastguard Worker 8925*da0073e9SAndroid Build Coastguard Worker foo = MyFn.apply(base, False) 8926*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward") 8927*da0073e9SAndroid Build Coastguard Worker 8928*da0073e9SAndroid Build Coastguard Worker foo = MyFn.apply(base, True) 8929*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo.grad_fn.__class__.__name__, "MyFnBackward") 8930*da0073e9SAndroid Build Coastguard Worker 8931*da0073e9SAndroid Build Coastguard Worker def test_integer_outputs(self): 8932*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(4, requires_grad=True) 8933*da0073e9SAndroid Build Coastguard Worker 8934*da0073e9SAndroid Build Coastguard Worker out = inp.argmax() 8935*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.dtype.is_floating_point) 8936*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8937*da0073e9SAndroid Build Coastguard Worker 8938*da0073e9SAndroid Build Coastguard Worker out = inp.argmin() 8939*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.dtype.is_floating_point) 8940*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8941*da0073e9SAndroid Build Coastguard Worker 8942*da0073e9SAndroid Build Coastguard Worker out = inp.argsort() 8943*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.dtype.is_floating_point) 8944*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8945*da0073e9SAndroid Build Coastguard Worker 8946*da0073e9SAndroid Build Coastguard Worker val = torch.rand((), requires_grad=True) 8947*da0073e9SAndroid Build Coastguard Worker 8948*da0073e9SAndroid Build Coastguard Worker out = torch.searchsorted(inp, val) 8949*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.dtype.is_floating_point) 8950*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8951*da0073e9SAndroid Build Coastguard Worker 8952*da0073e9SAndroid Build Coastguard Worker bins = torch.linspace(0, 1.0, steps=100, requires_grad=True) 8953*da0073e9SAndroid Build Coastguard Worker vals = torch.rand(5, 5, requires_grad=True) 8954*da0073e9SAndroid Build Coastguard Worker out = torch.bucketize(vals, bins) 8955*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.dtype.is_floating_point) 8956*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8957*da0073e9SAndroid Build Coastguard Worker 8958*da0073e9SAndroid Build Coastguard Worker val = torch.empty(5).requires_grad_() 8959*da0073e9SAndroid Build Coastguard Worker out = val.count_nonzero() 8960*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8961*da0073e9SAndroid Build Coastguard Worker 8962*da0073e9SAndroid Build Coastguard Worker def assert_only_first_requires_grad(res): 8963*da0073e9SAndroid Build Coastguard Worker if not isinstance(res, tuple): 8964*da0073e9SAndroid Build Coastguard Worker res = (res,) 8965*da0073e9SAndroid Build Coastguard Worker self.assertTrue(res[0].requires_grad) 8966*da0073e9SAndroid Build Coastguard Worker for out in res[1:]: 8967*da0073e9SAndroid Build Coastguard Worker if out is not None: 8968*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 8969*da0073e9SAndroid Build Coastguard Worker 8970*da0073e9SAndroid Build Coastguard Worker for sort in [True, False]: 8971*da0073e9SAndroid Build Coastguard Worker for return_inverse in [True, False]: 8972*da0073e9SAndroid Build Coastguard Worker for return_counts in [True, False]: 8973*da0073e9SAndroid Build Coastguard Worker res = torch.unique( 8974*da0073e9SAndroid Build Coastguard Worker inp, 8975*da0073e9SAndroid Build Coastguard Worker sorted=sort, 8976*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 8977*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 8978*da0073e9SAndroid Build Coastguard Worker ) 8979*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 8980*da0073e9SAndroid Build Coastguard Worker 8981*da0073e9SAndroid Build Coastguard Worker res = torch.unique( 8982*da0073e9SAndroid Build Coastguard Worker inp, 8983*da0073e9SAndroid Build Coastguard Worker sorted=sort, 8984*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 8985*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 8986*da0073e9SAndroid Build Coastguard Worker dim=0, 8987*da0073e9SAndroid Build Coastguard Worker ) 8988*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 8989*da0073e9SAndroid Build Coastguard Worker 8990*da0073e9SAndroid Build Coastguard Worker res = torch.unique_consecutive( 8991*da0073e9SAndroid Build Coastguard Worker inp, return_inverse=return_inverse, return_counts=return_counts 8992*da0073e9SAndroid Build Coastguard Worker ) 8993*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 8994*da0073e9SAndroid Build Coastguard Worker 8995*da0073e9SAndroid Build Coastguard Worker res = torch.unique_consecutive( 8996*da0073e9SAndroid Build Coastguard Worker inp, 8997*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 8998*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 8999*da0073e9SAndroid Build Coastguard Worker dim=0, 9000*da0073e9SAndroid Build Coastguard Worker ) 9001*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 9002*da0073e9SAndroid Build Coastguard Worker 9003*da0073e9SAndroid Build Coastguard Worker # Here we test the internal functions to make sure all of them are 9004*da0073e9SAndroid Build Coastguard Worker # covered on top of the public API 9005*da0073e9SAndroid Build Coastguard Worker res = torch._unique(inp, sorted=sort, return_inverse=return_inverse) 9006*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 9007*da0073e9SAndroid Build Coastguard Worker 9008*da0073e9SAndroid Build Coastguard Worker # This looks public but is actually manually deleted from the 9009*da0073e9SAndroid Build Coastguard Worker # torch namespace in torch/functional.py 9010*da0073e9SAndroid Build Coastguard Worker res = torch._VF.unique_dim( 9011*da0073e9SAndroid Build Coastguard Worker inp, 9012*da0073e9SAndroid Build Coastguard Worker dim=0, 9013*da0073e9SAndroid Build Coastguard Worker sorted=sort, 9014*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 9015*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 9016*da0073e9SAndroid Build Coastguard Worker ) 9017*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 9018*da0073e9SAndroid Build Coastguard Worker 9019*da0073e9SAndroid Build Coastguard Worker # We don't test `unique_dim_consecutive` here. 9020*da0073e9SAndroid Build Coastguard Worker # It looks public but the python binding is actually manually disabled in 9021*da0073e9SAndroid Build Coastguard Worker # tools/autograd/gen_python_functions.py 9022*da0073e9SAndroid Build Coastguard Worker 9023*da0073e9SAndroid Build Coastguard Worker res = torch._unique2( 9024*da0073e9SAndroid Build Coastguard Worker inp, 9025*da0073e9SAndroid Build Coastguard Worker sorted=sort, 9026*da0073e9SAndroid Build Coastguard Worker return_inverse=return_inverse, 9027*da0073e9SAndroid Build Coastguard Worker return_counts=return_counts, 9028*da0073e9SAndroid Build Coastguard Worker ) 9029*da0073e9SAndroid Build Coastguard Worker assert_only_first_requires_grad(res) 9030*da0073e9SAndroid Build Coastguard Worker 9031*da0073e9SAndroid Build Coastguard Worker def test_custom_function_cycle(self): 9032*da0073e9SAndroid Build Coastguard Worker class MyFn(Function): 9033*da0073e9SAndroid Build Coastguard Worker @staticmethod 9034*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, metadata): 9035*da0073e9SAndroid Build Coastguard Worker x = x.clone() 9036*da0073e9SAndroid Build Coastguard Worker ctx.meta = metadata 9037*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 9038*da0073e9SAndroid Build Coastguard Worker return x 9039*da0073e9SAndroid Build Coastguard Worker 9040*da0073e9SAndroid Build Coastguard Worker @staticmethod 9041*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 9042*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 9043*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, 3.14) 9044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ctx.meta["foo"], 3.14) 9045*da0073e9SAndroid Build Coastguard Worker return gO * x, None 9046*da0073e9SAndroid Build Coastguard Worker 9047*da0073e9SAndroid Build Coastguard Worker def get_refs(with_backward): 9048*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(3.14, requires_grad=True) 9049*da0073e9SAndroid Build Coastguard Worker 9050*da0073e9SAndroid Build Coastguard Worker metadata = {} 9051*da0073e9SAndroid Build Coastguard Worker out = MyFn.apply(a, metadata) 9052*da0073e9SAndroid Build Coastguard Worker 9053*da0073e9SAndroid Build Coastguard Worker metadata["foo"] = out 9054*da0073e9SAndroid Build Coastguard Worker 9055*da0073e9SAndroid Build Coastguard Worker if with_backward: 9056*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 9057*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, a) 9058*da0073e9SAndroid Build Coastguard Worker 9059*da0073e9SAndroid Build Coastguard Worker return torch._C._WeakTensorRef(out) 9060*da0073e9SAndroid Build Coastguard Worker 9061*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 9062*da0073e9SAndroid Build Coastguard Worker ref = get_refs(False) 9063*da0073e9SAndroid Build Coastguard Worker self.assertFalse(ref.expired()) 9064*da0073e9SAndroid Build Coastguard Worker gc.collect() 9065*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref.expired()) 9066*da0073e9SAndroid Build Coastguard Worker 9067*da0073e9SAndroid Build Coastguard Worker # The backward clears the saved_variables but not the __dict__ 9068*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 9069*da0073e9SAndroid Build Coastguard Worker ref = get_refs(True) 9070*da0073e9SAndroid Build Coastguard Worker self.assertFalse(ref.expired()) 9071*da0073e9SAndroid Build Coastguard Worker gc.collect() 9072*da0073e9SAndroid Build Coastguard Worker self.assertTrue(ref.expired()) 9073*da0073e9SAndroid Build Coastguard Worker 9074*da0073e9SAndroid Build Coastguard Worker def test_create_graph_and_full_backward_hook_cycle(self): 9075*da0073e9SAndroid Build Coastguard Worker # If BackwardHook saves grad_output, it can create a cycle when we perform backward 9076*da0073e9SAndroid Build Coastguard Worker # with create_graph=True 9077*da0073e9SAndroid Build Coastguard Worker # 9078*da0073e9SAndroid Build Coastguard Worker # grad_output -> grad_output.grad_fn -> graph -> hook -> grad_output 9079*da0073e9SAndroid Build Coastguard Worker # 9080*da0073e9SAndroid Build Coastguard Worker class TestCls: 9081*da0073e9SAndroid Build Coastguard Worker # Dummy class for the purpose of creating a weakref 9082*da0073e9SAndroid Build Coastguard Worker pass 9083*da0073e9SAndroid Build Coastguard Worker 9084*da0073e9SAndroid Build Coastguard Worker def get_ref(input_requires_grad, nb_hooks): 9085*da0073e9SAndroid Build Coastguard Worker t = torch.randn(10, requires_grad=input_requires_grad) 9086*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 9087*da0073e9SAndroid Build Coastguard Worker 9088*da0073e9SAndroid Build Coastguard Worker class Test(nn.Module): 9089*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 9090*da0073e9SAndroid Build Coastguard Worker return x**2 * a**2 9091*da0073e9SAndroid Build Coastguard Worker 9092*da0073e9SAndroid Build Coastguard Worker mod = Test() 9093*da0073e9SAndroid Build Coastguard Worker 9094*da0073e9SAndroid Build Coastguard Worker for _ in range(nb_hooks): 9095*da0073e9SAndroid Build Coastguard Worker mod.register_full_backward_hook(lambda a, b, c: None) 9096*da0073e9SAndroid Build Coastguard Worker 9097*da0073e9SAndroid Build Coastguard Worker tmp = mod(t) 9098*da0073e9SAndroid Build Coastguard Worker 9099*da0073e9SAndroid Build Coastguard Worker # Save dummy object to graph and get a weak ref to it 9100*da0073e9SAndroid Build Coastguard Worker test = TestCls() 9101*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(test) 9102*da0073e9SAndroid Build Coastguard Worker tmp.grad_fn.metadata["a"] = test 9103*da0073e9SAndroid Build Coastguard Worker 9104*da0073e9SAndroid Build Coastguard Worker with set_warn_always_context(True): 9105*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9106*da0073e9SAndroid Build Coastguard Worker tmp.exp().sum().backward(create_graph=True) 9107*da0073e9SAndroid Build Coastguard Worker self.assertTrue(len(w) == 1) 9108*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 9109*da0073e9SAndroid Build Coastguard Worker "Using backward() with create_graph=True" in str(w[0].message) 9110*da0073e9SAndroid Build Coastguard Worker ) 9111*da0073e9SAndroid Build Coastguard Worker 9112*da0073e9SAndroid Build Coastguard Worker # Remove the backward + create_graph=True cycle 9113*da0073e9SAndroid Build Coastguard Worker a.grad = None 9114*da0073e9SAndroid Build Coastguard Worker t.grad = None 9115*da0073e9SAndroid Build Coastguard Worker 9116*da0073e9SAndroid Build Coastguard Worker return ref 9117*da0073e9SAndroid Build Coastguard Worker 9118*da0073e9SAndroid Build Coastguard Worker for nb_hooks in (1, 2, 3): 9119*da0073e9SAndroid Build Coastguard Worker for input_requires_grad in (True, False): 9120*da0073e9SAndroid Build Coastguard Worker ref_ = get_ref( 9121*da0073e9SAndroid Build Coastguard Worker input_requires_grad=input_requires_grad, 9122*da0073e9SAndroid Build Coastguard Worker nb_hooks=nb_hooks, 9123*da0073e9SAndroid Build Coastguard Worker ) 9124*da0073e9SAndroid Build Coastguard Worker gc.collect() 9125*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref_()) 9126*da0073e9SAndroid Build Coastguard Worker 9127*da0073e9SAndroid Build Coastguard Worker @parametrize("use_custom_function", [True, False]) 9128*da0073e9SAndroid Build Coastguard Worker @parametrize("use_tensor_hook", [True, False]) 9129*da0073e9SAndroid Build Coastguard Worker def test_hook_closure_cycle(self, use_custom_function, use_tensor_hook): 9130*da0073e9SAndroid Build Coastguard Worker # This creates a cycle between the hook and grad_fn_b 9131*da0073e9SAndroid Build Coastguard Worker # hook -> closure -> grad_fn_b (python) -> grad_fn (cpp) -> hook (cpp) 9132*da0073e9SAndroid Build Coastguard Worker # -> dict -> hook 9133*da0073e9SAndroid Build Coastguard Worker # 9134*da0073e9SAndroid Build Coastguard Worker # This test is testing that the grad_fn_b (python) only traverses the 9135*da0073e9SAndroid Build Coastguard Worker # dict if it is the only one holding a reference to the grad_fn_b (cpp) 9136*da0073e9SAndroid Build Coastguard Worker # shared_ptr 9137*da0073e9SAndroid Build Coastguard Worker # 9138*da0073e9SAndroid Build Coastguard Worker # See: https://github.com/pytorch/pytorch/issues/102174 9139*da0073e9SAndroid Build Coastguard Worker class Function(torch.autograd.Function): 9140*da0073e9SAndroid Build Coastguard Worker @staticmethod 9141*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 9142*da0073e9SAndroid Build Coastguard Worker return x 9143*da0073e9SAndroid Build Coastguard Worker 9144*da0073e9SAndroid Build Coastguard Worker @staticmethod 9145*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 9146*da0073e9SAndroid Build Coastguard Worker return grad 9147*da0073e9SAndroid Build Coastguard Worker 9148*da0073e9SAndroid Build Coastguard Worker class Test: 9149*da0073e9SAndroid Build Coastguard Worker pass 9150*da0073e9SAndroid Build Coastguard Worker 9151*da0073e9SAndroid Build Coastguard Worker count = [0] 9152*da0073e9SAndroid Build Coastguard Worker 9153*da0073e9SAndroid Build Coastguard Worker def scope(): 9154*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 9155*da0073e9SAndroid Build Coastguard Worker if use_custom_function: 9156*da0073e9SAndroid Build Coastguard Worker b = Function.apply(a) 9157*da0073e9SAndroid Build Coastguard Worker else: 9158*da0073e9SAndroid Build Coastguard Worker b = a.clone() 9159*da0073e9SAndroid Build Coastguard Worker grad_fn_b = b.grad_fn 9160*da0073e9SAndroid Build Coastguard Worker obj = Test() 9161*da0073e9SAndroid Build Coastguard Worker 9162*da0073e9SAndroid Build Coastguard Worker def hook(*args): 9163*da0073e9SAndroid Build Coastguard Worker # Make sure this hook's closure holds onto grad_fn_b 9164*da0073e9SAndroid Build Coastguard Worker # This forms a cycle between the hook and grad_fn_b 9165*da0073e9SAndroid Build Coastguard Worker # We also hold onto a sentinel object 'obj' to track 9166*da0073e9SAndroid Build Coastguard Worker # whether this cycle is still alive. See 'ref' below. 9167*da0073e9SAndroid Build Coastguard Worker grad_fn_b 9168*da0073e9SAndroid Build Coastguard Worker obj 9169*da0073e9SAndroid Build Coastguard Worker count[0] += 1 9170*da0073e9SAndroid Build Coastguard Worker 9171*da0073e9SAndroid Build Coastguard Worker if use_tensor_hook: 9172*da0073e9SAndroid Build Coastguard Worker b.register_hook(hook) 9173*da0073e9SAndroid Build Coastguard Worker else: 9174*da0073e9SAndroid Build Coastguard Worker b.grad_fn.register_hook(hook) 9175*da0073e9SAndroid Build Coastguard Worker c = b.clone() 9176*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(obj) 9177*da0073e9SAndroid Build Coastguard Worker return c, ref 9178*da0073e9SAndroid Build Coastguard Worker 9179*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 9180*da0073e9SAndroid Build Coastguard Worker out, ref = scope() 9181*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 9182*da0073e9SAndroid Build Coastguard Worker 9183*da0073e9SAndroid Build Coastguard Worker gc.collect() 9184*da0073e9SAndroid Build Coastguard Worker 9185*da0073e9SAndroid Build Coastguard Worker # Make sure gc does not clear the cycle noted above. 9186*da0073e9SAndroid Build Coastguard Worker # e.g. the hook is alive and gets fired even after gc runs 9187*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 9188*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 2) 9189*da0073e9SAndroid Build Coastguard Worker 9190*da0073e9SAndroid Build Coastguard Worker # ref is still alive because the use_count of the cpp grad_fn 9191*da0073e9SAndroid Build Coastguard Worker # shared_ptr > 1 since (1) the python grad_fn is alive, and (2) the 9192*da0073e9SAndroid Build Coastguard Worker # rest of the graph holds onto the shared_ptr 9193*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(ref()) 9194*da0073e9SAndroid Build Coastguard Worker 9195*da0073e9SAndroid Build Coastguard Worker # Then delete the rest of the graph and check that ref is dead 9196*da0073e9SAndroid Build Coastguard Worker del out 9197*da0073e9SAndroid Build Coastguard Worker gc.collect() 9198*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref()) 9199*da0073e9SAndroid Build Coastguard Worker 9200*da0073e9SAndroid Build Coastguard Worker def test_full_backward_hook_double_backward(self): 9201*da0073e9SAndroid Build Coastguard Worker x = torch.rand(1, requires_grad=True) 9202*da0073e9SAndroid Build Coastguard Worker y = torch.rand_like(x) 9203*da0073e9SAndroid Build Coastguard Worker 9204*da0073e9SAndroid Build Coastguard Worker func = torch.nn.MSELoss() 9205*da0073e9SAndroid Build Coastguard Worker counter = [0] 9206*da0073e9SAndroid Build Coastguard Worker 9207*da0073e9SAndroid Build Coastguard Worker def hook(module, grad_input, grad_output): 9208*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 9209*da0073e9SAndroid Build Coastguard Worker 9210*da0073e9SAndroid Build Coastguard Worker func.register_full_backward_hook(hook) 9211*da0073e9SAndroid Build Coastguard Worker 9212*da0073e9SAndroid Build Coastguard Worker f = func(x, y) 9213*da0073e9SAndroid Build Coastguard Worker 9214*da0073e9SAndroid Build Coastguard Worker (gradx_f,) = torch.autograd.grad(f, x, create_graph=True) 9215*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 9216*da0073e9SAndroid Build Coastguard Worker _ = torch.autograd.grad(gradx_f, x) 9217*da0073e9SAndroid Build Coastguard Worker # We should not error, and counter should not be incremented 9218*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 9219*da0073e9SAndroid Build Coastguard Worker 9220*da0073e9SAndroid Build Coastguard Worker def test_input_buffer_accum(self): 9221*da0073e9SAndroid Build Coastguard Worker leaf = torch.rand(2, 2, requires_grad=True) 9222*da0073e9SAndroid Build Coastguard Worker 9223*da0073e9SAndroid Build Coastguard Worker # An op that returns sparse gradients 9224*da0073e9SAndroid Build Coastguard Worker ind = torch.tensor([[0, 0]], dtype=torch.long) 9225*da0073e9SAndroid Build Coastguard Worker out2 = leaf.gather(0, ind, sparse_grad=True) 9226*da0073e9SAndroid Build Coastguard Worker 9227*da0073e9SAndroid Build Coastguard Worker # An op that returns the gradients as-is 9228*da0073e9SAndroid Build Coastguard Worker out1 = leaf.clone() 9229*da0073e9SAndroid Build Coastguard Worker 9230*da0073e9SAndroid Build Coastguard Worker grad_out1_original = torch.rand_like(out1) 9231*da0073e9SAndroid Build Coastguard Worker grad_out1 = grad_out1_original.clone() 9232*da0073e9SAndroid Build Coastguard Worker grad_out2 = torch.rand_like(out2) 9233*da0073e9SAndroid Build Coastguard Worker 9234*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward((out1, out2), (grad_out1, grad_out2)) 9235*da0073e9SAndroid Build Coastguard Worker 9236*da0073e9SAndroid Build Coastguard Worker # Given gradients should not be modified inplace 9237*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad_out1, grad_out1_original) 9238*da0073e9SAndroid Build Coastguard Worker 9239*da0073e9SAndroid Build Coastguard Worker def test_no_unnecessary_unwrapping(self): 9240*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9241*da0073e9SAndroid Build Coastguard Worker a_orig = a.detach().clone() 9242*da0073e9SAndroid Build Coastguard Worker b = a * a 9243*da0073e9SAndroid Build Coastguard Worker c = a * b 9244*da0073e9SAndroid Build Coastguard Worker d = torch.exp(a) 9245*da0073e9SAndroid Build Coastguard Worker 9246*da0073e9SAndroid Build Coastguard Worker # a is leaf 9247*da0073e9SAndroid Build Coastguard Worker self.assertIs(b.grad_fn._saved_self, a) 9248*da0073e9SAndroid Build Coastguard Worker self.assertIs(b.grad_fn._saved_other, a) 9249*da0073e9SAndroid Build Coastguard Worker self.assertIs(c.grad_fn._saved_self, a) 9250*da0073e9SAndroid Build Coastguard Worker 9251*da0073e9SAndroid Build Coastguard Worker # b is not an output 9252*da0073e9SAndroid Build Coastguard Worker self.assertIs(c.grad_fn._saved_other, b) 9253*da0073e9SAndroid Build Coastguard Worker 9254*da0073e9SAndroid Build Coastguard Worker # d is an output 9255*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.grad_fn._saved_result, d) 9256*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(d.grad_fn._saved_result, d) 9257*da0073e9SAndroid Build Coastguard Worker 9258*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 9259*da0073e9SAndroid Build Coastguard Worker 9260*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "after they have already been freed"): 9261*da0073e9SAndroid Build Coastguard Worker c.grad_fn._saved_self 9262*da0073e9SAndroid Build Coastguard Worker 9263*da0073e9SAndroid Build Coastguard Worker # a is left untouched 9264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, a_orig) 9265*da0073e9SAndroid Build Coastguard Worker 9266*da0073e9SAndroid Build Coastguard Worker def test_saved_variable_version_counter(self): 9267*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, requires_grad=True) 9268*da0073e9SAndroid Build Coastguard Worker 9269*da0073e9SAndroid Build Coastguard Worker b = torch.exp(a) 9270*da0073e9SAndroid Build Coastguard Worker 9271*da0073e9SAndroid Build Coastguard Worker b_unpacked = b.grad_fn._saved_result 9272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, b_unpacked) 9273*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b._version, b_unpacked._version) 9274*da0073e9SAndroid Build Coastguard Worker 9275*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 9276*da0073e9SAndroid Build Coastguard Worker b += 1 9277*da0073e9SAndroid Build Coastguard Worker 9278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b, b_unpacked) 9279*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b._version, b_unpacked._version) 9280*da0073e9SAndroid Build Coastguard Worker 9281*da0073e9SAndroid Build Coastguard Worker def test_saved_variable_packing_unpacking_saved_original_with_hooks(self): 9282*da0073e9SAndroid Build Coastguard Worker # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks 9283*da0073e9SAndroid Build Coastguard Worker # The saved_original / did_not_save_original distinction corresponds to the `save_original` 9284*da0073e9SAndroid Build Coastguard Worker # attribute of `SavedVariable`. 9285*da0073e9SAndroid Build Coastguard Worker 9286*da0073e9SAndroid Build Coastguard Worker def test(get_input, is_leaf): 9287*da0073e9SAndroid Build Coastguard Worker a = get_input() 9288*da0073e9SAndroid Build Coastguard Worker grad_fn = a.grad_fn 9289*da0073e9SAndroid Build Coastguard Worker y = a * a 9290*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x / 2) 9291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_self) 9292*da0073e9SAndroid Build Coastguard Worker if not is_leaf: 9293*da0073e9SAndroid Build Coastguard Worker self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn) 9294*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9295*da0073e9SAndroid Build Coastguard Worker else: 9296*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9297*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, a.grad) 9298*da0073e9SAndroid Build Coastguard Worker 9299*da0073e9SAndroid Build Coastguard Worker a = get_input() 9300*da0073e9SAndroid Build Coastguard Worker grad_fn = a.grad_fn 9301*da0073e9SAndroid Build Coastguard Worker y = a * a 9302*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_self.register_hooks(lambda x: 2 * x, lambda x: x) 9303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, y.grad_fn._saved_self) 9304*da0073e9SAndroid Build Coastguard Worker if not is_leaf: 9305*da0073e9SAndroid Build Coastguard Worker self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn) 9306*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9307*da0073e9SAndroid Build Coastguard Worker else: 9308*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9309*da0073e9SAndroid Build Coastguard Worker self.assertEqual(3 * a, a.grad) 9310*da0073e9SAndroid Build Coastguard Worker 9311*da0073e9SAndroid Build Coastguard Worker # double backward 9312*da0073e9SAndroid Build Coastguard Worker a = get_input() 9313*da0073e9SAndroid Build Coastguard Worker grad_fn = a.grad_fn 9314*da0073e9SAndroid Build Coastguard Worker y = a**3 9315*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: x) 9316*da0073e9SAndroid Build Coastguard Worker s = torch.sum(y) 9317*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9318*da0073e9SAndroid Build Coastguard Worker if not is_leaf: 9319*da0073e9SAndroid Build Coastguard Worker self.assertIs(grad_fn, y.grad_fn._saved_self.grad_fn) 9320*da0073e9SAndroid Build Coastguard Worker g.sum().backward() 9321*da0073e9SAndroid Build Coastguard Worker else: 9322*da0073e9SAndroid Build Coastguard Worker g.sum().backward() 9323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(6 * a, a.grad) 9324*da0073e9SAndroid Build Coastguard Worker 9325*da0073e9SAndroid Build Coastguard Worker a = get_input() 9326*da0073e9SAndroid Build Coastguard Worker y = a * a 9327*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_self.register_hooks(lambda x: x, lambda x: 1) 9328*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9329*da0073e9SAndroid Build Coastguard Worker TypeError, "Output of saved tensor unpack_hook expected to be a Tensor" 9330*da0073e9SAndroid Build Coastguard Worker ): 9331*da0073e9SAndroid Build Coastguard Worker print(y.grad_fn._saved_self) 9332*da0073e9SAndroid Build Coastguard Worker 9333*da0073e9SAndroid Build Coastguard Worker a = get_input() 9334*da0073e9SAndroid Build Coastguard Worker y = a * a 9335*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9336*da0073e9SAndroid Build Coastguard Worker TypeError, "missing 1 required positional argument" 9337*da0073e9SAndroid Build Coastguard Worker ): 9338*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_self.register_hooks(lambda x, b: x, lambda x: x) 9339*da0073e9SAndroid Build Coastguard Worker 9340*da0073e9SAndroid Build Coastguard Worker a = get_input() 9341*da0073e9SAndroid Build Coastguard Worker y = a * a 9342*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9343*da0073e9SAndroid Build Coastguard Worker TypeError, "missing 1 required positional argument" 9344*da0073e9SAndroid Build Coastguard Worker ): 9345*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_self.register_hooks( 9346*da0073e9SAndroid Build Coastguard Worker lambda x, b: (x, b), lambda x: x 9347*da0073e9SAndroid Build Coastguard Worker ) 9348*da0073e9SAndroid Build Coastguard Worker 9349*da0073e9SAndroid Build Coastguard Worker def inplace_double(x): 9350*da0073e9SAndroid Build Coastguard Worker x *= 2 9351*da0073e9SAndroid Build Coastguard Worker return x 9352*da0073e9SAndroid Build Coastguard Worker 9353*da0073e9SAndroid Build Coastguard Worker a = get_input() 9354*da0073e9SAndroid Build Coastguard Worker t = a * a 9355*da0073e9SAndroid Build Coastguard Worker 9356*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9357*da0073e9SAndroid Build Coastguard Worker RuntimeError, 9358*da0073e9SAndroid Build Coastguard Worker "A saved tensor pack hook is modifying its input in place.", 9359*da0073e9SAndroid Build Coastguard Worker ): 9360*da0073e9SAndroid Build Coastguard Worker t.grad_fn._raw_saved_self.register_hooks( 9361*da0073e9SAndroid Build Coastguard Worker inplace_double, lambda x: x / 2 9362*da0073e9SAndroid Build Coastguard Worker ) 9363*da0073e9SAndroid Build Coastguard Worker 9364*da0073e9SAndroid Build Coastguard Worker # leaf 9365*da0073e9SAndroid Build Coastguard Worker test(lambda: torch.randn(5, requires_grad=True), True) 9366*da0073e9SAndroid Build Coastguard Worker 9367*da0073e9SAndroid Build Coastguard Worker # not leaf, not output 9368*da0073e9SAndroid Build Coastguard Worker test(lambda: (1 + torch.randn(5, requires_grad=True)), False) 9369*da0073e9SAndroid Build Coastguard Worker 9370*da0073e9SAndroid Build Coastguard Worker def test_saved_variable_saved_original_inplace_detach(self): 9371*da0073e9SAndroid Build Coastguard Worker # Detaching a tensor that is saved input raises 9372*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True).clone() 9373*da0073e9SAndroid Build Coastguard Worker b = a.sin() 9374*da0073e9SAndroid Build Coastguard Worker a.detach_() 9375*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9376*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Trying to use a saved tensor that has been detached" 9377*da0073e9SAndroid Build Coastguard Worker ): 9378*da0073e9SAndroid Build Coastguard Worker b.backward() 9379*da0073e9SAndroid Build Coastguard Worker 9380*da0073e9SAndroid Build Coastguard Worker # Detaching a tensor that is saved as output is OK 9381*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True).clone() 9382*da0073e9SAndroid Build Coastguard Worker b = a.exp() 9383*da0073e9SAndroid Build Coastguard Worker a.detach_() 9384*da0073e9SAndroid Build Coastguard Worker b.backward() 9385*da0073e9SAndroid Build Coastguard Worker 9386*da0073e9SAndroid Build Coastguard Worker def test_saved_variable_packing_unpacking_did_not_save_original_with_hooks(self): 9387*da0073e9SAndroid Build Coastguard Worker # Tests that packing/unpacking a SavedVariable works correctly with user-defined hooks 9388*da0073e9SAndroid Build Coastguard Worker # The saved_original / did_not_save_original distinction corresponds to the `save_original` 9389*da0073e9SAndroid Build Coastguard Worker # attribute of `SavedVariable`. 9390*da0073e9SAndroid Build Coastguard Worker 9391*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9392*da0073e9SAndroid Build Coastguard Worker y = torch.exp(a) 9393*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_result.register_hooks(lambda x: x, lambda x: x) 9394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y.grad_fn._saved_result) 9395*da0073e9SAndroid Build Coastguard Worker self.assertIs(y.grad_fn, y.grad_fn._saved_result.grad_fn) 9396*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, y) 9398*da0073e9SAndroid Build Coastguard Worker 9399*da0073e9SAndroid Build Coastguard Worker def test_saved_variable_packing_unpacking_saved_original_with_default_hooks(self): 9400*da0073e9SAndroid Build Coastguard Worker # Tests that default hooks are properly registered, used and reset 9401*da0073e9SAndroid Build Coastguard Worker # The saved_original / did_not_save_original distinction corresponds to the `save_original` 9402*da0073e9SAndroid Build Coastguard Worker # attribute of `SavedVariable`. 9403*da0073e9SAndroid Build Coastguard Worker # See also: 9404*da0073e9SAndroid Build Coastguard Worker # - test_saved_variable_packing_unpacking_saved_original_with_hooks 9405*da0073e9SAndroid Build Coastguard Worker 9406*da0073e9SAndroid Build Coastguard Worker def pack(x): 9407*da0073e9SAndroid Build Coastguard Worker warnings.warn("pack") 9408*da0073e9SAndroid Build Coastguard Worker return x 9409*da0073e9SAndroid Build Coastguard Worker 9410*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): 9411*da0073e9SAndroid Build Coastguard Worker a = torch.ones(5, requires_grad=True) 9412*da0073e9SAndroid Build Coastguard Worker 9413*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 9414*da0073e9SAndroid Build Coastguard Worker warnings.simplefilter("always") 9415*da0073e9SAndroid Build Coastguard Worker y = a * a 9416*da0073e9SAndroid Build Coastguard Worker # should raise two warnings from a being saved twice 9417*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(w), 2) 9418*da0073e9SAndroid Build Coastguard Worker 9419*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9420*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9421*da0073e9SAndroid Build Coastguard Worker y = a * a 9422*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_self) 9423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_other) 9424*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9425*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, a.grad) 9426*da0073e9SAndroid Build Coastguard Worker 9427*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x / 2): 9428*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9429*da0073e9SAndroid Build Coastguard Worker y = a * a 9430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_self) 9431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_other) 9432*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9433*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, a.grad) 9434*da0073e9SAndroid Build Coastguard Worker 9435*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9436*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9437*da0073e9SAndroid Build Coastguard Worker y = a * a 9438*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, y.grad_fn._saved_self) 9439*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, y.grad_fn._saved_other) 9440*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9441*da0073e9SAndroid Build Coastguard Worker self.assertEqual(4 * a, a.grad) 9442*da0073e9SAndroid Build Coastguard Worker 9443*da0073e9SAndroid Build Coastguard Worker # Exited hooks correctly 9444*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9445*da0073e9SAndroid Build Coastguard Worker y = a * a 9446*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_self) 9447*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_other) 9448*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9449*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, a.grad) 9450*da0073e9SAndroid Build Coastguard Worker 9451*da0073e9SAndroid Build Coastguard Worker def test_saved_variable_packing_unpacking_did_not_save_original_with_default_hooks( 9452*da0073e9SAndroid Build Coastguard Worker self, 9453*da0073e9SAndroid Build Coastguard Worker ): 9454*da0073e9SAndroid Build Coastguard Worker # See also test_saved_variable_packing_unpacking_did_not_save_original_with_hooks 9455*da0073e9SAndroid Build Coastguard Worker 9456*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9457*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9458*da0073e9SAndroid Build Coastguard Worker y = torch.exp(a) 9459*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y.grad_fn._saved_result) 9460*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9461*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, y) 9462*da0073e9SAndroid Build Coastguard Worker 9463*da0073e9SAndroid Build Coastguard Worker def test_setting_default_saved_variable_hooks_twice_should_not_fail(self): 9464*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9465*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9466*da0073e9SAndroid Build Coastguard Worker pass 9467*da0073e9SAndroid Build Coastguard Worker 9468*da0073e9SAndroid Build Coastguard Worker def test_setting_default_saved_variable_hooks_twice_should_use_inner(self): 9469*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: 3 * x, lambda x: 3 * x): 9470*da0073e9SAndroid Build Coastguard Worker b = torch.randn(5, requires_grad=True) 9471*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks( 9472*da0073e9SAndroid Build Coastguard Worker lambda x: 5 * x, lambda x: 5 * x 9473*da0073e9SAndroid Build Coastguard Worker ): 9474*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9475*da0073e9SAndroid Build Coastguard Worker y = a * a 9476*da0073e9SAndroid Build Coastguard Worker z = b * b 9477*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9478*da0073e9SAndroid Build Coastguard Worker z.sum().backward() 9479*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * 5 * 5 * a, a.grad) 9480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * 3 * 3 * b, b.grad) 9481*da0073e9SAndroid Build Coastguard Worker 9482*da0073e9SAndroid Build Coastguard Worker def test_disabling_saved_tensor_hooks(self): 9483*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.disable_saved_tensors_hooks("error message"): 9484*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "error message"): 9485*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9486*da0073e9SAndroid Build Coastguard Worker pass 9487*da0073e9SAndroid Build Coastguard Worker 9488*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9489*da0073e9SAndroid Build Coastguard Worker 9490*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9491*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "error message"): 9492*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.disable_saved_tensors_hooks("error message"): 9493*da0073e9SAndroid Build Coastguard Worker pass 9494*da0073e9SAndroid Build Coastguard Worker 9495*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9496*da0073e9SAndroid Build Coastguard Worker 9497*da0073e9SAndroid Build Coastguard Worker def test_disabling_saved_tensor_hooks_nested(self): 9498*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.disable_saved_tensors_hooks("outer"): 9499*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.disable_saved_tensors_hooks("inner"): 9500*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "inner"): 9501*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks( 9502*da0073e9SAndroid Build Coastguard Worker lambda x: x, lambda x: x 9503*da0073e9SAndroid Build Coastguard Worker ): 9504*da0073e9SAndroid Build Coastguard Worker pass 9505*da0073e9SAndroid Build Coastguard Worker 9506*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9507*da0073e9SAndroid Build Coastguard Worker 9508*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._autograd._saved_tensors_hooks_is_enabled()) 9509*da0073e9SAndroid Build Coastguard Worker 9510*da0073e9SAndroid Build Coastguard Worker def test_saved_tensor_hooks_custom_error_propagation(self): 9511*da0073e9SAndroid Build Coastguard Worker class CustomError(Exception): 9512*da0073e9SAndroid Build Coastguard Worker pass 9513*da0073e9SAndroid Build Coastguard Worker 9514*da0073e9SAndroid Build Coastguard Worker class error_on_pack_hook(torch.autograd.graph.saved_tensors_hooks): 9515*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9516*da0073e9SAndroid Build Coastguard Worker def pack_hook(x): 9517*da0073e9SAndroid Build Coastguard Worker raise CustomError("pack") 9518*da0073e9SAndroid Build Coastguard Worker 9519*da0073e9SAndroid Build Coastguard Worker super().__init__(pack_hook, lambda x: x) 9520*da0073e9SAndroid Build Coastguard Worker 9521*da0073e9SAndroid Build Coastguard Worker class error_on_unpack_hook(torch.autograd.graph.saved_tensors_hooks): 9522*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9523*da0073e9SAndroid Build Coastguard Worker def unpack_hook(x): 9524*da0073e9SAndroid Build Coastguard Worker raise CustomError("unpack") 9525*da0073e9SAndroid Build Coastguard Worker 9526*da0073e9SAndroid Build Coastguard Worker super().__init__(lambda x: x, unpack_hook) 9527*da0073e9SAndroid Build Coastguard Worker 9528*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 9529*da0073e9SAndroid Build Coastguard Worker 9530*da0073e9SAndroid Build Coastguard Worker with error_on_pack_hook(): 9531*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(CustomError, "pack"): 9532*da0073e9SAndroid Build Coastguard Worker out = torch.sin(a) 9533*da0073e9SAndroid Build Coastguard Worker 9534*da0073e9SAndroid Build Coastguard Worker with error_on_unpack_hook(): 9535*da0073e9SAndroid Build Coastguard Worker out = torch.sin(a) 9536*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(CustomError, "unpack"): 9537*da0073e9SAndroid Build Coastguard Worker out.backward() 9538*da0073e9SAndroid Build Coastguard Worker 9539*da0073e9SAndroid Build Coastguard Worker def test_saved_tensor_hooks_custom_function_intermediates(self): 9540*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 9541*da0073e9SAndroid Build Coastguard Worker @staticmethod 9542*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 9543*da0073e9SAndroid Build Coastguard Worker intermediate = x.exp() 9544*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward( 9545*da0073e9SAndroid Build Coastguard Worker intermediate.clone().detach_().requires_grad_(True) 9546*da0073e9SAndroid Build Coastguard Worker ) 9547*da0073e9SAndroid Build Coastguard Worker return x.exp() 9548*da0073e9SAndroid Build Coastguard Worker 9549*da0073e9SAndroid Build Coastguard Worker @staticmethod 9550*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_out): 9551*da0073e9SAndroid Build Coastguard Worker (intermediate,) = ctx.saved_tensors 9552*da0073e9SAndroid Build Coastguard Worker return grad_out * intermediate 9553*da0073e9SAndroid Build Coastguard Worker 9554*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 9555*da0073e9SAndroid Build Coastguard Worker 9556*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9557*da0073e9SAndroid Build Coastguard Worker out = Func.apply(a) 9558*da0073e9SAndroid Build Coastguard Worker out.backward() 9559*da0073e9SAndroid Build Coastguard Worker 9560*da0073e9SAndroid Build Coastguard Worker def test_unpack_hooks_exec_count(self): 9561*da0073e9SAndroid Build Coastguard Worker def f(x, y): 9562*da0073e9SAndroid Build Coastguard Worker return x * y 9563*da0073e9SAndroid Build Coastguard Worker 9564*da0073e9SAndroid Build Coastguard Worker pack_count = 0 9565*da0073e9SAndroid Build Coastguard Worker unpack_count = 0 9566*da0073e9SAndroid Build Coastguard Worker 9567*da0073e9SAndroid Build Coastguard Worker def pack_hook(x): 9568*da0073e9SAndroid Build Coastguard Worker nonlocal pack_count 9569*da0073e9SAndroid Build Coastguard Worker pack_count += 1 9570*da0073e9SAndroid Build Coastguard Worker return x 9571*da0073e9SAndroid Build Coastguard Worker 9572*da0073e9SAndroid Build Coastguard Worker # unpack hook shouldn't run during compilation, while we trace the forward 9573*da0073e9SAndroid Build Coastguard Worker def unpack_hook(x): 9574*da0073e9SAndroid Build Coastguard Worker nonlocal unpack_count 9575*da0073e9SAndroid Build Coastguard Worker unpack_count += 1 9576*da0073e9SAndroid Build Coastguard Worker return x 9577*da0073e9SAndroid Build Coastguard Worker 9578*da0073e9SAndroid Build Coastguard Worker x = torch.ones(4, requires_grad=True) 9579*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4, requires_grad=False) 9580*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 9581*da0073e9SAndroid Build Coastguard Worker out_test = f(x, y) 9582*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pack_count, 1) 9583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpack_count, 0) 9584*da0073e9SAndroid Build Coastguard Worker out_test.sum().backward() 9585*da0073e9SAndroid Build Coastguard Worker self.assertEqual(pack_count, 1) 9586*da0073e9SAndroid Build Coastguard Worker self.assertEqual(unpack_count, 1) 9587*da0073e9SAndroid Build Coastguard Worker 9588*da0073e9SAndroid Build Coastguard Worker def test_saved_tensors_hook_version_counter_not_shared(self): 9589*da0073e9SAndroid Build Coastguard Worker class Test(torch.autograd.Function): 9590*da0073e9SAndroid Build Coastguard Worker @staticmethod 9591*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 9592*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 9593*da0073e9SAndroid Build Coastguard Worker return x.sin() 9594*da0073e9SAndroid Build Coastguard Worker 9595*da0073e9SAndroid Build Coastguard Worker @staticmethod 9596*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 9597*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 9598*da0073e9SAndroid Build Coastguard Worker before = a._version 9599*da0073e9SAndroid Build Coastguard Worker x.add_(1) 9600*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a._version, before) 9601*da0073e9SAndroid Build Coastguard Worker return grad_output 9602*da0073e9SAndroid Build Coastguard Worker 9603*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 9604*da0073e9SAndroid Build Coastguard Worker a_replacement = a.clone() 9605*da0073e9SAndroid Build Coastguard Worker 9606*da0073e9SAndroid Build Coastguard Worker def pack_hook(x): 9607*da0073e9SAndroid Build Coastguard Worker return a_replacement 9608*da0073e9SAndroid Build Coastguard Worker 9609*da0073e9SAndroid Build Coastguard Worker def unpack_hook(x): 9610*da0073e9SAndroid Build Coastguard Worker return x 9611*da0073e9SAndroid Build Coastguard Worker 9612*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook): 9613*da0073e9SAndroid Build Coastguard Worker b = Test.apply(a) 9614*da0073e9SAndroid Build Coastguard Worker 9615*da0073e9SAndroid Build Coastguard Worker b.backward() 9616*da0073e9SAndroid Build Coastguard Worker 9617*da0073e9SAndroid Build Coastguard Worker def test_save_on_cpu_and_checkpoint(self): 9618*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, requires_grad=True) 9619*da0073e9SAndroid Build Coastguard Worker 9620*da0073e9SAndroid Build Coastguard Worker b = a.pow(2).pow(2).pow(2).pow(2) 9621*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 9622*da0073e9SAndroid Build Coastguard Worker b_grad = a.grad.clone() 9623*da0073e9SAndroid Build Coastguard Worker a.grad.zero_() 9624*da0073e9SAndroid Build Coastguard Worker 9625*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.save_on_cpu(): 9626*da0073e9SAndroid Build Coastguard Worker h = a.pow(2) 9627*da0073e9SAndroid Build Coastguard Worker h = checkpoint(lambda x: x.pow(2).pow(2), h, use_reentrant=False) 9628*da0073e9SAndroid Build Coastguard Worker c = h.pow(2) 9629*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 9630*da0073e9SAndroid Build Coastguard Worker c_grad = a.grad.clone() 9631*da0073e9SAndroid Build Coastguard Worker a.grad.zero_() 9632*da0073e9SAndroid Build Coastguard Worker 9633*da0073e9SAndroid Build Coastguard Worker def f(a): 9634*da0073e9SAndroid Build Coastguard Worker h = a.pow(2) 9635*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.save_on_cpu(): 9636*da0073e9SAndroid Build Coastguard Worker h = h.pow(2).pow(2) 9637*da0073e9SAndroid Build Coastguard Worker return h.pow(2) 9638*da0073e9SAndroid Build Coastguard Worker 9639*da0073e9SAndroid Build Coastguard Worker d = checkpoint(f, a, use_reentrant=False) 9640*da0073e9SAndroid Build Coastguard Worker d.sum().backward() 9641*da0073e9SAndroid Build Coastguard Worker d_grad = a.grad.clone() 9642*da0073e9SAndroid Build Coastguard Worker 9643*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_grad, c_grad) 9644*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b_grad, d_grad) 9645*da0073e9SAndroid Build Coastguard Worker 9646*da0073e9SAndroid Build Coastguard Worker def test_pack_hook_with_inplace_modification_should_fail(self): 9647*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9648*da0073e9SAndroid Build Coastguard Worker 9649*da0073e9SAndroid Build Coastguard Worker def inc(x): 9650*da0073e9SAndroid Build Coastguard Worker x += 1 9651*da0073e9SAndroid Build Coastguard Worker return x 9652*da0073e9SAndroid Build Coastguard Worker 9653*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(inc, lambda x: x): 9654*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9655*da0073e9SAndroid Build Coastguard Worker RuntimeError, 9656*da0073e9SAndroid Build Coastguard Worker "A saved tensor pack hook is modifying its input in place.", 9657*da0073e9SAndroid Build Coastguard Worker ): 9658*da0073e9SAndroid Build Coastguard Worker y = torch.exp(a) 9659*da0073e9SAndroid Build Coastguard Worker 9660*da0073e9SAndroid Build Coastguard Worker y = torch.exp(a) 9661*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 9662*da0073e9SAndroid Build Coastguard Worker RuntimeError, "A saved tensor pack hook is modifying its input in place." 9663*da0073e9SAndroid Build Coastguard Worker ): 9664*da0073e9SAndroid Build Coastguard Worker y.grad_fn._raw_saved_result.register_hooks(inc, lambda x: x) 9665*da0073e9SAndroid Build Coastguard Worker 9666*da0073e9SAndroid Build Coastguard Worker def test_saving_variable_to_disk(self): 9667*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryDirectory() as tmp_dir: 9668*da0073e9SAndroid Build Coastguard Worker 9669*da0073e9SAndroid Build Coastguard Worker def pack(x): 9670*da0073e9SAndroid Build Coastguard Worker name = os.path.join(tmp_dir, str(uuid.uuid4())) 9671*da0073e9SAndroid Build Coastguard Worker torch.save(x, name) 9672*da0073e9SAndroid Build Coastguard Worker return name 9673*da0073e9SAndroid Build Coastguard Worker 9674*da0073e9SAndroid Build Coastguard Worker def unpack(name): 9675*da0073e9SAndroid Build Coastguard Worker return torch.load(name) 9676*da0073e9SAndroid Build Coastguard Worker 9677*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(pack, unpack): 9678*da0073e9SAndroid Build Coastguard Worker a = torch.ones(5, requires_grad=True) 9679*da0073e9SAndroid Build Coastguard Worker y = a * a 9680*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_self) 9681*da0073e9SAndroid Build Coastguard Worker 9682*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(2 * a, a.grad) 9684*da0073e9SAndroid Build Coastguard Worker 9685*da0073e9SAndroid Build Coastguard Worker def test_default_saved_tensors_hooks_double_backward(self): 9686*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: x, lambda x: x): 9687*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9688*da0073e9SAndroid Build Coastguard Worker y = a**3 9689*da0073e9SAndroid Build Coastguard Worker s = torch.sum(y) 9690*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9691*da0073e9SAndroid Build Coastguard Worker g.sum().backward() 9692*da0073e9SAndroid Build Coastguard Worker self.assertEqual(6 * a, a.grad) 9693*da0073e9SAndroid Build Coastguard Worker 9694*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9695*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9696*da0073e9SAndroid Build Coastguard Worker y = a**3 9697*da0073e9SAndroid Build Coastguard Worker s = torch.sum(y) 9698*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9699*da0073e9SAndroid Build Coastguard Worker g.sum().backward() 9700*da0073e9SAndroid Build Coastguard Worker # factor 2 because only a is saved once 9701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(6 * 2 * a, a.grad) 9702*da0073e9SAndroid Build Coastguard Worker 9703*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9704*da0073e9SAndroid Build Coastguard Worker y = a**3 9705*da0073e9SAndroid Build Coastguard Worker s = torch.sum(y) 9706*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9707*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9708*da0073e9SAndroid Build Coastguard Worker g.sum().backward() 9709*da0073e9SAndroid Build Coastguard Worker # factor 4 because pow_backward is grad * (exp * self.pow(exp - 1)) 9710*da0073e9SAndroid Build Coastguard Worker # so grad is saved and self (i.e. a) is saved 9711*da0073e9SAndroid Build Coastguard Worker self.assertEqual(6 * 4 * a, a.grad) 9712*da0073e9SAndroid Build Coastguard Worker 9713*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(lambda x: 2 * x, lambda x: x): 9714*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9715*da0073e9SAndroid Build Coastguard Worker y = a**3 9716*da0073e9SAndroid Build Coastguard Worker s = torch.sum(y) 9717*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad(s, (a,), create_graph=True) 9718*da0073e9SAndroid Build Coastguard Worker g.sum().backward() 9719*da0073e9SAndroid Build Coastguard Worker # combining the two above blocks: 2 * 4 = 8 9720*da0073e9SAndroid Build Coastguard Worker # note that in that sense, a is saved twice 9721*da0073e9SAndroid Build Coastguard Worker self.assertEqual(6 * 8 * a, a.grad) 9722*da0073e9SAndroid Build Coastguard Worker 9723*da0073e9SAndroid Build Coastguard Worker def test_wrapped_number_saved_tensors_hooks(self): 9724*da0073e9SAndroid Build Coastguard Worker def err_hook(x): 9725*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("this hook should not be called") 9726*da0073e9SAndroid Build Coastguard Worker 9727*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(err_hook, err_hook): 9728*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, requires_grad=True) 9729*da0073e9SAndroid Build Coastguard Worker out = (a * 3).sum() 9730*da0073e9SAndroid Build Coastguard Worker # 3 is saved as a saved tensor because it is a wrapped number, but 9731*da0073e9SAndroid Build Coastguard Worker # wrapped numbers should be special cased to not trigger saved variable hooks 9732*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out, (a,)) 9733*da0073e9SAndroid Build Coastguard Worker 9734*da0073e9SAndroid Build Coastguard Worker def test_graph_save_on_cpu(self): 9735*da0073e9SAndroid Build Coastguard Worker def test(get_input, cuda, pin_memory): 9736*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.save_on_cpu(pin_memory): 9737*da0073e9SAndroid Build Coastguard Worker a = get_input() 9738*da0073e9SAndroid Build Coastguard Worker if cuda: 9739*da0073e9SAndroid Build Coastguard Worker a.cuda() 9740*da0073e9SAndroid Build Coastguard Worker y = a * a 9741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_self) 9742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a, y.grad_fn._saved_other) 9743*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, y.grad_fn._saved_self.dtype) 9744*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.layout, y.grad_fn._saved_self.layout) 9745*da0073e9SAndroid Build Coastguard Worker if y.is_sparse: 9746*da0073e9SAndroid Build Coastguard Worker y = y.to_dense() 9747*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 9748*da0073e9SAndroid Build Coastguard Worker 9749*da0073e9SAndroid Build Coastguard Worker actual = 2 * a 9750*da0073e9SAndroid Build Coastguard Worker expected = a.grad 9751*da0073e9SAndroid Build Coastguard Worker if a.is_sparse: 9752*da0073e9SAndroid Build Coastguard Worker actual = actual.coalesce() 9753*da0073e9SAndroid Build Coastguard Worker expected = expected.coalesce() 9754*da0073e9SAndroid Build Coastguard Worker 9755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(actual, expected) 9756*da0073e9SAndroid Build Coastguard Worker 9757*da0073e9SAndroid Build Coastguard Worker for cuda in [False] + ([True] if torch.cuda.is_available() else []): 9758*da0073e9SAndroid Build Coastguard Worker for pin_memory in [True, False]: 9759*da0073e9SAndroid Build Coastguard Worker # FloatTensor 9760*da0073e9SAndroid Build Coastguard Worker test(lambda: torch.randn(5, requires_grad=True), cuda, pin_memory) 9761*da0073e9SAndroid Build Coastguard Worker # DoubleTensor 9762*da0073e9SAndroid Build Coastguard Worker test( 9763*da0073e9SAndroid Build Coastguard Worker lambda: torch.randn(5, requires_grad=True, dtype=torch.double), 9764*da0073e9SAndroid Build Coastguard Worker cuda, 9765*da0073e9SAndroid Build Coastguard Worker pin_memory, 9766*da0073e9SAndroid Build Coastguard Worker ) 9767*da0073e9SAndroid Build Coastguard Worker # Sparse tensor 9768*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor( 9769*da0073e9SAndroid Build Coastguard Worker torch.tensor([[1, 1]]).long(), 9770*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.0, 1.0]), 9771*da0073e9SAndroid Build Coastguard Worker requires_grad=True, 9772*da0073e9SAndroid Build Coastguard Worker ) 9773*da0073e9SAndroid Build Coastguard Worker test(lambda: x, cuda, pin_memory) 9774*da0073e9SAndroid Build Coastguard Worker 9775*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 9776*da0073e9SAndroid Build Coastguard Worker def test_graph_save_on_cpu_cuda(self): 9777*da0073e9SAndroid Build Coastguard Worker def f(x): 9778*da0073e9SAndroid Build Coastguard Worker a = x + 1 9779*da0073e9SAndroid Build Coastguard Worker return a * a 9780*da0073e9SAndroid Build Coastguard Worker 9781*da0073e9SAndroid Build Coastguard Worker # with grad 9782*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, requires_grad=True, device="cuda") 9783*da0073e9SAndroid Build Coastguard Worker y = f(a) 9784*da0073e9SAndroid Build Coastguard Worker memory_with_grad = torch.cuda.memory_allocated() 9785*da0073e9SAndroid Build Coastguard Worker 9786*da0073e9SAndroid Build Coastguard Worker del a 9787*da0073e9SAndroid Build Coastguard Worker del y 9788*da0073e9SAndroid Build Coastguard Worker 9789*da0073e9SAndroid Build Coastguard Worker # without grad 9790*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, requires_grad=True, device="cuda") 9791*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 9792*da0073e9SAndroid Build Coastguard Worker y = f(a) 9793*da0073e9SAndroid Build Coastguard Worker memory_without_grad = torch.cuda.memory_allocated() 9794*da0073e9SAndroid Build Coastguard Worker 9795*da0073e9SAndroid Build Coastguard Worker self.assertGreater(memory_with_grad, memory_without_grad) 9796*da0073e9SAndroid Build Coastguard Worker 9797*da0073e9SAndroid Build Coastguard Worker del a 9798*da0073e9SAndroid Build Coastguard Worker del y 9799*da0073e9SAndroid Build Coastguard Worker 9800*da0073e9SAndroid Build Coastguard Worker # with hooks 9801*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.save_on_cpu(): 9802*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, requires_grad=True, device="cuda") 9803*da0073e9SAndroid Build Coastguard Worker y = f(a) 9804*da0073e9SAndroid Build Coastguard Worker memory_with_hooks = torch.cuda.memory_allocated() 9805*da0073e9SAndroid Build Coastguard Worker self.assertEqual(memory_with_hooks, memory_without_grad) 9806*da0073e9SAndroid Build Coastguard Worker 9807*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 9808*da0073e9SAndroid Build Coastguard Worker def test_scalar_grad_mixed_device(self): 9809*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(1.0, requires_grad=True) 9810*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, 2, device="cuda") 9811*da0073e9SAndroid Build Coastguard Worker out = x * y 9812*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 9813*da0073e9SAndroid Build Coastguard Worker 9814*da0073e9SAndroid Build Coastguard Worker def test_multi_grad_all_hooks(self): 9815*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(2, requires_grad=True) 9816*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(2, requires_grad=True) 9817*da0073e9SAndroid Build Coastguard Worker t3 = torch.rand(2, requires_grad=True) 9818*da0073e9SAndroid Build Coastguard Worker t4 = torch.rand(2, requires_grad=True) 9819*da0073e9SAndroid Build Coastguard Worker 9820*da0073e9SAndroid Build Coastguard Worker # Ensure we properly detect all types of Nodes here 9821*da0073e9SAndroid Build Coastguard Worker # C++ Node 9822*da0073e9SAndroid Build Coastguard Worker t1 = t1.mul(2) 9823*da0073e9SAndroid Build Coastguard Worker 9824*da0073e9SAndroid Build Coastguard Worker # Python custom Function 9825*da0073e9SAndroid Build Coastguard Worker class Foo(Function): 9826*da0073e9SAndroid Build Coastguard Worker @staticmethod 9827*da0073e9SAndroid Build Coastguard Worker def forward(ctx, a): 9828*da0073e9SAndroid Build Coastguard Worker return a.clone() 9829*da0073e9SAndroid Build Coastguard Worker 9830*da0073e9SAndroid Build Coastguard Worker @staticmethod 9831*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 9832*da0073e9SAndroid Build Coastguard Worker return gO 9833*da0073e9SAndroid Build Coastguard Worker 9834*da0073e9SAndroid Build Coastguard Worker t2 = Foo.apply(t2) 9835*da0073e9SAndroid Build Coastguard Worker 9836*da0073e9SAndroid Build Coastguard Worker # C++ Node 9837*da0073e9SAndroid Build Coastguard Worker t3 = torch._C._functions.UndefinedGrad()(t3) 9838*da0073e9SAndroid Build Coastguard Worker 9839*da0073e9SAndroid Build Coastguard Worker # C++ Custom Op 9840*da0073e9SAndroid Build Coastguard Worker cpp_source = """ 9841*da0073e9SAndroid Build Coastguard Workerstruct CustomOpAutogradFunction : public torch::autograd::Function<CustomOpAutogradFunction> { 9842*da0073e9SAndroid Build Coastguard Worker static torch::Tensor forward( 9843*da0073e9SAndroid Build Coastguard Worker torch::autograd::AutogradContext* ctx, 9844*da0073e9SAndroid Build Coastguard Worker const torch::Tensor& x) { 9845*da0073e9SAndroid Build Coastguard Worker return x.clone(); 9846*da0073e9SAndroid Build Coastguard Worker } 9847*da0073e9SAndroid Build Coastguard Worker 9848*da0073e9SAndroid Build Coastguard Worker static torch::autograd::variable_list backward( 9849*da0073e9SAndroid Build Coastguard Worker torch::autograd::AutogradContext *ctx, 9850*da0073e9SAndroid Build Coastguard Worker torch::autograd::variable_list grad_output) { 9851*da0073e9SAndroid Build Coastguard Worker return grad_output; 9852*da0073e9SAndroid Build Coastguard Worker } 9853*da0073e9SAndroid Build Coastguard Worker}; 9854*da0073e9SAndroid Build Coastguard Worker 9855*da0073e9SAndroid Build Coastguard Workertorch::Tensor custom_op_backed_by_autograd_fn(torch::Tensor x) { 9856*da0073e9SAndroid Build Coastguard Worker return CustomOpAutogradFunction::apply(x); 9857*da0073e9SAndroid Build Coastguard Worker} 9858*da0073e9SAndroid Build Coastguard Worker 9859*da0073e9SAndroid Build Coastguard WorkerTORCH_LIBRARY(test_autograd_cpp_node, m) { 9860*da0073e9SAndroid Build Coastguard Worker m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); 9861*da0073e9SAndroid Build Coastguard Worker} 9862*da0073e9SAndroid Build Coastguard Worker """ 9863*da0073e9SAndroid Build Coastguard Worker 9864*da0073e9SAndroid Build Coastguard Worker module = load_inline( 9865*da0073e9SAndroid Build Coastguard Worker name="test_autograd_cpp_node", 9866*da0073e9SAndroid Build Coastguard Worker cpp_sources=cpp_source, 9867*da0073e9SAndroid Build Coastguard Worker functions="custom_op_backed_by_autograd_fn", 9868*da0073e9SAndroid Build Coastguard Worker verbose=True, 9869*da0073e9SAndroid Build Coastguard Worker ) 9870*da0073e9SAndroid Build Coastguard Worker 9871*da0073e9SAndroid Build Coastguard Worker t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4) 9872*da0073e9SAndroid Build Coastguard Worker 9873*da0073e9SAndroid Build Coastguard Worker res = [None] * 4 9874*da0073e9SAndroid Build Coastguard Worker count = [0] 9875*da0073e9SAndroid Build Coastguard Worker 9876*da0073e9SAndroid Build Coastguard Worker def hook(grads): 9877*da0073e9SAndroid Build Coastguard Worker nonlocal res 9878*da0073e9SAndroid Build Coastguard Worker count[0] += 1 9879*da0073e9SAndroid Build Coastguard Worker res = [g is not None for g in grads] 9880*da0073e9SAndroid Build Coastguard Worker 9881*da0073e9SAndroid Build Coastguard Worker handle = torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook) 9882*da0073e9SAndroid Build Coastguard Worker 9883*da0073e9SAndroid Build Coastguard Worker out = t2 * t3 9884*da0073e9SAndroid Build Coastguard Worker 9885*da0073e9SAndroid Build Coastguard Worker out.sum().backward(inputs=(t2, t3), retain_graph=True) 9886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 1) 9887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, [False, True, True, False]) 9888*da0073e9SAndroid Build Coastguard Worker 9889*da0073e9SAndroid Build Coastguard Worker out.sum().backward(inputs=(t1, t4), retain_graph=True) 9890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 1) 9891*da0073e9SAndroid Build Coastguard Worker 9892*da0073e9SAndroid Build Coastguard Worker out.sum().backward(inputs=(t1, t3), retain_graph=True) 9893*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 2) 9894*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, [False, False, True, False]) 9895*da0073e9SAndroid Build Coastguard Worker 9896*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 9897*da0073e9SAndroid Build Coastguard Worker @staticmethod 9898*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 9899*da0073e9SAndroid Build Coastguard Worker return x 9900*da0073e9SAndroid Build Coastguard Worker 9901*da0073e9SAndroid Build Coastguard Worker @staticmethod 9902*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 9903*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("error message") 9904*da0073e9SAndroid Build Coastguard Worker 9905*da0073e9SAndroid Build Coastguard Worker out = Func.apply(t2) * t3 9906*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "error message"): 9907*da0073e9SAndroid Build Coastguard Worker out.sum().backward(inputs=(t2, t3), retain_graph=True) 9908*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 2) 9909*da0073e9SAndroid Build Coastguard Worker 9910*da0073e9SAndroid Build Coastguard Worker handle.remove() 9911*da0073e9SAndroid Build Coastguard Worker out.sum().backward(inputs=(t1, t3), retain_graph=True) 9912*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 2) 9913*da0073e9SAndroid Build Coastguard Worker 9914*da0073e9SAndroid Build Coastguard Worker def test_multi_grad_any_hooks(self): 9915*da0073e9SAndroid Build Coastguard Worker hook_id = 0 9916*da0073e9SAndroid Build Coastguard Worker any_hook_handles: List[RemovableHandle] = [] 9917*da0073e9SAndroid Build Coastguard Worker 9918*da0073e9SAndroid Build Coastguard Worker class MultiOutputModule(nn.Module): 9919*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9920*da0073e9SAndroid Build Coastguard Worker super().__init__() 9921*da0073e9SAndroid Build Coastguard Worker self.lin = nn.Linear(3, 3) 9922*da0073e9SAndroid Build Coastguard Worker 9923*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 9924*da0073e9SAndroid Build Coastguard Worker z = self.lin(x) 9925*da0073e9SAndroid Build Coastguard Worker out = torch.sin(z), torch.cos(z) 9926*da0073e9SAndroid Build Coastguard Worker nonlocal hook_id 9927*da0073e9SAndroid Build Coastguard Worker z.register_hook(partial(hook, hook_id)) 9928*da0073e9SAndroid Build Coastguard Worker hook_id += 1 9929*da0073e9SAndroid Build Coastguard Worker any_hook_handles.append( 9930*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.register_multi_grad_hook( 9931*da0073e9SAndroid Build Coastguard Worker out, partial(hook, hook_id), mode="any" 9932*da0073e9SAndroid Build Coastguard Worker ) 9933*da0073e9SAndroid Build Coastguard Worker ) 9934*da0073e9SAndroid Build Coastguard Worker hook_id += 1 9935*da0073e9SAndroid Build Coastguard Worker return out 9936*da0073e9SAndroid Build Coastguard Worker 9937*da0073e9SAndroid Build Coastguard Worker class Model(nn.Module): 9938*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 9939*da0073e9SAndroid Build Coastguard Worker super().__init__() 9940*da0073e9SAndroid Build Coastguard Worker self.mod1 = MultiOutputModule() 9941*da0073e9SAndroid Build Coastguard Worker self.mod2 = MultiOutputModule() 9942*da0073e9SAndroid Build Coastguard Worker 9943*da0073e9SAndroid Build Coastguard Worker def forward(self, x: torch.Tensor) -> torch.Tensor: 9944*da0073e9SAndroid Build Coastguard Worker y = self.mod1(x) 9945*da0073e9SAndroid Build Coastguard Worker z = y[0] + y[1] 9946*da0073e9SAndroid Build Coastguard Worker return self.mod2(z) 9947*da0073e9SAndroid Build Coastguard Worker 9948*da0073e9SAndroid Build Coastguard Worker hook_order: List[int] = [] 9949*da0073e9SAndroid Build Coastguard Worker hook_count = 0 9950*da0073e9SAndroid Build Coastguard Worker 9951*da0073e9SAndroid Build Coastguard Worker def hook(hook_id: int, *unused): 9952*da0073e9SAndroid Build Coastguard Worker nonlocal hook_count 9953*da0073e9SAndroid Build Coastguard Worker nonlocal hook_order 9954*da0073e9SAndroid Build Coastguard Worker hook_count += 1 9955*da0073e9SAndroid Build Coastguard Worker hook_order.append(hook_id) 9956*da0073e9SAndroid Build Coastguard Worker 9957*da0073e9SAndroid Build Coastguard Worker # Any hooks: IDs 1 and 3; regular hooks: IDs 0 and 2 9958*da0073e9SAndroid Build Coastguard Worker model = Model() 9959*da0073e9SAndroid Build Coastguard Worker inp = torch.randn((2, 3)) 9960*da0073e9SAndroid Build Coastguard Worker out = model(inp) 9961*da0073e9SAndroid Build Coastguard Worker (out[0] + out[1]).sum().backward() 9962*da0073e9SAndroid Build Coastguard Worker # Check that the any-hook runs only once and before the regular hook 9963*da0073e9SAndroid Build Coastguard Worker # for each module 9964*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(any_hook_handles), 2) 9965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hook_order, [3, 2, 1, 0]) 9966*da0073e9SAndroid Build Coastguard Worker 9967*da0073e9SAndroid Build Coastguard Worker hook_id = 0 9968*da0073e9SAndroid Build Coastguard Worker hook_order.clear() 9969*da0073e9SAndroid Build Coastguard Worker any_hook_handles.clear() 9970*da0073e9SAndroid Build Coastguard Worker out = model(inp) 9971*da0073e9SAndroid Build Coastguard Worker for handle in any_hook_handles: 9972*da0073e9SAndroid Build Coastguard Worker handle.remove() 9973*da0073e9SAndroid Build Coastguard Worker (out[0] + out[1]).sum().backward() 9974*da0073e9SAndroid Build Coastguard Worker # Check that the any-hook does not run if removed 9975*da0073e9SAndroid Build Coastguard Worker self.assertEqual(hook_order, [2, 0]) 9976*da0073e9SAndroid Build Coastguard Worker 9977*da0073e9SAndroid Build Coastguard Worker def test_multi_grad_hooks_invalid_mode(self): 9978*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(2, requires_grad=True) 9979*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(2, requires_grad=True) 9980*da0073e9SAndroid Build Coastguard Worker regex = r"Expects mode to be one of \('all', 'any'\) but got foo" 9981*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, regex): 9982*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.register_multi_grad_hook( 9983*da0073e9SAndroid Build Coastguard Worker (t1, t2), lambda _: None, mode="foo" 9984*da0073e9SAndroid Build Coastguard Worker ) 9985*da0073e9SAndroid Build Coastguard Worker 9986*da0073e9SAndroid Build Coastguard Worker def test_pynode_destruction_deadlock(self): 9987*da0073e9SAndroid Build Coastguard Worker script = """ 9988*da0073e9SAndroid Build Coastguard Workerimport torch 9989*da0073e9SAndroid Build Coastguard Worker 9990*da0073e9SAndroid Build Coastguard Workerclass Foo(torch.autograd.Function): 9991*da0073e9SAndroid Build Coastguard Worker @staticmethod 9992*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 9993*da0073e9SAndroid Build Coastguard Worker return x.clone() 9994*da0073e9SAndroid Build Coastguard Worker 9995*da0073e9SAndroid Build Coastguard Worker @staticmethod 9996*da0073e9SAndroid Build Coastguard Worker def forward(ctx, gO): 9997*da0073e9SAndroid Build Coastguard Worker return gO.clone() 9998*da0073e9SAndroid Build Coastguard Worker 9999*da0073e9SAndroid Build Coastguard Workerdef get_out(): 10000*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, requires_grad=True) 10001*da0073e9SAndroid Build Coastguard Worker 10002*da0073e9SAndroid Build Coastguard Worker # The python function is first so that it runs 10003*da0073e9SAndroid Build Coastguard Worker # last in the backward pass 10004*da0073e9SAndroid Build Coastguard Worker right = Foo.apply(inp) 10005*da0073e9SAndroid Build Coastguard Worker 10006*da0073e9SAndroid Build Coastguard Worker # An op that creates new memory 10007*da0073e9SAndroid Build Coastguard Worker left1 = inp.clone() 10008*da0073e9SAndroid Build Coastguard Worker # An op that saves its input 10009*da0073e9SAndroid Build Coastguard Worker left2 = left1 ** 2 10010*da0073e9SAndroid Build Coastguard Worker 10011*da0073e9SAndroid Build Coastguard Worker # Inplace modify so that the backward for 10012*da0073e9SAndroid Build Coastguard Worker # left2 always raises an error 10013*da0073e9SAndroid Build Coastguard Worker left1 += 1 10014*da0073e9SAndroid Build Coastguard Worker 10015*da0073e9SAndroid Build Coastguard Worker # An op that takes both side as input. 10016*da0073e9SAndroid Build Coastguard Worker # After running, both side's last op will be in 10017*da0073e9SAndroid Build Coastguard Worker # the ready queue 10018*da0073e9SAndroid Build Coastguard Worker # And the op for left will run first as it was 10019*da0073e9SAndroid Build Coastguard Worker # executed last during the forward 10020*da0073e9SAndroid Build Coastguard Worker out = left2 + right 10021*da0073e9SAndroid Build Coastguard Worker 10022*da0073e9SAndroid Build Coastguard Worker return out 10023*da0073e9SAndroid Build Coastguard Worker 10024*da0073e9SAndroid Build Coastguard Worker# Nothing should be global variables here as, from what 10025*da0073e9SAndroid Build Coastguard Worker# I can see, python leaks all the global objects 10026*da0073e9SAndroid Build Coastguard Workerget_out().sum().backward() 10027*da0073e9SAndroid Build Coastguard Worker 10028*da0073e9SAndroid Build Coastguard Worker# This used to deadlock when the PyNode is being destroyed after 10029*da0073e9SAndroid Build Coastguard Worker# the error is raised. 10030*da0073e9SAndroid Build Coastguard Worker""" 10031*da0073e9SAndroid Build Coastguard Worker try: 10032*da0073e9SAndroid Build Coastguard Worker subprocess.check_output( 10033*da0073e9SAndroid Build Coastguard Worker [sys.executable, "-c", script], 10034*da0073e9SAndroid Build Coastguard Worker stderr=subprocess.STDOUT, 10035*da0073e9SAndroid Build Coastguard Worker # On Windows, opening the subprocess with the default CWD makes `import torch` 10036*da0073e9SAndroid Build Coastguard Worker # fail, so just set CWD to this script's directory 10037*da0073e9SAndroid Build Coastguard Worker cwd=os.path.dirname(os.path.realpath(__file__)), 10038*da0073e9SAndroid Build Coastguard Worker # It is ok to have an extra long timeout here as a timeout means the test failed 10039*da0073e9SAndroid Build Coastguard Worker timeout=20, 10040*da0073e9SAndroid Build Coastguard Worker ) 10041*da0073e9SAndroid Build Coastguard Worker except subprocess.TimeoutExpired as e: 10042*da0073e9SAndroid Build Coastguard Worker self.fail( 10043*da0073e9SAndroid Build Coastguard Worker msg="Example code timed out! See the code sample in the test for details." 10044*da0073e9SAndroid Build Coastguard Worker ) 10045*da0073e9SAndroid Build Coastguard Worker except subprocess.CalledProcessError as e: 10046*da0073e9SAndroid Build Coastguard Worker if e.returncode < 0: 10047*da0073e9SAndroid Build Coastguard Worker # Sometimes we segfault instead of deadlocking 10048*da0073e9SAndroid Build Coastguard Worker self.fail("Subprocess exited with a fatal signal") 10049*da0073e9SAndroid Build Coastguard Worker else: 10050*da0073e9SAndroid Build Coastguard Worker err_msg = ( 10051*da0073e9SAndroid Build Coastguard Worker "RuntimeError: one of the variables needed for gradient computation" 10052*da0073e9SAndroid Build Coastguard Worker ) 10053*da0073e9SAndroid Build Coastguard Worker self.assertTrue(err_msg in e.output.decode("utf-8")) 10054*da0073e9SAndroid Build Coastguard Worker 10055*da0073e9SAndroid Build Coastguard Worker def test_view_func_replay(self): 10056*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(True): 10057*da0073e9SAndroid Build Coastguard Worker 10058*da0073e9SAndroid Build Coastguard Worker def _assert_match_metadata(a, b): 10059*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.size(), b.size()) 10060*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.stride(), b.stride()) 10061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.storage_offset(), b.storage_offset()) 10062*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.device, b.device) 10063*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.dtype, b.dtype) 10064*da0073e9SAndroid Build Coastguard Worker 10065*da0073e9SAndroid Build Coastguard Worker def _test_fn(fn, inp, *args, use_unsafe_view_func=False): 10066*da0073e9SAndroid Build Coastguard Worker outs = fn(inp, *args) 10067*da0073e9SAndroid Build Coastguard Worker # handle functions that return multiple views (e.g. split) 10068*da0073e9SAndroid Build Coastguard Worker if isinstance(outs, torch.Tensor): 10069*da0073e9SAndroid Build Coastguard Worker outs = [outs] 10070*da0073e9SAndroid Build Coastguard Worker 10071*da0073e9SAndroid Build Coastguard Worker for out in outs: 10072*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out._is_view()) 10073*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out._base is inp) 10074*da0073e9SAndroid Build Coastguard Worker 10075*da0073e9SAndroid Build Coastguard Worker # forward view_func 10076*da0073e9SAndroid Build Coastguard Worker new_inp = inp.clone() 10077*da0073e9SAndroid Build Coastguard Worker _assert_match_metadata(new_inp, inp) 10078*da0073e9SAndroid Build Coastguard Worker if use_unsafe_view_func: 10079*da0073e9SAndroid Build Coastguard Worker new_out = out._view_func_unsafe(new_inp) 10080*da0073e9SAndroid Build Coastguard Worker else: 10081*da0073e9SAndroid Build Coastguard Worker new_out = out._view_func(new_inp) 10082*da0073e9SAndroid Build Coastguard Worker _assert_match_metadata(new_out, out) 10083*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_out, out) 10084*da0073e9SAndroid Build Coastguard Worker 10085*da0073e9SAndroid Build Coastguard Worker # reverse view_func 10086*da0073e9SAndroid Build Coastguard Worker new_out = out.detach() 10087*da0073e9SAndroid Build Coastguard Worker new_inp = out._rev_view_func_unsafe(new_out) 10088*da0073e9SAndroid Build Coastguard Worker _assert_match_metadata(new_inp, inp) 10089*da0073e9SAndroid Build Coastguard Worker self.assertTrue(new_inp._is_view()) 10090*da0073e9SAndroid Build Coastguard Worker self.assertTrue(new_inp._base is new_out) 10091*da0073e9SAndroid Build Coastguard Worker 10092*da0073e9SAndroid Build Coastguard Worker # test individual view ops 10093*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.ops.aten.alias.default, torch.rand(2, 2)) 10094*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.as_strided, torch.rand(2, 2), (4,), (1,)) 10095*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.chunk, torch.rand(2, 4), 2, -1) 10096*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.diagonal, torch.rand(4, 4)) 10097*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.ops.aten.expand.default, torch.rand(4, 1), (-1, 3)) 10098*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.narrow, torch.rand(2, 2), 0, 1, 1) 10099*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.permute, torch.rand(2, 3, 4), (1, 0, 2)) 10100*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.select, torch.rand(2, 2), 0, 0) 10101*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.ops.aten.slice.Tensor, torch.rand(2, 2), 1, 1, 2) 10102*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.split, torch.rand(2, 2), 1) 10103*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.split_with_sizes, torch.rand(2, 4), [1, 3], -1) 10104*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.squeeze, torch.rand(2, 1, 4)) 10105*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.squeeze, torch.rand(2, 1, 4), 1) 10106*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.squeeze, torch.rand(2, 1, 1, 4), [1, 2]) 10107*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.t, torch.rand(2, 4)) 10108*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.transpose, torch.rand(2, 4), 0, 1) 10109*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.unbind, torch.rand(1, 5)) 10110*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.ops.aten.unfold.default, torch.rand(1, 5), 1, 3, 2) 10111*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.unsqueeze, torch.rand(2, 4), -2) 10112*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.ops.aten.view.default, torch.rand(2, 10), (-1, 5, 2)) 10113*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.view_as_complex, torch.rand(2, 2)) 10114*da0073e9SAndroid Build Coastguard Worker _test_fn(torch.view_as_real, torch.rand(2, 2, dtype=torch.cfloat)) 10115*da0073e9SAndroid Build Coastguard Worker 10116*da0073e9SAndroid Build Coastguard Worker # test view chains 10117*da0073e9SAndroid Build Coastguard Worker _test_fn( 10118*da0073e9SAndroid Build Coastguard Worker lambda x: x.unsqueeze(-1).transpose(-1, -2).squeeze(1), 10119*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 4), 10120*da0073e9SAndroid Build Coastguard Worker ) 10121*da0073e9SAndroid Build Coastguard Worker _test_fn( 10122*da0073e9SAndroid Build Coastguard Worker lambda x: x.chunk(2, -1)[0].transpose(0, 1).unsqueeze(-1), 10123*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 4), 10124*da0073e9SAndroid Build Coastguard Worker ) 10125*da0073e9SAndroid Build Coastguard Worker _test_fn( 10126*da0073e9SAndroid Build Coastguard Worker lambda x: x.split_with_sizes([1, 3], -1)[0].chunk(2, 0), 10127*da0073e9SAndroid Build Coastguard Worker torch.randn(2, 3, 4), 10128*da0073e9SAndroid Build Coastguard Worker ) 10129*da0073e9SAndroid Build Coastguard Worker 10130*da0073e9SAndroid Build Coastguard Worker # chains with missing view_func()s use as_strided() to cover the gaps 10131*da0073e9SAndroid Build Coastguard Worker def chain_with_only_parent_view_func(x): 10132*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(True): 10133*da0073e9SAndroid Build Coastguard Worker x = x.split_with_sizes([1, 3], -1)[0] 10134*da0073e9SAndroid Build Coastguard Worker 10135*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(False): 10136*da0073e9SAndroid Build Coastguard Worker x = x.chunk(2, 0) 10137*da0073e9SAndroid Build Coastguard Worker 10138*da0073e9SAndroid Build Coastguard Worker return x 10139*da0073e9SAndroid Build Coastguard Worker 10140*da0073e9SAndroid Build Coastguard Worker _test_fn(chain_with_only_parent_view_func, torch.randn(2, 3, 4)) 10141*da0073e9SAndroid Build Coastguard Worker 10142*da0073e9SAndroid Build Coastguard Worker def chain_with_only_current_view_func(x): 10143*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(False): 10144*da0073e9SAndroid Build Coastguard Worker x = x.split_with_sizes([1, 3], -1)[0] 10145*da0073e9SAndroid Build Coastguard Worker 10146*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(True): 10147*da0073e9SAndroid Build Coastguard Worker x = x.chunk(2, 0) 10148*da0073e9SAndroid Build Coastguard Worker 10149*da0073e9SAndroid Build Coastguard Worker return x 10150*da0073e9SAndroid Build Coastguard Worker 10151*da0073e9SAndroid Build Coastguard Worker _test_fn(chain_with_only_current_view_func, torch.randn(2, 3, 4)) 10152*da0073e9SAndroid Build Coastguard Worker 10153*da0073e9SAndroid Build Coastguard Worker # TODO: Move this somewhere else 10154*da0073e9SAndroid Build Coastguard Worker # test NT views 10155*da0073e9SAndroid Build Coastguard Worker from torch.nested._internal.nested_tensor import ( 10156*da0073e9SAndroid Build Coastguard Worker nested_view_from_values_offsets, 10157*da0073e9SAndroid Build Coastguard Worker ) 10158*da0073e9SAndroid Build Coastguard Worker 10159*da0073e9SAndroid Build Coastguard Worker values = torch.randn(10, 5) 10160*da0073e9SAndroid Build Coastguard Worker offsets = torch.tensor([0, 3, 6, 10]) 10161*da0073e9SAndroid Build Coastguard Worker _test_fn(nested_view_from_values_offsets, values, offsets) 10162*da0073e9SAndroid Build Coastguard Worker 10163*da0073e9SAndroid Build Coastguard Worker nt = nested_view_from_values_offsets(values, offsets).clone().detach() 10164*da0073e9SAndroid Build Coastguard Worker _test_fn( 10165*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._nested_get_values.default, nt, use_unsafe_view_func=True 10166*da0073e9SAndroid Build Coastguard Worker ) 10167*da0073e9SAndroid Build Coastguard Worker 10168*da0073e9SAndroid Build Coastguard Worker def chain_nt_to_dense_back_and_forth(nt): 10169*da0073e9SAndroid Build Coastguard Worker # NJT1 -> dense -> NJT2 -> dense 10170*da0073e9SAndroid Build Coastguard Worker offsets2 = nt.offsets().clone().detach() 10171*da0073e9SAndroid Build Coastguard Worker return nested_view_from_values_offsets(nt.values(), offsets2).values() 10172*da0073e9SAndroid Build Coastguard Worker 10173*da0073e9SAndroid Build Coastguard Worker _test_fn(chain_nt_to_dense_back_and_forth, nt, use_unsafe_view_func=True) 10174*da0073e9SAndroid Build Coastguard Worker 10175*da0073e9SAndroid Build Coastguard Worker def chain_dense_to_nt_back_and_forth(values, offsets): 10176*da0073e9SAndroid Build Coastguard Worker offsets2 = offsets.clone().detach() 10177*da0073e9SAndroid Build Coastguard Worker # dense -> NJT1 -> dense -> NJT2 10178*da0073e9SAndroid Build Coastguard Worker return nested_view_from_values_offsets( 10179*da0073e9SAndroid Build Coastguard Worker nested_view_from_values_offsets(values, offsets).values(), offsets2 10180*da0073e9SAndroid Build Coastguard Worker ) 10181*da0073e9SAndroid Build Coastguard Worker 10182*da0073e9SAndroid Build Coastguard Worker _test_fn( 10183*da0073e9SAndroid Build Coastguard Worker chain_dense_to_nt_back_and_forth, 10184*da0073e9SAndroid Build Coastguard Worker values, 10185*da0073e9SAndroid Build Coastguard Worker offsets, 10186*da0073e9SAndroid Build Coastguard Worker use_unsafe_view_func=True, 10187*da0073e9SAndroid Build Coastguard Worker ) 10188*da0073e9SAndroid Build Coastguard Worker 10189*da0073e9SAndroid Build Coastguard Worker def test_view_func_replay_with_modified_state(self): 10190*da0073e9SAndroid Build Coastguard Worker with torch.autograd._force_original_view_tracking(True): 10191*da0073e9SAndroid Build Coastguard Worker base = torch.randn(3, 4, 5) 10192*da0073e9SAndroid Build Coastguard Worker view = base.select(1, 2) 10193*da0073e9SAndroid Build Coastguard Worker 10194*da0073e9SAndroid Build Coastguard Worker def symint_visitor_fn(x): 10195*da0073e9SAndroid Build Coastguard Worker # modify saved index 10196*da0073e9SAndroid Build Coastguard Worker return x + 1 10197*da0073e9SAndroid Build Coastguard Worker 10198*da0073e9SAndroid Build Coastguard Worker # ensure modifying state changes view replay 10199*da0073e9SAndroid Build Coastguard Worker new_base = torch.randn_like(base) 10200*da0073e9SAndroid Build Coastguard Worker new_view = view._view_func(new_base, symint_visitor_fn=symint_visitor_fn) 10201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_view, new_base.select(1, 3)) 10202*da0073e9SAndroid Build Coastguard Worker 10203*da0073e9SAndroid Build Coastguard Worker # ensure saved state reverts back afterwards 10204*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view._view_func(new_base), new_base.select(1, 2)) 10205*da0073e9SAndroid Build Coastguard Worker 10206*da0073e9SAndroid Build Coastguard Worker # check modifying tensor state. currently, slice_inverse() is the only 10207*da0073e9SAndroid Build Coastguard Worker # view that saves a tensor 10208*da0073e9SAndroid Build Coastguard Worker base = torch.randn(3, 4, 5) 10209*da0073e9SAndroid Build Coastguard Worker sliced = base[:, 2:3, :].detach() 10210*da0073e9SAndroid Build Coastguard Worker view = torch.ops.aten.slice_inverse(sliced, base, 1, 2, 3, 1) 10211*da0073e9SAndroid Build Coastguard Worker 10212*da0073e9SAndroid Build Coastguard Worker replacement_shape = (1, 2, 3) 10213*da0073e9SAndroid Build Coastguard Worker 10214*da0073e9SAndroid Build Coastguard Worker def tensor_visitor_fn(x): 10215*da0073e9SAndroid Build Coastguard Worker # return tensor with a smaller shape than the saved one 10216*da0073e9SAndroid Build Coastguard Worker return torch.randn(*replacement_shape) 10217*da0073e9SAndroid Build Coastguard Worker 10218*da0073e9SAndroid Build Coastguard Worker # ensure modifying state changes view replay 10219*da0073e9SAndroid Build Coastguard Worker new_sliced = torch.ones_like(base)[:, 2:3, :].detach() 10220*da0073e9SAndroid Build Coastguard Worker new_view = view._view_func(new_sliced, tensor_visitor_fn=tensor_visitor_fn) 10221*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_view.shape, replacement_shape) 10222*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 10223*da0073e9SAndroid Build Coastguard Worker new_view, new_sliced.as_strided(replacement_shape, (6, 3, 1)) 10224*da0073e9SAndroid Build Coastguard Worker ) 10225*da0073e9SAndroid Build Coastguard Worker 10226*da0073e9SAndroid Build Coastguard Worker # ensure saved state reverts back afterwards 10227*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view._view_func(sliced), base) 10228*da0073e9SAndroid Build Coastguard Worker 10229*da0073e9SAndroid Build Coastguard Worker def test_setup_context_when_forward_has_default_args(self): 10230*da0073e9SAndroid Build Coastguard Worker class PowFunction(Function): 10231*da0073e9SAndroid Build Coastguard Worker @staticmethod 10232*da0073e9SAndroid Build Coastguard Worker def forward(x, y=3): 10233*da0073e9SAndroid Build Coastguard Worker return torch.pow(x, y) 10234*da0073e9SAndroid Build Coastguard Worker 10235*da0073e9SAndroid Build Coastguard Worker @staticmethod 10236*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output): 10237*da0073e9SAndroid Build Coastguard Worker x, y = inputs 10238*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 10239*da0073e9SAndroid Build Coastguard Worker ctx.y = y 10240*da0073e9SAndroid Build Coastguard Worker 10241*da0073e9SAndroid Build Coastguard Worker @staticmethod 10242*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 10243*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 10244*da0073e9SAndroid Build Coastguard Worker y = ctx.y 10245*da0073e9SAndroid Build Coastguard Worker return gO * y * torch.pow(x, y - 1), None 10246*da0073e9SAndroid Build Coastguard Worker 10247*da0073e9SAndroid Build Coastguard Worker class PowFunctionWithClassmethod(Function): 10248*da0073e9SAndroid Build Coastguard Worker @classmethod 10249*da0073e9SAndroid Build Coastguard Worker def forward(cls, x, y=3): 10250*da0073e9SAndroid Build Coastguard Worker return torch.pow(x, y) 10251*da0073e9SAndroid Build Coastguard Worker 10252*da0073e9SAndroid Build Coastguard Worker @classmethod 10253*da0073e9SAndroid Build Coastguard Worker def setup_context(cls, ctx, inputs, output): 10254*da0073e9SAndroid Build Coastguard Worker x, y = inputs 10255*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 10256*da0073e9SAndroid Build Coastguard Worker ctx.y = y 10257*da0073e9SAndroid Build Coastguard Worker 10258*da0073e9SAndroid Build Coastguard Worker @classmethod 10259*da0073e9SAndroid Build Coastguard Worker def backward(cls, ctx, gO): 10260*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 10261*da0073e9SAndroid Build Coastguard Worker y = ctx.y 10262*da0073e9SAndroid Build Coastguard Worker return gO * y * torch.pow(x, y - 1), None 10263*da0073e9SAndroid Build Coastguard Worker 10264*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(2.0, requires_grad=True) 10265*da0073e9SAndroid Build Coastguard Worker 10266*da0073e9SAndroid Build Coastguard Worker y = torch.tensor(8.0) 10267*da0073e9SAndroid Build Coastguard Worker y_expected = torch.tensor(12.0) 10268*da0073e9SAndroid Build Coastguard Worker 10269*da0073e9SAndroid Build Coastguard Worker y1 = PowFunction.apply(x) 10270*da0073e9SAndroid Build Coastguard Worker (y1_expected,) = torch.autograd.grad(y1, x) 10271*da0073e9SAndroid Build Coastguard Worker 10272*da0073e9SAndroid Build Coastguard Worker y2 = PowFunctionWithClassmethod.apply(x) 10273*da0073e9SAndroid Build Coastguard Worker (y2_expected,) = torch.autograd.grad(y2, x) 10274*da0073e9SAndroid Build Coastguard Worker 10275*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y1) 10276*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_expected, y1_expected) 10277*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, y2) 10278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y_expected, y2_expected) 10279*da0073e9SAndroid Build Coastguard Worker 10280*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 10281*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_default_device_placement_context(self): 10282*da0073e9SAndroid Build Coastguard Worker # During gradcheck with fast_mode=True, we create a random vector on the CPU device using a CPU generator. 10283*da0073e9SAndroid Build Coastguard Worker # This test ensures that this still works when the default device is set to something else by the user. 10284*da0073e9SAndroid Build Coastguard Worker with torch.device("cuda"): 10285*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, dtype=torch.double, requires_grad=True) 10286*da0073e9SAndroid Build Coastguard Worker 10287*da0073e9SAndroid Build Coastguard Worker def func(inp): 10288*da0073e9SAndroid Build Coastguard Worker return inp**2.0 10289*da0073e9SAndroid Build Coastguard Worker 10290*da0073e9SAndroid Build Coastguard Worker self.assertTrue(gradcheck(func, x, fast_mode=True)) 10291*da0073e9SAndroid Build Coastguard Worker 10292*da0073e9SAndroid Build Coastguard Worker 10293*da0073e9SAndroid Build Coastguard Workerdef index_perm_variable(shape, max_indices): 10294*da0073e9SAndroid Build Coastguard Worker if not isinstance(shape, tuple): 10295*da0073e9SAndroid Build Coastguard Worker shape = (shape,) 10296*da0073e9SAndroid Build Coastguard Worker 10297*da0073e9SAndroid Build Coastguard Worker index = torch.randperm(max_indices).narrow(0, 0, reduce(mul, shape)).view(shape) 10298*da0073e9SAndroid Build Coastguard Worker return index 10299*da0073e9SAndroid Build Coastguard Worker 10300*da0073e9SAndroid Build Coastguard Worker 10301*da0073e9SAndroid Build Coastguard Workerdef bernoulli_scalar(): 10302*da0073e9SAndroid Build Coastguard Worker return torch.tensor(0, dtype=torch.uint8).bernoulli_() 10303*da0073e9SAndroid Build Coastguard Worker 10304*da0073e9SAndroid Build Coastguard Worker 10305*da0073e9SAndroid Build Coastguard Workerclass TestAutogradForwardModeBatchedGrad(TestCase): 10306*da0073e9SAndroid Build Coastguard Worker def test_out_of_place_basic(self): 10307*da0073e9SAndroid Build Coastguard Worker a = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 10308*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4, 4, dtype=torch.double, requires_grad=True) 10309*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 10310*da0073e9SAndroid Build Coastguard Worker gradcheck( 10311*da0073e9SAndroid Build Coastguard Worker torch.sin, 10312*da0073e9SAndroid Build Coastguard Worker a, 10313*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 10314*da0073e9SAndroid Build Coastguard Worker check_batched_grad=True, 10315*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 10316*da0073e9SAndroid Build Coastguard Worker ) 10317*da0073e9SAndroid Build Coastguard Worker ) 10318*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 10319*da0073e9SAndroid Build Coastguard Worker gradcheck( 10320*da0073e9SAndroid Build Coastguard Worker torch.add, 10321*da0073e9SAndroid Build Coastguard Worker (a, b), 10322*da0073e9SAndroid Build Coastguard Worker check_forward_ad=True, 10323*da0073e9SAndroid Build Coastguard Worker check_batched_grad=True, 10324*da0073e9SAndroid Build Coastguard Worker check_batched_forward_grad=True, 10325*da0073e9SAndroid Build Coastguard Worker ) 10326*da0073e9SAndroid Build Coastguard Worker ) 10327*da0073e9SAndroid Build Coastguard Worker 10328*da0073e9SAndroid Build Coastguard Worker def test_out_of_place_not_same_layout(self): 10329*da0073e9SAndroid Build Coastguard Worker input = torch.zeros([2, 2]).transpose(0, 1) 10330*da0073e9SAndroid Build Coastguard Worker tangent = torch.zeros([2, 2, 2]) 10331*da0073e9SAndroid Build Coastguard Worker 10332*da0073e9SAndroid Build Coastguard Worker def jvp(tangent): 10333*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10334*da0073e9SAndroid Build Coastguard Worker x = fwAD.make_dual(input, tangent) 10335*da0073e9SAndroid Build Coastguard Worker return fwAD.unpack_dual(x)[1] 10336*da0073e9SAndroid Build Coastguard Worker 10337*da0073e9SAndroid Build Coastguard Worker x_tangent = torch._vmap_internals._vmap(jvp, 0, 0)(tangent) 10338*da0073e9SAndroid Build Coastguard Worker 10339*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(x_tangent, tangent) 10340*da0073e9SAndroid Build Coastguard Worker 10341*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_same_layout(self): 10342*da0073e9SAndroid Build Coastguard Worker input = torch.zeros([2, 2]) 10343*da0073e9SAndroid Build Coastguard Worker tangent = torch.zeros([2, 2, 2]) 10344*da0073e9SAndroid Build Coastguard Worker base = torch.zeros([2, 2]) 10345*da0073e9SAndroid Build Coastguard Worker view = base.view_as(base) 10346*da0073e9SAndroid Build Coastguard Worker 10347*da0073e9SAndroid Build Coastguard Worker def jvp(tangent): 10348*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10349*da0073e9SAndroid Build Coastguard Worker x = fwAD.make_dual(input, tangent) 10350*da0073e9SAndroid Build Coastguard Worker view.copy_(x) 10351*da0073e9SAndroid Build Coastguard Worker return ( 10352*da0073e9SAndroid Build Coastguard Worker fwAD.unpack_dual(x)[1], 10353*da0073e9SAndroid Build Coastguard Worker fwAD.unpack_dual(view)[1], 10354*da0073e9SAndroid Build Coastguard Worker fwAD.unpack_dual(view._base)[1], 10355*da0073e9SAndroid Build Coastguard Worker ) 10356*da0073e9SAndroid Build Coastguard Worker 10357*da0073e9SAndroid Build Coastguard Worker x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)( 10358*da0073e9SAndroid Build Coastguard Worker tangent 10359*da0073e9SAndroid Build Coastguard Worker ) 10360*da0073e9SAndroid Build Coastguard Worker 10361*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 10362*da0073e9SAndroid Build Coastguard Worker view_tangent._is_view() 10363*da0073e9SAndroid Build Coastguard Worker ) # Optimization to share the same tensor! 10364*da0073e9SAndroid Build Coastguard Worker self.assertIs(view_tangent, base_tangent) 10365*da0073e9SAndroid Build Coastguard Worker self.assertIs(x_tangent, tangent) 10366*da0073e9SAndroid Build Coastguard Worker 10367*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_not_same_layout(self): 10368*da0073e9SAndroid Build Coastguard Worker input = torch.zeros([2, 2]) 10369*da0073e9SAndroid Build Coastguard Worker tangent = torch.zeros([2, 2, 2]) 10370*da0073e9SAndroid Build Coastguard Worker view = torch.zeros([2, 2]).transpose(0, 1) 10371*da0073e9SAndroid Build Coastguard Worker 10372*da0073e9SAndroid Build Coastguard Worker def jvp(tangent): 10373*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10374*da0073e9SAndroid Build Coastguard Worker x = fwAD.make_dual(input, tangent) 10375*da0073e9SAndroid Build Coastguard Worker view.copy_(x) 10376*da0073e9SAndroid Build Coastguard Worker return ( 10377*da0073e9SAndroid Build Coastguard Worker fwAD.unpack_dual(x)[1], 10378*da0073e9SAndroid Build Coastguard Worker fwAD.unpack_dual(view)[1], 10379*da0073e9SAndroid Build Coastguard Worker fwAD.unpack_dual(view._base)[1], 10380*da0073e9SAndroid Build Coastguard Worker ) 10381*da0073e9SAndroid Build Coastguard Worker 10382*da0073e9SAndroid Build Coastguard Worker x_tangent, view_tangent, base_tangent = torch._vmap_internals._vmap(jvp, 0, 0)( 10383*da0073e9SAndroid Build Coastguard Worker tangent 10384*da0073e9SAndroid Build Coastguard Worker ) 10385*da0073e9SAndroid Build Coastguard Worker 10386*da0073e9SAndroid Build Coastguard Worker self.assertIs(view_tangent._base, base_tangent) 10387*da0073e9SAndroid Build Coastguard Worker self.assertIs(x_tangent, tangent) 10388*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(view_tangent, tangent) 10389*da0073e9SAndroid Build Coastguard Worker 10390*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_for_storage_numel_skipped(self): 10391*da0073e9SAndroid Build Coastguard Worker # See: test_metadata_check_checks_storage_numel for the reverse of this test 10392*da0073e9SAndroid Build Coastguard Worker primal = torch.randn(5)[:4].detach() 10393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(primal.storage()), 5) 10394*da0073e9SAndroid Build Coastguard Worker tangent = torch.randn(10, 4) 10395*da0073e9SAndroid Build Coastguard Worker 10396*da0073e9SAndroid Build Coastguard Worker def jvp(tangent): 10397*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10398*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(primal, tangent) 10399*da0073e9SAndroid Build Coastguard Worker _, unpacked_tangent = fwAD.unpack_dual(dual) 10400*da0073e9SAndroid Build Coastguard Worker 10401*da0073e9SAndroid Build Coastguard Worker # No copy is made 10402*da0073e9SAndroid Build Coastguard Worker self.assertIs(tangent, unpacked_tangent) 10403*da0073e9SAndroid Build Coastguard Worker 10404*da0073e9SAndroid Build Coastguard Worker # as_strided raises 10405*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10406*da0073e9SAndroid Build Coastguard Worker RuntimeError, "can access memory outside of `tensor`" 10407*da0073e9SAndroid Build Coastguard Worker ): 10408*da0073e9SAndroid Build Coastguard Worker dual.as_strided((5,), (1,), 0) 10409*da0073e9SAndroid Build Coastguard Worker return unpacked_tangent 10410*da0073e9SAndroid Build Coastguard Worker 10411*da0073e9SAndroid Build Coastguard Worker torch._vmap_internals._vmap(jvp, 0, 0)(tangent) 10412*da0073e9SAndroid Build Coastguard Worker 10413*da0073e9SAndroid Build Coastguard Worker 10414*da0073e9SAndroid Build Coastguard Workerclass TestAutogradForwardMode(TestCase): 10415*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 10416*da0073e9SAndroid Build Coastguard Worker # Ensure that a failing test won't make others fail 10417*da0073e9SAndroid Build Coastguard Worker while fwAD._current_level >= 0: 10418*da0073e9SAndroid Build Coastguard Worker fwAD.exit_dual_level() 10419*da0073e9SAndroid Build Coastguard Worker 10420*da0073e9SAndroid Build Coastguard Worker super().tearDown() 10421*da0073e9SAndroid Build Coastguard Worker 10422*da0073e9SAndroid Build Coastguard Worker def test_forward_level_cleanup(self): 10423*da0073e9SAndroid Build Coastguard Worker def get_tensor_and_weak_ref(): 10424*da0073e9SAndroid Build Coastguard Worker # Create a new Tensor and weak reference 10425*da0073e9SAndroid Build Coastguard Worker t = torch.rand(2, requires_grad=True) 10426*da0073e9SAndroid Build Coastguard Worker return t, torch._C._WeakTensorRef(t) 10427*da0073e9SAndroid Build Coastguard Worker 10428*da0073e9SAndroid Build Coastguard Worker # Sanity check that the helper function works as expected 10429*da0073e9SAndroid Build Coastguard Worker t, t_ref = get_tensor_and_weak_ref() 10430*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t_ref.expired()) 10431*da0073e9SAndroid Build Coastguard Worker 10432*da0073e9SAndroid Build Coastguard Worker del t 10433*da0073e9SAndroid Build Coastguard Worker self.assertTrue(t_ref.expired()) 10434*da0073e9SAndroid Build Coastguard Worker 10435*da0073e9SAndroid Build Coastguard Worker # Main test code 10436*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10437*da0073e9SAndroid Build Coastguard Worker 10438*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10439*da0073e9SAndroid Build Coastguard Worker tangent, tangent_ref = get_tensor_and_weak_ref() 10440*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tangent_ref.expired()) 10441*da0073e9SAndroid Build Coastguard Worker 10442*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, tangent) 10443*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tangent_ref.expired()) 10444*da0073e9SAndroid Build Coastguard Worker 10445*da0073e9SAndroid Build Coastguard Worker # Make sure that the tangent we provided has been re-used as is 10446*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwAD.unpack_dual(dual)[1] is tangent) 10447*da0073e9SAndroid Build Coastguard Worker 10448*da0073e9SAndroid Build Coastguard Worker # Make sure that dual is keeping the tangent alive 10449*da0073e9SAndroid Build Coastguard Worker del tangent 10450*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tangent_ref.expired()) 10451*da0073e9SAndroid Build Coastguard Worker 10452*da0073e9SAndroid Build Coastguard Worker # Make sure that the dual level does not keep the c++ 10453*da0073e9SAndroid Build Coastguard Worker # version of the tangent alive 10454*da0073e9SAndroid Build Coastguard Worker del dual 10455*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tangent_ref.expired()) 10456*da0073e9SAndroid Build Coastguard Worker 10457*da0073e9SAndroid Build Coastguard Worker def test_size_check(self): 10458*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10459*da0073e9SAndroid Build Coastguard Worker tangent = torch.rand(3) 10460*da0073e9SAndroid Build Coastguard Worker 10461*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10462*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10463*da0073e9SAndroid Build Coastguard Worker RuntimeError, 10464*da0073e9SAndroid Build Coastguard Worker "Trying to set a forward gradient that has a different size", 10465*da0073e9SAndroid Build Coastguard Worker ): 10466*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, tangent) 10467*da0073e9SAndroid Build Coastguard Worker 10468*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, tangent[1:]) 10469*da0073e9SAndroid Build Coastguard Worker 10470*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_checks_storage_numel(self): 10471*da0073e9SAndroid Build Coastguard Worker primal = torch.randn(5)[:4].detach() 10472*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(primal.storage()), 5) 10473*da0073e9SAndroid Build Coastguard Worker tangent = torch.randn(4) 10474*da0073e9SAndroid Build Coastguard Worker 10475*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10476*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(primal, tangent) 10477*da0073e9SAndroid Build Coastguard Worker _, unpacked_tangent = fwAD.unpack_dual(dual) 10478*da0073e9SAndroid Build Coastguard Worker 10479*da0073e9SAndroid Build Coastguard Worker # # Verify that mutating unpacked tangent does not affect the original tangent 10480*da0073e9SAndroid Build Coastguard Worker tangent_clone = tangent.clone() 10481*da0073e9SAndroid Build Coastguard Worker unpacked_tangent *= 2 10482*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(tangent_clone, tangent)) 10483*da0073e9SAndroid Build Coastguard Worker 10484*da0073e9SAndroid Build Coastguard Worker # as_strided runs without error 10485*da0073e9SAndroid Build Coastguard Worker dual.as_strided((5,), (1,), 0) 10486*da0073e9SAndroid Build Coastguard Worker 10487*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_checks_ignores_size_zero(self): 10488*da0073e9SAndroid Build Coastguard Worker a = torch.ones(0).as_strided((0, 1), (1, 1), 0) 10489*da0073e9SAndroid Build Coastguard Worker b = torch.ones(0).as_strided((0, 1), (1, 0), 0) 10490*da0073e9SAndroid Build Coastguard Worker 10491*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10492*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(a, b) 10493*da0073e9SAndroid Build Coastguard Worker torch.diagonal(dual, offset=0) 10494*da0073e9SAndroid Build Coastguard Worker 10495*da0073e9SAndroid Build Coastguard Worker input = torch.rand([0, 1], dtype=torch.complex128, requires_grad=True) 10496*da0073e9SAndroid Build Coastguard Worker func = partial(torch.diagonal, offset=0) 10497*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(func, (input,), check_forward_ad=True) 10498*da0073e9SAndroid Build Coastguard Worker 10499*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_when_primal_has_conj_bit(self): 10500*da0073e9SAndroid Build Coastguard Worker # Make sure the _has_same_storage_numel is a fallthrough, so that 10501*da0073e9SAndroid Build Coastguard Worker # conj bit does not materialize. If it materializes it would 10502*da0073e9SAndroid Build Coastguard Worker # cause the layout check to fail for views that do not index the 10503*da0073e9SAndroid Build Coastguard Worker # the entire storage. 10504*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, dtype=torch.cdouble).conj() 10505*da0073e9SAndroid Build Coastguard Worker b = torch.rand_like(a) 10506*da0073e9SAndroid Build Coastguard Worker 10507*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_conj(a)) 10508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(a.storage()), len(b.storage())) 10509*da0073e9SAndroid Build Coastguard Worker 10510*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10511*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(a, b) 10512*da0073e9SAndroid Build Coastguard Worker dual[1:] 10513*da0073e9SAndroid Build Coastguard Worker 10514*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_when_primal_has_neg_bit(self): 10515*da0073e9SAndroid Build Coastguard Worker # Make sure the _has_same_storage_numel is a fallthrough, so that 10516*da0073e9SAndroid Build Coastguard Worker # conj bit does not materialize. If it materializes it would 10517*da0073e9SAndroid Build Coastguard Worker # cause the layout check to fail for views that do not index the 10518*da0073e9SAndroid Build Coastguard Worker # the entire storage. 10519*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 2, dtype=torch.cdouble).conj().imag 10520*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, dtype=torch.cdouble).imag 10521*da0073e9SAndroid Build Coastguard Worker 10522*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_neg(a)) 10523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(a.storage()), len(b.storage())) 10524*da0073e9SAndroid Build Coastguard Worker 10525*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10526*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(a, b) 10527*da0073e9SAndroid Build Coastguard Worker dual[1:] 10528*da0073e9SAndroid Build Coastguard Worker 10529*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_check_conj(self): 10530*da0073e9SAndroid Build Coastguard Worker keys = { 10531*da0073e9SAndroid Build Coastguard Worker "NEITHER": lambda x: x, 10532*da0073e9SAndroid Build Coastguard Worker "CONJ": lambda x: x.conj(), 10533*da0073e9SAndroid Build Coastguard Worker "NEG": lambda x: x._neg_view(), 10534*da0073e9SAndroid Build Coastguard Worker } 10535*da0073e9SAndroid Build Coastguard Worker 10536*da0073e9SAndroid Build Coastguard Worker for primal_key, tangent_key in product(keys, keys): 10537*da0073e9SAndroid Build Coastguard Worker x = keys[primal_key](torch.randn(2, 3, 4, dtype=torch.cdouble)) 10538*da0073e9SAndroid Build Coastguard Worker t = keys[tangent_key](torch.randn(2, 3, 4, dtype=torch.cdouble)) 10539*da0073e9SAndroid Build Coastguard Worker 10540*da0073e9SAndroid Build Coastguard Worker if primal_key == tangent_key: 10541*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10542*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(x, t) 10543*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwAD.unpack_dual(dual).tangent is t) 10544*da0073e9SAndroid Build Coastguard Worker torch.real(dual) 10545*da0073e9SAndroid Build Coastguard Worker torch.imag(dual) 10546*da0073e9SAndroid Build Coastguard Worker else: 10547*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10548*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(x, t) 10549*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwAD.unpack_dual(dual).tangent is not t) 10550*da0073e9SAndroid Build Coastguard Worker torch.real(dual) 10551*da0073e9SAndroid Build Coastguard Worker torch.imag(dual) 10552*da0073e9SAndroid Build Coastguard Worker 10553*da0073e9SAndroid Build Coastguard Worker def test_metadata_check_ignore_storage_offset_for_zero_numel_tensor(self): 10554*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/80507 10555*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0]).as_strided((0,), (1,), 1) 10556*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1.0]).as_strided((0,), (1,), 2) 10557*da0073e9SAndroid Build Coastguard Worker 10558*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10559*da0073e9SAndroid Build Coastguard Worker dual_input = fwAD.make_dual(a, b) 10560*da0073e9SAndroid Build Coastguard Worker # Check that no copy is made 10561*da0073e9SAndroid Build Coastguard Worker self.assertIs(fwAD.unpack_dual(dual_input).tangent, b) 10562*da0073e9SAndroid Build Coastguard Worker 10563*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0]).as_strided((1,), (2,), 0) 10564*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1.0]).as_strided((1,), (1,), 0) 10565*da0073e9SAndroid Build Coastguard Worker 10566*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10567*da0073e9SAndroid Build Coastguard Worker dual_input = fwAD.make_dual(a, b) 10568*da0073e9SAndroid Build Coastguard Worker dual_input[1:] 10569*da0073e9SAndroid Build Coastguard Worker 10570*da0073e9SAndroid Build Coastguard Worker # The following test functions want to ensure all the following behaviors: 10571*da0073e9SAndroid Build Coastguard Worker # - Ensure that default level system in the python binding works 10572*da0073e9SAndroid Build Coastguard Worker # - Ensure that only level 0 exists and nesting is properly disabled 10573*da0073e9SAndroid Build Coastguard Worker # - Ensure that printing works fine 10574*da0073e9SAndroid Build Coastguard Worker # - Ensure that basic packing/unpacking works 10575*da0073e9SAndroid Build Coastguard Worker # - Ensure that advanced packing/unpacking works 10576*da0073e9SAndroid Build Coastguard Worker # - For memory / version counter share 10577*da0073e9SAndroid Build Coastguard Worker # - For backward AD (regular ops) 10578*da0073e9SAndroid Build Coastguard Worker # - Ensure that view + inplace for both modes work fine 10579*da0073e9SAndroid Build Coastguard Worker # - Ensure we do proper cleanup on exit of a level 10580*da0073e9SAndroid Build Coastguard Worker 10581*da0073e9SAndroid Build Coastguard Worker def test_default_level(self): 10582*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10583*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 10584*da0073e9SAndroid Build Coastguard Worker 10585*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10586*da0073e9SAndroid Build Coastguard Worker baz = fwAD.make_dual(foo, bar) 10587*da0073e9SAndroid Build Coastguard Worker baz_primal, baz_tangent = fwAD.unpack_dual(baz) 10588*da0073e9SAndroid Build Coastguard Worker self.assertEqual(baz_primal, foo) 10589*da0073e9SAndroid Build Coastguard Worker # We don't actually need to enforce that these two are the exact same python 10590*da0073e9SAndroid Build Coastguard Worker # object, feel free to relax in the future 10591*da0073e9SAndroid Build Coastguard Worker self.assertIs(baz_tangent, bar) 10592*da0073e9SAndroid Build Coastguard Worker 10593*da0073e9SAndroid Build Coastguard Worker baz_primal, baz_tangent = fwAD.unpack_dual(baz) 10594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(baz_primal, foo) 10595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(baz_tangent, None) 10596*da0073e9SAndroid Build Coastguard Worker 10597*da0073e9SAndroid Build Coastguard Worker def test_fwd_grad_enabled(self): 10598*da0073e9SAndroid Build Coastguard Worker # Tests some private helper functions to enable/disable fwd grad mode 10599*da0073e9SAndroid Build Coastguard Worker enabled = fwAD._is_fwd_grad_enabled() 10600*da0073e9SAndroid Build Coastguard Worker self.assertTrue(enabled) 10601*da0073e9SAndroid Build Coastguard Worker 10602*da0073e9SAndroid Build Coastguard Worker try: 10603*da0073e9SAndroid Build Coastguard Worker torch._C._set_fwd_grad_enabled(False) 10604*da0073e9SAndroid Build Coastguard Worker enabled = fwAD._is_fwd_grad_enabled() 10605*da0073e9SAndroid Build Coastguard Worker self.assertFalse(enabled) 10606*da0073e9SAndroid Build Coastguard Worker finally: 10607*da0073e9SAndroid Build Coastguard Worker torch._C._set_fwd_grad_enabled(True) 10608*da0073e9SAndroid Build Coastguard Worker 10609*da0073e9SAndroid Build Coastguard Worker enabled = fwAD._is_fwd_grad_enabled() 10610*da0073e9SAndroid Build Coastguard Worker self.assertTrue(enabled) 10611*da0073e9SAndroid Build Coastguard Worker 10612*da0073e9SAndroid Build Coastguard Worker def test_set_fwd_grad_enabled(self): 10613*da0073e9SAndroid Build Coastguard Worker # Tests a private helper function 10614*da0073e9SAndroid Build Coastguard Worker try: 10615*da0073e9SAndroid Build Coastguard Worker torch._C._set_fwd_grad_enabled(False) 10616*da0073e9SAndroid Build Coastguard Worker enabled = fwAD._is_fwd_grad_enabled() 10617*da0073e9SAndroid Build Coastguard Worker self.assertFalse(enabled) 10618*da0073e9SAndroid Build Coastguard Worker 10619*da0073e9SAndroid Build Coastguard Worker with fwAD._set_fwd_grad_enabled(True): 10620*da0073e9SAndroid Build Coastguard Worker enabled = fwAD._is_fwd_grad_enabled() 10621*da0073e9SAndroid Build Coastguard Worker self.assertTrue(enabled) 10622*da0073e9SAndroid Build Coastguard Worker 10623*da0073e9SAndroid Build Coastguard Worker enabled = fwAD._is_fwd_grad_enabled() 10624*da0073e9SAndroid Build Coastguard Worker self.assertFalse(enabled) 10625*da0073e9SAndroid Build Coastguard Worker finally: 10626*da0073e9SAndroid Build Coastguard Worker torch._C._set_fwd_grad_enabled(True) 10627*da0073e9SAndroid Build Coastguard Worker 10628*da0073e9SAndroid Build Coastguard Worker def test_nested_level(self): 10629*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level() as level: 10630*da0073e9SAndroid Build Coastguard Worker # For now only level 0 exists 10631*da0073e9SAndroid Build Coastguard Worker self.assertEqual(level, 0) 10632*da0073e9SAndroid Build Coastguard Worker 10633*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10634*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10635*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Nested forward mode AD is not supported at the moment" 10636*da0073e9SAndroid Build Coastguard Worker ): 10637*da0073e9SAndroid Build Coastguard Worker nest_level = fwAD.enter_dual_level() 10638*da0073e9SAndroid Build Coastguard Worker 10639*da0073e9SAndroid Build Coastguard Worker def test_set_fw_grad_having_own_fw_grad_at_same_level(self): 10640*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10641*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 10642*da0073e9SAndroid Build Coastguard Worker baz = torch.rand(2) 10643*da0073e9SAndroid Build Coastguard Worker 10644*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10645*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10646*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10647*da0073e9SAndroid Build Coastguard Worker RuntimeError, "has a forward gradient at the same level" 10648*da0073e9SAndroid Build Coastguard Worker ): 10649*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(baz, dual) 10650*da0073e9SAndroid Build Coastguard Worker 10651*da0073e9SAndroid Build Coastguard Worker def test_codegen_ignores_undefined_outputs(self): 10652*da0073e9SAndroid Build Coastguard Worker # This test checks that codegen silently ignores undefined outputs 10653*da0073e9SAndroid Build Coastguard Worker # Below, grad_input is specified as False in grad_output_mask, so 10654*da0073e9SAndroid Build Coastguard Worker # convolution backward will return a undefined tensor in that position. 10655*da0073e9SAndroid Build Coastguard Worker # Note that for this test to work we need to make sure either grad_output 10656*da0073e9SAndroid Build Coastguard Worker # or weight to be a dual tensor, so grad_input requires forward grad 10657*da0073e9SAndroid Build Coastguard Worker weight = torch.randn(6, 1, 30, 30) 10658*da0073e9SAndroid Build Coastguard Worker inp = torch.rand((1, 1, 32, 32)) 10659*da0073e9SAndroid Build Coastguard Worker out = torch.nn.functional.conv2d(inp, weight) 10660*da0073e9SAndroid Build Coastguard Worker grad_out = torch.ones_like(out) 10661*da0073e9SAndroid Build Coastguard Worker 10662*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10663*da0073e9SAndroid Build Coastguard Worker dual_weight = fwAD.make_dual(weight, torch.ones_like(weight)) 10664*da0073e9SAndroid Build Coastguard Worker grad_input, _, _ = torch.ops.aten.convolution_backward( 10665*da0073e9SAndroid Build Coastguard Worker grad_out, 10666*da0073e9SAndroid Build Coastguard Worker inp, 10667*da0073e9SAndroid Build Coastguard Worker dual_weight, 10668*da0073e9SAndroid Build Coastguard Worker (0,), 10669*da0073e9SAndroid Build Coastguard Worker (1, 1), 10670*da0073e9SAndroid Build Coastguard Worker (0, 0), 10671*da0073e9SAndroid Build Coastguard Worker (1, 1), 10672*da0073e9SAndroid Build Coastguard Worker False, 10673*da0073e9SAndroid Build Coastguard Worker (0, 0), 10674*da0073e9SAndroid Build Coastguard Worker 1, 10675*da0073e9SAndroid Build Coastguard Worker (False, True, False), 10676*da0073e9SAndroid Build Coastguard Worker ) 10677*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(grad_input) 10678*da0073e9SAndroid Build Coastguard Worker 10679*da0073e9SAndroid Build Coastguard Worker def test_make_dual_inference_tensor_in_inference_mode(self): 10680*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 10681*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10682*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 10683*da0073e9SAndroid Build Coastguard Worker foo_copy = foo.clone() 10684*da0073e9SAndroid Build Coastguard Worker 10685*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10686*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10687*da0073e9SAndroid Build Coastguard Worker self.assertFalse(dual._is_view()) 10688*da0073e9SAndroid Build Coastguard Worker 10689*da0073e9SAndroid Build Coastguard Worker dual += 1 10690*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.allclose(foo, foo_copy)) 10691*da0073e9SAndroid Build Coastguard Worker 10692*da0073e9SAndroid Build Coastguard Worker def test_make_dual_torch_dispatch(self): 10693*da0073e9SAndroid Build Coastguard Worker counter = [0] 10694*da0073e9SAndroid Build Coastguard Worker 10695*da0073e9SAndroid Build Coastguard Worker class MySubclass(torch.Tensor): 10696*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data=None): 10697*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, data) 10698*da0073e9SAndroid Build Coastguard Worker 10699*da0073e9SAndroid Build Coastguard Worker @classmethod 10700*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 10701*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.alias: 10702*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 10703*da0073e9SAndroid Build Coastguard Worker 10704*da0073e9SAndroid Build Coastguard Worker # Make sure we can re-enable autograd here 10705*da0073e9SAndroid Build Coastguard Worker with torch.overrides.enable_reentrant_dispatch(): 10706*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(1, requires_grad=True) 10707*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(foo.exp().grad_fn) 10708*da0073e9SAndroid Build Coastguard Worker 10709*da0073e9SAndroid Build Coastguard Worker with no_dispatch(): 10710*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 10711*da0073e9SAndroid Build Coastguard Worker 10712*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0) 10713*da0073e9SAndroid Build Coastguard Worker s = MySubclass(a) 10714*da0073e9SAndroid Build Coastguard Worker 10715*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10716*da0073e9SAndroid Build Coastguard Worker # Only the primal has "alias" called on it 10717*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(s, torch.rand_like(s)) 10718*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 10719*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(torch.rand_like(s), s) 10720*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 10721*da0073e9SAndroid Build Coastguard Worker 10722*da0073e9SAndroid Build Coastguard Worker def test_make_dual_forbid_integral_dtype(self): 10723*da0073e9SAndroid Build Coastguard Worker primal_f = torch.ones(2, 2, dtype=torch.float) 10724*da0073e9SAndroid Build Coastguard Worker primal_l = torch.ones(2, 2, dtype=torch.long) 10725*da0073e9SAndroid Build Coastguard Worker 10726*da0073e9SAndroid Build Coastguard Worker tangent_f = torch.ones(2, 2, dtype=torch.float) 10727*da0073e9SAndroid Build Coastguard Worker tangent_l = torch.ones(2, 2, dtype=torch.long) 10728*da0073e9SAndroid Build Coastguard Worker 10729*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10730*da0073e9SAndroid Build Coastguard Worker # Float Primal and Long Tangent 10731*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10732*da0073e9SAndroid Build Coastguard Worker ValueError, "Expected tangent to be floating point or complex" 10733*da0073e9SAndroid Build Coastguard Worker ): 10734*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(primal_f, tangent_l) 10735*da0073e9SAndroid Build Coastguard Worker 10736*da0073e9SAndroid Build Coastguard Worker # Long Primal and Long Tangent 10737*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10738*da0073e9SAndroid Build Coastguard Worker ValueError, "Expected primal to be floating point or complex" 10739*da0073e9SAndroid Build Coastguard Worker ): 10740*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(primal_l, tangent_l) 10741*da0073e9SAndroid Build Coastguard Worker 10742*da0073e9SAndroid Build Coastguard Worker # Long Primal and Float Tangent 10743*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 10744*da0073e9SAndroid Build Coastguard Worker ValueError, "Expected primal to be floating point or complex" 10745*da0073e9SAndroid Build Coastguard Worker ): 10746*da0073e9SAndroid Build Coastguard Worker fwAD.make_dual(primal_l, tangent_f) 10747*da0073e9SAndroid Build Coastguard Worker 10748*da0073e9SAndroid Build Coastguard Worker def test_print(self): 10749*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level() as level: 10750*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3) 10751*da0073e9SAndroid Build Coastguard Worker self.assertFalse("tangent=" in str(a)) 10752*da0073e9SAndroid Build Coastguard Worker 10753*da0073e9SAndroid Build Coastguard Worker b = fwAD.make_dual(a, torch.rand(3)) 10754*da0073e9SAndroid Build Coastguard Worker self.assertFalse("tangent=" in str(a)) 10755*da0073e9SAndroid Build Coastguard Worker self.assertTrue("tangent=" in str(b)) 10756*da0073e9SAndroid Build Coastguard Worker 10757*da0073e9SAndroid Build Coastguard Worker b_primal, b_tangent = fwAD.unpack_dual(b) 10758*da0073e9SAndroid Build Coastguard Worker self.assertFalse("tangent=" in str(b_primal)) 10759*da0073e9SAndroid Build Coastguard Worker self.assertFalse("tangent=" in str(b_tangent)) 10760*da0073e9SAndroid Build Coastguard Worker 10761*da0073e9SAndroid Build Coastguard Worker def test_basic_packing_unpacking(self): 10762*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10763*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 10764*da0073e9SAndroid Build Coastguard Worker 10765*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10766*da0073e9SAndroid Build Coastguard Worker baz = fwAD.make_dual(foo, bar) 10767*da0073e9SAndroid Build Coastguard Worker baz_primal, baz_tangent = fwAD.unpack_dual(baz) 10768*da0073e9SAndroid Build Coastguard Worker self.assertEqual(baz_primal, foo) 10769*da0073e9SAndroid Build Coastguard Worker self.assertIs(baz_tangent, bar) 10770*da0073e9SAndroid Build Coastguard Worker 10771*da0073e9SAndroid Build Coastguard Worker # Check unpacked dual is returned as a named tuple 10772*da0073e9SAndroid Build Coastguard Worker # NB: Every invocation of unpack_dual returns a new tensor view 10773*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(baz_primal, fwAD.unpack_dual(baz).primal) 10774*da0073e9SAndroid Build Coastguard Worker self.assertEqual(baz_primal, fwAD.unpack_dual(baz).primal) 10775*da0073e9SAndroid Build Coastguard Worker self.assertIs(baz_tangent, fwAD.unpack_dual(baz).tangent) 10776*da0073e9SAndroid Build Coastguard Worker 10777*da0073e9SAndroid Build Coastguard Worker # Check that packing/unpacking did not change the input 10778*da0073e9SAndroid Build Coastguard Worker foo_primal, foo_tangent = fwAD.unpack_dual(foo) 10779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo_primal, foo) 10780*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(foo_tangent) 10781*da0073e9SAndroid Build Coastguard Worker 10782*da0073e9SAndroid Build Coastguard Worker def test_advanced_packing_unpacking(self): 10783*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10784*da0073e9SAndroid Build Coastguard Worker bar = torch.ones(2) 10785*da0073e9SAndroid Build Coastguard Worker 10786*da0073e9SAndroid Build Coastguard Worker # Memory and version counter check 10787*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10788*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10789*da0073e9SAndroid Build Coastguard Worker 10790*da0073e9SAndroid Build Coastguard Worker # Ensure that they are sharing memory and version counter 10791*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dual.storage().data_ptr(), foo.storage().data_ptr()) 10792*da0073e9SAndroid Build Coastguard Worker 10793*da0073e9SAndroid Build Coastguard Worker # Ensure we properly share the version counter 10794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo._version, dual._version) 10795*da0073e9SAndroid Build Coastguard Worker foo.add_(1) 10796*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo._version, dual._version) 10797*da0073e9SAndroid Build Coastguard Worker 10798*da0073e9SAndroid Build Coastguard Worker # Unpacking should only create aliases as well 10799*da0073e9SAndroid Build Coastguard Worker dual_primal, dual_tangent = fwAD.unpack_dual(dual) 10800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dual_primal.storage().data_ptr(), foo.storage().data_ptr()) 10801*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 10802*da0073e9SAndroid Build Coastguard Worker dual_tangent.storage().data_ptr(), bar.storage().data_ptr() 10803*da0073e9SAndroid Build Coastguard Worker ) 10804*da0073e9SAndroid Build Coastguard Worker # And the tangent is actually re-used as-is so it is still the same Tensor 10805*da0073e9SAndroid Build Coastguard Worker self.assertIs(dual_tangent, bar) 10806*da0073e9SAndroid Build Coastguard Worker 10807*da0073e9SAndroid Build Coastguard Worker # Ensure we properly share the version counter 10808*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo._version, dual_primal._version) 10809*da0073e9SAndroid Build Coastguard Worker foo.add_(1) 10810*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo._version, dual_primal._version) 10811*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar._version, dual_tangent._version) 10812*da0073e9SAndroid Build Coastguard Worker bar.add_(1) 10813*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar._version, dual_tangent._version) 10814*da0073e9SAndroid Build Coastguard Worker 10815*da0073e9SAndroid Build Coastguard Worker # backward mode check 10816*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10817*da0073e9SAndroid Build Coastguard Worker foo.requires_grad_() 10818*da0073e9SAndroid Build Coastguard Worker bar.requires_grad_() 10819*da0073e9SAndroid Build Coastguard Worker 10820*da0073e9SAndroid Build Coastguard Worker # Check that backward gradients properly propagates through packing/unpacking 10821*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10822*da0073e9SAndroid Build Coastguard Worker p, t = fwAD.unpack_dual(dual) 10823*da0073e9SAndroid Build Coastguard Worker 10824*da0073e9SAndroid Build Coastguard Worker gfoo, gbar = torch.autograd.grad( 10825*da0073e9SAndroid Build Coastguard Worker p.sum(), (foo, bar), retain_graph=True, allow_unused=True 10826*da0073e9SAndroid Build Coastguard Worker ) 10827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gfoo, torch.ones_like(foo)) 10828*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(gbar) 10829*da0073e9SAndroid Build Coastguard Worker 10830*da0073e9SAndroid Build Coastguard Worker gfoo, gbar = torch.autograd.grad( 10831*da0073e9SAndroid Build Coastguard Worker t.sum(), (foo, bar), retain_graph=True, allow_unused=True 10832*da0073e9SAndroid Build Coastguard Worker ) 10833*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(gfoo) 10834*da0073e9SAndroid Build Coastguard Worker self.assertEqual(gbar, torch.ones_like(bar)) 10835*da0073e9SAndroid Build Coastguard Worker 10836*da0073e9SAndroid Build Coastguard Worker # Check that forward gradients are impacted by detach() 10837*da0073e9SAndroid Build Coastguard Worker detached_dual = dual.detach() 10838*da0073e9SAndroid Build Coastguard Worker out = detached_dual * 2 10839*da0073e9SAndroid Build Coastguard Worker p, t = fwAD.unpack_dual(out) 10840*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.requires_grad) 10841*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, foo * 2) 10842*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t) 10843*da0073e9SAndroid Build Coastguard Worker 10844*da0073e9SAndroid Build Coastguard Worker # Check that forward gradients are not impacted by no_grad 10845*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 10846*da0073e9SAndroid Build Coastguard Worker out = dual * 3 10847*da0073e9SAndroid Build Coastguard Worker p, t = fwAD.unpack_dual(out) 10848*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.requires_grad) 10849*da0073e9SAndroid Build Coastguard Worker self.assertFalse(t.requires_grad) 10850*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, foo * 3) 10851*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t, bar * 3) 10852*da0073e9SAndroid Build Coastguard Worker 10853*da0073e9SAndroid Build Coastguard Worker # Check that forward gradients are not impacted by inplace detach 10854*da0073e9SAndroid Build Coastguard Worker dual = dual.clone() 10855*da0073e9SAndroid Build Coastguard Worker dual.detach_() 10856*da0073e9SAndroid Build Coastguard Worker out = dual * 2 10857*da0073e9SAndroid Build Coastguard Worker p, t = fwAD.unpack_dual(out) 10858*da0073e9SAndroid Build Coastguard Worker self.assertFalse(p.requires_grad) 10859*da0073e9SAndroid Build Coastguard Worker self.assertEqual(p, foo * 2) 10860*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t) 10861*da0073e9SAndroid Build Coastguard Worker 10862*da0073e9SAndroid Build Coastguard Worker def test_view_inplace_non_differentiable_views(self): 10863*da0073e9SAndroid Build Coastguard Worker original_foo = torch.rand(2, dtype=torch.double) 10864*da0073e9SAndroid Build Coastguard Worker original_bar = torch.ones(2, dtype=torch.double) 10865*da0073e9SAndroid Build Coastguard Worker 10866*da0073e9SAndroid Build Coastguard Worker # Do clones to be able to compare the values updated inplace 10867*da0073e9SAndroid Build Coastguard Worker # with the original content of these Tensors 10868*da0073e9SAndroid Build Coastguard Worker foo = original_foo.clone() 10869*da0073e9SAndroid Build Coastguard Worker bar = original_bar.clone() 10870*da0073e9SAndroid Build Coastguard Worker 10871*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10872*da0073e9SAndroid Build Coastguard Worker # Note that in this test, we use "update" to mean computing the right tangent for the dual 10873*da0073e9SAndroid Build Coastguard Worker # All the inplace operations here are expected to update the primal value of the Tensors but 10874*da0073e9SAndroid Build Coastguard Worker # not always their tangents. 10875*da0073e9SAndroid Build Coastguard Worker # Also all mentions of "non differentiable view" here means non forward differentiable view 10876*da0073e9SAndroid Build Coastguard Worker # unless specified otherwise. 10877*da0073e9SAndroid Build Coastguard Worker # See note [Forward Grad View/inplace] for more details on how these views work. 10878*da0073e9SAndroid Build Coastguard Worker 10879*da0073e9SAndroid Build Coastguard Worker # Check that inplace ops do not update non-differentiable views 10880*da0073e9SAndroid Build Coastguard Worker # Non differentiable view 10881*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10882*da0073e9SAndroid Build Coastguard Worker dual *= 2 10883*da0073e9SAndroid Build Coastguard Worker # Check that non differentiable view's tangent was not updated 10884*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10885*da0073e9SAndroid Build Coastguard Worker # Check that the computed result is correct 10886*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bar, original_bar * 2) 10887*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2) 10888*da0073e9SAndroid Build Coastguard Worker self.assertEqual(foo, original_foo * 2) 10889*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 2) 10890*da0073e9SAndroid Build Coastguard Worker # Other non differentiable view 10891*da0073e9SAndroid Build Coastguard Worker dual_primal, dual_tangent = fwAD.unpack_dual(dual) 10892*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(dual_primal)[1]) 10893*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(dual_tangent)[1]) 10894*da0073e9SAndroid Build Coastguard Worker dual_primal *= 2 10895*da0073e9SAndroid Build Coastguard Worker # Ensure dual's tangent did not change 10896*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4) 10897*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 2) 10898*da0073e9SAndroid Build Coastguard Worker dual_tangent *= 2 10899*da0073e9SAndroid Build Coastguard Worker # Ensure dual's primal did not change 10900*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[0], original_foo * 4) 10901*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[1], original_bar * 4) 10902*da0073e9SAndroid Build Coastguard Worker 10903*da0073e9SAndroid Build Coastguard Worker def test_view_inplace_differentiable_views(self): 10904*da0073e9SAndroid Build Coastguard Worker original_foo = torch.rand(2) 10905*da0073e9SAndroid Build Coastguard Worker original_bar = torch.ones(2) 10906*da0073e9SAndroid Build Coastguard Worker 10907*da0073e9SAndroid Build Coastguard Worker # Do clones to be able to compare the values updated inplace 10908*da0073e9SAndroid Build Coastguard Worker # with the original content of these Tensors 10909*da0073e9SAndroid Build Coastguard Worker foo = original_foo.clone() 10910*da0073e9SAndroid Build Coastguard Worker bar = original_bar.clone() 10911*da0073e9SAndroid Build Coastguard Worker 10912*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10913*da0073e9SAndroid Build Coastguard Worker # Check that inplace ops do update differentiable view but stop at non differentiable ones 10914*da0073e9SAndroid Build Coastguard Worker # A non differentiable view 10915*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10916*da0073e9SAndroid Build Coastguard Worker # A differentiable view 10917*da0073e9SAndroid Build Coastguard Worker view = dual.narrow(0, 0, 1) 10918*da0073e9SAndroid Build Coastguard Worker view *= 2 10919*da0073e9SAndroid Build Coastguard Worker # Check that non differentiable view was not updated 10920*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10921*da0073e9SAndroid Build Coastguard Worker # Check that differentiable view was updated 10922*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(dual)[1], torch.tensor([2.0, 1.0])) 10923*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(view)[1], torch.tensor([2.0])) 10924*da0073e9SAndroid Build Coastguard Worker 10925*da0073e9SAndroid Build Coastguard Worker # Check that we track differentiable view even for Tensors that are not dual 10926*da0073e9SAndroid Build Coastguard Worker baz = torch.rand(2) 10927*da0073e9SAndroid Build Coastguard Worker baz += dual 10928*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(baz)[1], fwAD.unpack_dual(dual)[1]) 10929*da0073e9SAndroid Build Coastguard Worker # Updates on view should as well 10930*da0073e9SAndroid Build Coastguard Worker baz = torch.rand(2) 10931*da0073e9SAndroid Build Coastguard Worker baz[0] = dual[0] 10932*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(baz)[1][0], fwAD.unpack_dual(dual)[1][0]) 10933*da0073e9SAndroid Build Coastguard Worker # Unused values get a gradient of 0 10934*da0073e9SAndroid Build Coastguard Worker self.assertEqual(fwAD.unpack_dual(baz)[1][1], 0.0) 10935*da0073e9SAndroid Build Coastguard Worker 10936*da0073e9SAndroid Build Coastguard Worker # Check that forward non-differentiable views do prevent gradient update 10937*da0073e9SAndroid Build Coastguard Worker baz = torch.rand(2) 10938*da0073e9SAndroid Build Coastguard Worker view = baz.detach() 10939*da0073e9SAndroid Build Coastguard Worker view += dual 10940*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(baz)[1]) 10941*da0073e9SAndroid Build Coastguard Worker 10942*da0073e9SAndroid Build Coastguard Worker def test_view_inplace_always_creates_a_view(self): 10943*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/67800 10944*da0073e9SAndroid Build Coastguard Worker # The codepath may depend on the op. At the time writing, when self is not a dual tensor 10945*da0073e9SAndroid Build Coastguard Worker # the resulting forward grad for self for... 10946*da0073e9SAndroid Build Coastguard Worker # - add_ has the same layout as self 10947*da0073e9SAndroid Build Coastguard Worker # - mul_ has the same layout as other 10948*da0073e9SAndroid Build Coastguard Worker # This is kind of fragile because the above depends on how the forward grad expression 10949*da0073e9SAndroid Build Coastguard Worker # is written. For add and mul at least, the output inherits the layout of LHS. 10950*da0073e9SAndroid Build Coastguard Worker # We want to handle at least these two cases. 10951*da0073e9SAndroid Build Coastguard Worker inplace_binary_ops = ( # Add more to this list? 10952*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.add_(y), 10953*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.mul_(y), 10954*da0073e9SAndroid Build Coastguard Worker lambda x, y: x.copy_(y), 10955*da0073e9SAndroid Build Coastguard Worker ) 10956*da0073e9SAndroid Build Coastguard Worker 10957*da0073e9SAndroid Build Coastguard Worker for inplace_binary_op in inplace_binary_ops: 10958*da0073e9SAndroid Build Coastguard Worker base = torch.randn(2, 2) 10959*da0073e9SAndroid Build Coastguard Worker view = base.transpose(0, 1) 10960*da0073e9SAndroid Build Coastguard Worker 10961*da0073e9SAndroid Build Coastguard Worker primal = torch.randn(2, 2) 10962*da0073e9SAndroid Build Coastguard Worker tangent = torch.randn(2, 2) 10963*da0073e9SAndroid Build Coastguard Worker 10964*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10965*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(primal, tangent) 10966*da0073e9SAndroid Build Coastguard Worker inplace_binary_op(view, dual) 10967*da0073e9SAndroid Build Coastguard Worker 10968*da0073e9SAndroid Build Coastguard Worker # Verify that a view relationship is created for both the primal and tangent 10969*da0073e9SAndroid Build Coastguard Worker p, t = fwAD.unpack_dual(base) 10970*da0073e9SAndroid Build Coastguard Worker p_clone = p.clone() 10971*da0073e9SAndroid Build Coastguard Worker t_clone = t.clone() 10972*da0073e9SAndroid Build Coastguard Worker view *= 2 10973*da0073e9SAndroid Build Coastguard Worker p, t = fwAD.unpack_dual(base) 10974*da0073e9SAndroid Build Coastguard Worker 10975*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(p_clone * 2, p)) 10976*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(t_clone * 2, t)) 10977*da0073e9SAndroid Build Coastguard Worker 10978*da0073e9SAndroid Build Coastguard Worker def test_grad_cleanup(self): 10979*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 10980*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 10981*da0073e9SAndroid Build Coastguard Worker baz = torch.rand(2) 10982*da0073e9SAndroid Build Coastguard Worker 10983*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10984*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(foo, bar) 10985*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10986*da0073e9SAndroid Build Coastguard Worker self.assertIs(fwAD.unpack_dual(dual)[1], bar) 10987*da0073e9SAndroid Build Coastguard Worker 10988*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(dual)[1]) 10989*da0073e9SAndroid Build Coastguard Worker 10990*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 10991*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(fwAD.unpack_dual(foo)[1]) 10992*da0073e9SAndroid Build Coastguard Worker new_dual = fwAD.make_dual(foo, baz) 10993*da0073e9SAndroid Build Coastguard Worker 10994*da0073e9SAndroid Build Coastguard Worker dual_primal, dual_tangent = fwAD.unpack_dual(dual) 10995*da0073e9SAndroid Build Coastguard Worker new_dual_primal, new_dual_tangent = fwAD.unpack_dual(new_dual) 10996*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dual_primal, new_dual_primal) 10997*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(dual_tangent) 10998*da0073e9SAndroid Build Coastguard Worker self.assertEqual(new_dual_tangent, baz) 10999*da0073e9SAndroid Build Coastguard Worker 11000*da0073e9SAndroid Build Coastguard Worker def test_detach_view_tracking(self): 11001*da0073e9SAndroid Build Coastguard Worker # Default detach is both forward and backward non-differentiable 11002*da0073e9SAndroid Build Coastguard Worker foo = torch.rand(2) 11003*da0073e9SAndroid Build Coastguard Worker foo_weak = torch._C._WeakTensorRef(foo) 11004*da0073e9SAndroid Build Coastguard Worker 11005*da0073e9SAndroid Build Coastguard Worker out = foo.detach() 11006*da0073e9SAndroid Build Coastguard Worker 11007*da0073e9SAndroid Build Coastguard Worker del foo 11008*da0073e9SAndroid Build Coastguard Worker self.assertTrue(foo_weak.expired()) 11009*da0073e9SAndroid Build Coastguard Worker 11010*da0073e9SAndroid Build Coastguard Worker def test_out_variant(self): 11011*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 11012*da0073e9SAndroid Build Coastguard Worker foo = fwAD.make_dual(torch.rand(2), torch.rand(2)) 11013*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 11014*da0073e9SAndroid Build Coastguard Worker 11015*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out= function"): 11016*da0073e9SAndroid Build Coastguard Worker torch.add(bar, bar, out=foo) 11017*da0073e9SAndroid Build Coastguard Worker 11018*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "out= function"): 11019*da0073e9SAndroid Build Coastguard Worker torch.add(foo, bar, out=bar) 11020*da0073e9SAndroid Build Coastguard Worker 11021*da0073e9SAndroid Build Coastguard Worker def test_non_differentiable(self): 11022*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 11023*da0073e9SAndroid Build Coastguard Worker foo = fwAD.make_dual(torch.rand(2), torch.rand(2)) 11024*da0073e9SAndroid Build Coastguard Worker bar = torch.rand(2) 11025*da0073e9SAndroid Build Coastguard Worker 11026*da0073e9SAndroid Build Coastguard Worker # No differentiable outputs, shouldn't error 11027*da0073e9SAndroid Build Coastguard Worker eq = foo == bar 11028*da0073e9SAndroid Build Coastguard Worker 11029*da0073e9SAndroid Build Coastguard Worker # Inplace 11030*da0073e9SAndroid Build Coastguard Worker foo.eq_(bar) 11031*da0073e9SAndroid Build Coastguard Worker 11032*da0073e9SAndroid Build Coastguard Worker def test_create_new_zeros_with_same_meta(self): 11033*da0073e9SAndroid Build Coastguard Worker new_zeroes_fn = torch.ops.aten._new_zeros_with_same_feature_meta 11034*da0073e9SAndroid Build Coastguard Worker 11035*da0073e9SAndroid Build Coastguard Worker def check(a, b): 11036*da0073e9SAndroid Build Coastguard Worker def assert_same_meta(t, target): 11037*da0073e9SAndroid Build Coastguard Worker for num_bdim in range(t.dim()): 11038*da0073e9SAndroid Build Coastguard Worker result = new_zeroes_fn(t, target, self_num_batch_dims=num_bdim) 11039*da0073e9SAndroid Build Coastguard Worker 11040*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dim(), target.dim() + num_bdim) 11041*da0073e9SAndroid Build Coastguard Worker 11042*da0073e9SAndroid Build Coastguard Worker # Check size/strides match for feature dims only 11043*da0073e9SAndroid Build Coastguard Worker for i in range(num_bdim, result.dim()): 11044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.size()[i], target.size()[i - num_bdim]) 11045*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 11046*da0073e9SAndroid Build Coastguard Worker result.stride()[i], target.stride()[i - num_bdim] 11047*da0073e9SAndroid Build Coastguard Worker ) 11048*da0073e9SAndroid Build Coastguard Worker 11049*da0073e9SAndroid Build Coastguard Worker # Check that we generate strides reasonably 11050*da0073e9SAndroid Build Coastguard Worker if target.is_contiguous(): 11051*da0073e9SAndroid Build Coastguard Worker self.assertTrue(result.is_contiguous()) 11052*da0073e9SAndroid Build Coastguard Worker 11053*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.storage_offset(), target.storage_offset()) 11054*da0073e9SAndroid Build Coastguard Worker 11055*da0073e9SAndroid Build Coastguard Worker prod_of_t_bdims = reduce(operator.mul, t.size()[:num_bdim], 1) 11056*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 11057*da0073e9SAndroid Build Coastguard Worker len(result.storage()), len(target.storage()) * prod_of_t_bdims 11058*da0073e9SAndroid Build Coastguard Worker ) 11059*da0073e9SAndroid Build Coastguard Worker 11060*da0073e9SAndroid Build Coastguard Worker # TensorOptions is same 11061*da0073e9SAndroid Build Coastguard Worker self.assertEqual(result.dtype, target.dtype) 11062*da0073e9SAndroid Build Coastguard Worker 11063*da0073e9SAndroid Build Coastguard Worker assert_same_meta(a, b) 11064*da0073e9SAndroid Build Coastguard Worker assert_same_meta(b, a) 11065*da0073e9SAndroid Build Coastguard Worker 11066*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5, dtype=torch.float) 11067*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 3, 4, dtype=torch.double) 11068*da0073e9SAndroid Build Coastguard Worker check(a, b) 11069*da0073e9SAndroid Build Coastguard Worker 11070*da0073e9SAndroid Build Coastguard Worker # non-contiguous case 11071*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, 4).transpose(0, 1).contiguous().transpose(0, 1) 11072*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 3, 4) 11073*da0073e9SAndroid Build Coastguard Worker check(a, b) 11074*da0073e9SAndroid Build Coastguard Worker 11075*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5).narrow(0, 1, 2) 11076*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2) 11077*da0073e9SAndroid Build Coastguard Worker check(a, b) 11078*da0073e9SAndroid Build Coastguard Worker 11079*da0073e9SAndroid Build Coastguard Worker # tensor is not a view, but still does not index entirety of storage 11080*da0073e9SAndroid Build Coastguard Worker a = torch.randn(5).resize_(4) 11081*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4) 11082*da0073e9SAndroid Build Coastguard Worker check(a, b) 11083*da0073e9SAndroid Build Coastguard Worker 11084*da0073e9SAndroid Build Coastguard Worker # Zero-numel tensors 11085*da0073e9SAndroid Build Coastguard Worker a = torch.randn(1, 0, 2) 11086*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1, 2) 11087*da0073e9SAndroid Build Coastguard Worker check(a, b) 11088*da0073e9SAndroid Build Coastguard Worker 11089*da0073e9SAndroid Build Coastguard Worker # Scalar tensor 11090*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0) 11091*da0073e9SAndroid Build Coastguard Worker b = torch.randn(1, 2) 11092*da0073e9SAndroid Build Coastguard Worker check(a, b) 11093*da0073e9SAndroid Build Coastguard Worker 11094*da0073e9SAndroid Build Coastguard Worker def test_backward_graph_destruction(self): 11095*da0073e9SAndroid Build Coastguard Worker def fn(): 11096*da0073e9SAndroid Build Coastguard Worker a = torch.rand(10, requires_grad=True) 11097*da0073e9SAndroid Build Coastguard Worker 11098*da0073e9SAndroid Build Coastguard Worker da = fwAD.make_dual(torch.rand_like(a), a) 11099*da0073e9SAndroid Build Coastguard Worker 11100*da0073e9SAndroid Build Coastguard Worker # Create an object with a c++ cycle as: 11101*da0073e9SAndroid Build Coastguard Worker # db -> AutogradMeta -> ForwardGrad -> db's grad 11102*da0073e9SAndroid Build Coastguard Worker # db's grad -> AutogradMeta -> MulBackward 11103*da0073e9SAndroid Build Coastguard Worker # MulBackward -> SavedVariable -> db 11104*da0073e9SAndroid Build Coastguard Worker db = da.exp() 11105*da0073e9SAndroid Build Coastguard Worker 11106*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 11107*da0073e9SAndroid Build Coastguard Worker fn() 11108*da0073e9SAndroid Build Coastguard Worker # This test make sure that we don't deadlock on exit of this 11109*da0073e9SAndroid Build Coastguard Worker # context manager. If you do, there is something wrong with the 11110*da0073e9SAndroid Build Coastguard Worker # locking of the forward ad level most likely 11111*da0073e9SAndroid Build Coastguard Worker 11112*da0073e9SAndroid Build Coastguard Worker 11113*da0073e9SAndroid Build Coastguard Worker# Generic device type autograd tests. 11114*da0073e9SAndroid Build Coastguard Workerclass TestAutogradDeviceType(TestCase): 11115*da0073e9SAndroid Build Coastguard Worker def test_min_max_median_backprops_to_all_values(self, device): 11116*da0073e9SAndroid Build Coastguard Worker for f in [torch.min, torch.max, torch.median, torch.nanmedian]: 11117*da0073e9SAndroid Build Coastguard Worker x1 = torch.tensor( 11118*da0073e9SAndroid Build Coastguard Worker [1.0, 0.0, 1.0, 0.0, 1.0, 0.0], device=device, requires_grad=True 11119*da0073e9SAndroid Build Coastguard Worker ) 11120*da0073e9SAndroid Build Coastguard Worker x2 = torch.tensor( 11121*da0073e9SAndroid Build Coastguard Worker [float("nan"), float("nan"), float("nan")], requires_grad=True 11122*da0073e9SAndroid Build Coastguard Worker ) 11123*da0073e9SAndroid Build Coastguard Worker for x in [x1, x2]: 11124*da0073e9SAndroid Build Coastguard Worker y = f(x) 11125*da0073e9SAndroid Build Coastguard Worker y.backward() 11126*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad.sum(), 1.0) 11127*da0073e9SAndroid Build Coastguard Worker self.assertEqual((x.grad == 1 / 3).sum(), 3) 11128*da0073e9SAndroid Build Coastguard Worker 11129*da0073e9SAndroid Build Coastguard Worker def test_scatter_index_reduce_amin_amax_backprops_to_all_values(self, device): 11130*da0073e9SAndroid Build Coastguard Worker # tests that gradients are evenly distributed when there are multiple max/min values 11131*da0073e9SAndroid Build Coastguard Worker # tested here instead of adding a SampleInput as the backward for this case is non-differentiable for gradgrad 11132*da0073e9SAndroid Build Coastguard Worker # as is the case for test_min_max_median_backprops_to_all_values above 11133*da0073e9SAndroid Build Coastguard Worker fns = (torch.scatter_reduce, torch.index_reduce) 11134*da0073e9SAndroid Build Coastguard Worker reduces = ("amin", "amax") 11135*da0073e9SAndroid Build Coastguard Worker for fn, reduction in product(fns, reduces): 11136*da0073e9SAndroid Build Coastguard Worker input = torch.randn( 11137*da0073e9SAndroid Build Coastguard Worker (2, 3), device=device, dtype=torch.float64, requires_grad=True 11138*da0073e9SAndroid Build Coastguard Worker ) 11139*da0073e9SAndroid Build Coastguard Worker src = input.clone().detach_().requires_grad_(True) 11140*da0073e9SAndroid Build Coastguard Worker idx = torch.arange(2).to(dtype=torch.long, device=device) 11141*da0073e9SAndroid Build Coastguard Worker if fn == torch.scatter_reduce: 11142*da0073e9SAndroid Build Coastguard Worker idx = idx.unsqueeze(-1).expand((2, 3)) 11143*da0073e9SAndroid Build Coastguard Worker 11144*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (input, 0, idx, src, reduction), check_batched_grad=False) 11145*da0073e9SAndroid Build Coastguard Worker 11146*da0073e9SAndroid Build Coastguard Worker def test_scatter_index_reduce_prod_gradgrad_error(self, device): 11147*da0073e9SAndroid Build Coastguard Worker # test that double backward raises an error for the case where 2 zeros in src 11148*da0073e9SAndroid Build Coastguard Worker # are scattered to the same position in self 11149*da0073e9SAndroid Build Coastguard Worker input = torch.tensor( 11150*da0073e9SAndroid Build Coastguard Worker [1.0], device=device, dtype=torch.float64, requires_grad=True 11151*da0073e9SAndroid Build Coastguard Worker ) 11152*da0073e9SAndroid Build Coastguard Worker src = torch.tensor( 11153*da0073e9SAndroid Build Coastguard Worker [0.0, 0.0], device=device, dtype=torch.float64, requires_grad=True 11154*da0073e9SAndroid Build Coastguard Worker ) 11155*da0073e9SAndroid Build Coastguard Worker idx = torch.tensor([0, 0], device=device, dtype=torch.long) 11156*da0073e9SAndroid Build Coastguard Worker 11157*da0073e9SAndroid Build Coastguard Worker for fn in (torch.scatter_reduce, torch.index_reduce): 11158*da0073e9SAndroid Build Coastguard Worker # check that this case passes on gradcheck 11159*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (input, 0, idx, src, "prod"), check_batched_grad=False) 11160*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 11161*da0073e9SAndroid Build Coastguard Worker RuntimeError, "Double backward is unsupported for" 11162*da0073e9SAndroid Build Coastguard Worker ): 11163*da0073e9SAndroid Build Coastguard Worker gradgradcheck(fn, (input, 0, idx, src, "prod")) 11164*da0073e9SAndroid Build Coastguard Worker 11165*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11166*da0073e9SAndroid Build Coastguard Worker def test_parameter_resize(self, device): 11167*da0073e9SAndroid Build Coastguard Worker asd = torch.nn.Parameter(torch.ones(16, dtype=torch.double, device=device)) 11168*da0073e9SAndroid Build Coastguard Worker 11169*da0073e9SAndroid Build Coastguard Worker for i in range(2): 11170*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11171*da0073e9SAndroid Build Coastguard Worker asd.set_(asd[1:]) 11172*da0073e9SAndroid Build Coastguard Worker asd.grad = None 11173*da0073e9SAndroid Build Coastguard Worker 11174*da0073e9SAndroid Build Coastguard Worker m = torch.cat((asd, asd)) 11175*da0073e9SAndroid Build Coastguard Worker m.sum().backward() 11176*da0073e9SAndroid Build Coastguard Worker 11177*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11178*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 11179*da0073e9SAndroid Build Coastguard Worker def test_sparse_ctor_getter_backward(self, device, dtype): 11180*da0073e9SAndroid Build Coastguard Worker # See NOTE [ Sparse: autograd and API ] on the expected behavior of this test 11181*da0073e9SAndroid Build Coastguard Worker def _test(size, sparse_dim, nnz, device): 11182*da0073e9SAndroid Build Coastguard Worker v_size = [nnz] + list(size[sparse_dim:]) 11183*da0073e9SAndroid Build Coastguard Worker i = torch.rand(sparse_dim, nnz) 11184*da0073e9SAndroid Build Coastguard Worker i.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(i)) 11185*da0073e9SAndroid Build Coastguard Worker i = i.to(torch.long) 11186*da0073e9SAndroid Build Coastguard Worker 11187*da0073e9SAndroid Build Coastguard Worker inp = torch.randn( 11188*da0073e9SAndroid Build Coastguard Worker v_size, dtype=torch.double, device=device, requires_grad=True 11189*da0073e9SAndroid Build Coastguard Worker ) 11190*da0073e9SAndroid Build Coastguard Worker other = self.genSparseTensor( 11191*da0073e9SAndroid Build Coastguard Worker size, sparse_dim, nnz, is_uncoalesced=True, device=device, dtype=dtype 11192*da0073e9SAndroid Build Coastguard Worker )[0] 11193*da0073e9SAndroid Build Coastguard Worker 11194*da0073e9SAndroid Build Coastguard Worker def fn(v): 11195*da0073e9SAndroid Build Coastguard Worker x = torch.sparse_coo_tensor(i, v, size, dtype=dtype, device=device) 11196*da0073e9SAndroid Build Coastguard Worker y = (x + other).coalesce() 11197*da0073e9SAndroid Build Coastguard Worker yv = y.values() 11198*da0073e9SAndroid Build Coastguard Worker new_v = yv.tanh() 11199*da0073e9SAndroid Build Coastguard Worker z = torch.sparse_coo_tensor(y.indices(), new_v, y.size()) 11200*da0073e9SAndroid Build Coastguard Worker return z.coalesce().values() 11201*da0073e9SAndroid Build Coastguard Worker 11202*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (inp,), check_batched_grad=False) 11203*da0073e9SAndroid Build Coastguard Worker # FIXME: make gradgradcheck work. 11204*da0073e9SAndroid Build Coastguard Worker # gradgradcheck(fn, (inp,), check_batched_grad=False) 11205*da0073e9SAndroid Build Coastguard Worker 11206*da0073e9SAndroid Build Coastguard Worker # assert that _values is non-differentiable 11207*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not have a grad_fn"): 11208*da0073e9SAndroid Build Coastguard Worker other.detach().requires_grad_()._values().backward( 11209*da0073e9SAndroid Build Coastguard Worker torch.ones_like(other._values()) 11210*da0073e9SAndroid Build Coastguard Worker ) 11211*da0073e9SAndroid Build Coastguard Worker 11212*da0073e9SAndroid Build Coastguard Worker for empty_i, empty_v, empty_nnz in product([True, False], repeat=3): 11213*da0073e9SAndroid Build Coastguard Worker sparse_size = [] if empty_i else [2, 1] 11214*da0073e9SAndroid Build Coastguard Worker dense_size = [1, 0, 2] if empty_v else [1, 2] 11215*da0073e9SAndroid Build Coastguard Worker nnz = 0 if empty_nnz else 5 11216*da0073e9SAndroid Build Coastguard Worker _test(sparse_size + dense_size, len(sparse_size), nnz, device) 11217*da0073e9SAndroid Build Coastguard Worker 11218*da0073e9SAndroid Build Coastguard Worker @skipMeta 11219*da0073e9SAndroid Build Coastguard Worker @skipIfMps 11220*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.double, torch.cdouble) 11221*da0073e9SAndroid Build Coastguard Worker def test_sparse_backward(self, device, dtype): 11222*da0073e9SAndroid Build Coastguard Worker class FixedGradientFunction(Function): 11223*da0073e9SAndroid Build Coastguard Worker @staticmethod 11224*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, grad_x): 11225*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(grad_x) 11226*da0073e9SAndroid Build Coastguard Worker return x 11227*da0073e9SAndroid Build Coastguard Worker 11228*da0073e9SAndroid Build Coastguard Worker @staticmethod 11229*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_x): 11230*da0073e9SAndroid Build Coastguard Worker (saved_grad_x,) = ctx.saved_tensors 11231*da0073e9SAndroid Build Coastguard Worker return saved_grad_x, None 11232*da0073e9SAndroid Build Coastguard Worker 11233*da0073e9SAndroid Build Coastguard Worker size = torch.Size([6, 3, 2]) 11234*da0073e9SAndroid Build Coastguard Worker i1 = torch.tensor([[0, 3, 4], [0, 2, 2]], dtype=torch.long) 11235*da0073e9SAndroid Build Coastguard Worker v1 = make_tensor([3, 2], dtype=dtype, device=device) 11236*da0073e9SAndroid Build Coastguard Worker sparse_grad1 = torch.sparse_coo_tensor(i1, v1, size, dtype=dtype, device=device) 11237*da0073e9SAndroid Build Coastguard Worker i2 = torch.tensor([[0, 1, 3, 4], [0, 1, 2, 2]], dtype=torch.long) 11238*da0073e9SAndroid Build Coastguard Worker v2 = make_tensor([4, 2], dtype=dtype, device=device) 11239*da0073e9SAndroid Build Coastguard Worker sparse_grad2 = torch.sparse_coo_tensor(i2, v2, size, dtype=dtype, device=device) 11240*da0073e9SAndroid Build Coastguard Worker dense_grad = torch.rand(size, device=device, dtype=dtype) 11241*da0073e9SAndroid Build Coastguard Worker fn = FixedGradientFunction 11242*da0073e9SAndroid Build Coastguard Worker 11243*da0073e9SAndroid Build Coastguard Worker # sparse first 11244*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 11245*da0073e9SAndroid Build Coastguard Worker ( 11246*da0073e9SAndroid Build Coastguard Worker fn.apply(x, sparse_grad1) 11247*da0073e9SAndroid Build Coastguard Worker + fn.apply(x, dense_grad) 11248*da0073e9SAndroid Build Coastguard Worker + fn.apply(x, sparse_grad2) 11249*da0073e9SAndroid Build Coastguard Worker ).sum().abs().backward() 11250*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) 11251*da0073e9SAndroid Build Coastguard Worker # dense first 11252*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 11253*da0073e9SAndroid Build Coastguard Worker ( 11254*da0073e9SAndroid Build Coastguard Worker fn.apply(x, dense_grad) 11255*da0073e9SAndroid Build Coastguard Worker + fn.apply(x, sparse_grad1) 11256*da0073e9SAndroid Build Coastguard Worker + fn.apply(x, sparse_grad2) 11257*da0073e9SAndroid Build Coastguard Worker ).sum().abs().backward() 11258*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, dense_grad + sparse_grad1 + sparse_grad2) 11259*da0073e9SAndroid Build Coastguard Worker # sparse only 11260*da0073e9SAndroid Build Coastguard Worker x = torch.randn(size, dtype=dtype, device=device, requires_grad=True) 11261*da0073e9SAndroid Build Coastguard Worker (fn.apply(x, sparse_grad1) + fn.apply(x, sparse_grad2)).sum().abs().backward() 11262*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, sparse_grad1 + sparse_grad2) 11263*da0073e9SAndroid Build Coastguard Worker 11264*da0073e9SAndroid Build Coastguard Worker @skipIfMps 11265*da0073e9SAndroid Build Coastguard Worker def test_sparse_mask_autograd(self, device): 11266*da0073e9SAndroid Build Coastguard Worker tensor = torch.randn(3, requires_grad=True, device=device) 11267*da0073e9SAndroid Build Coastguard Worker mask = torch.ones(3, device=device) 11268*da0073e9SAndroid Build Coastguard Worker mask[1] = 0 11269*da0073e9SAndroid Build Coastguard Worker mask = mask.to_sparse() 11270*da0073e9SAndroid Build Coastguard Worker converted = tensor.sparse_mask(mask).to_dense() 11271*da0073e9SAndroid Build Coastguard Worker converted.sum().backward() 11272*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tensor.grad, mask.to_dense()) 11273*da0073e9SAndroid Build Coastguard Worker 11274*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11275*da0073e9SAndroid Build Coastguard Worker def test_pyscalar_conversions(self, device): 11276*da0073e9SAndroid Build Coastguard Worker def _test_pyscalar_conversions(t, integral_conv): 11277*da0073e9SAndroid Build Coastguard Worker # integral -> integral 11278*da0073e9SAndroid Build Coastguard Worker l = t(torch.zeros(1, 1, 1, dtype=torch.long)) 11279*da0073e9SAndroid Build Coastguard Worker pyscalar = -12345 11280*da0073e9SAndroid Build Coastguard Worker l[0] = pyscalar 11281*da0073e9SAndroid Build Coastguard Worker self.assertEqual(integral_conv(l), pyscalar) 11282*da0073e9SAndroid Build Coastguard Worker 11283*da0073e9SAndroid Build Coastguard Worker # floating point -> floating point 11284*da0073e9SAndroid Build Coastguard Worker f = Variable(t(torch.randn(1, 1, dtype=torch.double))) 11285*da0073e9SAndroid Build Coastguard Worker pyscalar = -12345.1 11286*da0073e9SAndroid Build Coastguard Worker f[0] = pyscalar 11287*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float(f), pyscalar) 11288*da0073e9SAndroid Build Coastguard Worker f[0] = nan 11289*da0073e9SAndroid Build Coastguard Worker self.assertTrue(math.isnan(float(f))) 11290*da0073e9SAndroid Build Coastguard Worker f[0] = inf 11291*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float(f), inf) 11292*da0073e9SAndroid Build Coastguard Worker f[0] = -inf 11293*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float(f), -inf) 11294*da0073e9SAndroid Build Coastguard Worker 11295*da0073e9SAndroid Build Coastguard Worker # integral -> floating point 11296*da0073e9SAndroid Build Coastguard Worker # check we can convert something that loses precision 11297*da0073e9SAndroid Build Coastguard Worker pyscalar = 1234567890123456789 11298*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(pyscalar, integral_conv(float(pyscalar))) 11299*da0073e9SAndroid Build Coastguard Worker l[0] = pyscalar 11300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(float(l), float(pyscalar)) 11301*da0073e9SAndroid Build Coastguard Worker 11302*da0073e9SAndroid Build Coastguard Worker # floating point -> integral 11303*da0073e9SAndroid Build Coastguard Worker f[0] = nan 11304*da0073e9SAndroid Build Coastguard Worker self.assertRaises(ValueError, lambda: integral_conv(f[0])) 11305*da0073e9SAndroid Build Coastguard Worker f[0] = inf 11306*da0073e9SAndroid Build Coastguard Worker self.assertRaises(OverflowError, lambda: integral_conv(f[0])) 11307*da0073e9SAndroid Build Coastguard Worker f[0] = -inf 11308*da0073e9SAndroid Build Coastguard Worker self.assertRaises(OverflowError, lambda: integral_conv(f[0])) 11309*da0073e9SAndroid Build Coastguard Worker f[0] = sys.float_info.max 11310*da0073e9SAndroid Build Coastguard Worker self.assertEqual(integral_conv(f), sys.float_info.max) 11311*da0073e9SAndroid Build Coastguard Worker 11312*da0073e9SAndroid Build Coastguard Worker # bool, nonzero 11313*da0073e9SAndroid Build Coastguard Worker def test_nonzero(tensor, value, expected): 11314*da0073e9SAndroid Build Coastguard Worker tensor[0] = value 11315*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, bool(tensor)) 11316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected, True if tensor else False) 11317*da0073e9SAndroid Build Coastguard Worker 11318*da0073e9SAndroid Build Coastguard Worker test_nonzero(l, 0, False) 11319*da0073e9SAndroid Build Coastguard Worker test_nonzero(l, -2, True) 11320*da0073e9SAndroid Build Coastguard Worker test_nonzero(f, 0.0, False) 11321*da0073e9SAndroid Build Coastguard Worker test_nonzero(f, sys.float_info.min, True) 11322*da0073e9SAndroid Build Coastguard Worker test_nonzero(f, nan, bool(nan)) 11323*da0073e9SAndroid Build Coastguard Worker test_nonzero(f, inf, bool(inf)) 11324*da0073e9SAndroid Build Coastguard Worker test_nonzero(f, -inf, bool(-inf)) 11325*da0073e9SAndroid Build Coastguard Worker 11326*da0073e9SAndroid Build Coastguard Worker _test_pyscalar_conversions(lambda x: x.to(device), lambda x: int(x)) 11327*da0073e9SAndroid Build Coastguard Worker 11328*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float32) 11329*da0073e9SAndroid Build Coastguard Worker @dtypesIfCUDA( 11330*da0073e9SAndroid Build Coastguard Worker torch.half, 11331*da0073e9SAndroid Build Coastguard Worker torch.float, 11332*da0073e9SAndroid Build Coastguard Worker torch.double, 11333*da0073e9SAndroid Build Coastguard Worker torch.int8, 11334*da0073e9SAndroid Build Coastguard Worker torch.int16, 11335*da0073e9SAndroid Build Coastguard Worker torch.int32, 11336*da0073e9SAndroid Build Coastguard Worker torch.int64, 11337*da0073e9SAndroid Build Coastguard Worker ) 11338*da0073e9SAndroid Build Coastguard Worker @dtypes( 11339*da0073e9SAndroid Build Coastguard Worker torch.float, torch.double, torch.int8, torch.int16, torch.int32, torch.int64 11340*da0073e9SAndroid Build Coastguard Worker ) 11341*da0073e9SAndroid Build Coastguard Worker def test_set_requires_grad_only_for_floats(self, device, dtype): 11342*da0073e9SAndroid Build Coastguard Worker def f1(): 11343*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, dtype=dtype, device=device) 11344*da0073e9SAndroid Build Coastguard Worker a.requires_grad_() 11345*da0073e9SAndroid Build Coastguard Worker 11346*da0073e9SAndroid Build Coastguard Worker def f2(): 11347*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, dtype=dtype, device=device) 11348*da0073e9SAndroid Build Coastguard Worker a.requires_grad = True 11349*da0073e9SAndroid Build Coastguard Worker 11350*da0073e9SAndroid Build Coastguard Worker def f3(): 11351*da0073e9SAndroid Build Coastguard Worker torch.ones(1, dtype=dtype, device=device, requires_grad=True) 11352*da0073e9SAndroid Build Coastguard Worker 11353*da0073e9SAndroid Build Coastguard Worker a = torch.ones(1, dtype=dtype, device=device) 11354*da0073e9SAndroid Build Coastguard Worker a.requires_grad = False # should always work 11355*da0073e9SAndroid Build Coastguard Worker a.requires_grad_(False) 11356*da0073e9SAndroid Build Coastguard Worker 11357*da0073e9SAndroid Build Coastguard Worker for f in [f1, f2, f3]: 11358*da0073e9SAndroid Build Coastguard Worker if dtype.is_floating_point: 11359*da0073e9SAndroid Build Coastguard Worker f() 11360*da0073e9SAndroid Build Coastguard Worker else: 11361*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 11362*da0073e9SAndroid Build Coastguard Worker RuntimeError, 11363*da0073e9SAndroid Build Coastguard Worker "floating point", 11364*da0073e9SAndroid Build Coastguard Worker msg=f"dt: {a.dtype} device: {a.device}", 11365*da0073e9SAndroid Build Coastguard Worker ): 11366*da0073e9SAndroid Build Coastguard Worker f() 11367*da0073e9SAndroid Build Coastguard Worker 11368*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11369*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_backwards_large(self, device): 11370*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/22843 11371*da0073e9SAndroid Build Coastguard Worker n = 1 << 16 11372*da0073e9SAndroid Build Coastguard Worker x = torch.rand(n, 1, device=device, requires_grad=True) 11373*da0073e9SAndroid Build Coastguard Worker a = x[:, [0]] 11374*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 11375*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(n, 1, device=device)) 11376*da0073e9SAndroid Build Coastguard Worker 11377*da0073e9SAndroid Build Coastguard Worker def test_advanced_indexing_backwards_memory_format(self, device): 11378*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/36956 11379*da0073e9SAndroid Build Coastguard Worker shape = (2, 8, 1, 2) 11380*da0073e9SAndroid Build Coastguard Worker i = torch.randint(1, shape, device=device).contiguous( 11381*da0073e9SAndroid Build Coastguard Worker memory_format=torch.channels_last 11382*da0073e9SAndroid Build Coastguard Worker ) 11383*da0073e9SAndroid Build Coastguard Worker x = torch.randn(shape, requires_grad=True, device=device) 11384*da0073e9SAndroid Build Coastguard Worker x[i].sum().backward() 11385*da0073e9SAndroid Build Coastguard Worker 11386*da0073e9SAndroid Build Coastguard Worker def _test_reentrant_parent_error_on_cpu(self, device): 11387*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand([3, 3], requires_grad=True) 11388*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand([3, 3], device=device, requires_grad=True) 11389*da0073e9SAndroid Build Coastguard Worker t3 = torch.rand([3, 3], device=device, requires_grad=True) 11390*da0073e9SAndroid Build Coastguard Worker 11391*da0073e9SAndroid Build Coastguard Worker # Parent graph cpu graph. 11392*da0073e9SAndroid Build Coastguard Worker t4 = t1 * t1 11393*da0073e9SAndroid Build Coastguard Worker t5 = TestAutograd.SimulateBackwardError.apply(t4) 11394*da0073e9SAndroid Build Coastguard Worker 11395*da0073e9SAndroid Build Coastguard Worker # Child gpu graph (much longer than parent graph). 11396*da0073e9SAndroid Build Coastguard Worker prev = t2 * t2 11397*da0073e9SAndroid Build Coastguard Worker for i in range(10): 11398*da0073e9SAndroid Build Coastguard Worker prev = prev * t2 11399*da0073e9SAndroid Build Coastguard Worker reentrant_root = prev 11400*da0073e9SAndroid Build Coastguard Worker 11401*da0073e9SAndroid Build Coastguard Worker class ReentrantFunc(Function): 11402*da0073e9SAndroid Build Coastguard Worker @staticmethod 11403*da0073e9SAndroid Build Coastguard Worker def forward(ctx, inp): 11404*da0073e9SAndroid Build Coastguard Worker return inp.clone() 11405*da0073e9SAndroid Build Coastguard Worker 11406*da0073e9SAndroid Build Coastguard Worker @staticmethod 11407*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 11408*da0073e9SAndroid Build Coastguard Worker # Reentrant backward in child will take much longer. 11409*da0073e9SAndroid Build Coastguard Worker reentrant_root.backward() 11410*da0073e9SAndroid Build Coastguard Worker return grad 11411*da0073e9SAndroid Build Coastguard Worker 11412*da0073e9SAndroid Build Coastguard Worker # Parent gpu graph. 11413*da0073e9SAndroid Build Coastguard Worker t6 = ReentrantFunc.apply(t3) 11414*da0073e9SAndroid Build Coastguard Worker t7 = t6 * t6 11415*da0073e9SAndroid Build Coastguard Worker 11416*da0073e9SAndroid Build Coastguard Worker # Parent graph will error out first, while child graph will continue executing. 11417*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(Exception, "Simulate error"): 11418*da0073e9SAndroid Build Coastguard Worker torch.autograd.backward([t5.sum(), t7.sum()]) 11419*da0073e9SAndroid Build Coastguard Worker 11420*da0073e9SAndroid Build Coastguard Worker # No grads should be accumulated since child graph will stop execution 11421*da0073e9SAndroid Build Coastguard Worker # after parent receives error. 11422*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t2.grad) 11423*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t1.grad) 11424*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t3.grad) 11425*da0073e9SAndroid Build Coastguard Worker 11426*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11427*da0073e9SAndroid Build Coastguard Worker def test_reentrant_parent_error_on_cpu(self, device): 11428*da0073e9SAndroid Build Coastguard Worker def _get_cuda_memory_usage(): 11429*da0073e9SAndroid Build Coastguard Worker # we don't need CUDA synchronize because the statistics are not tracked at 11430*da0073e9SAndroid Build Coastguard Worker # actual freeing, but at when marking the block as free. 11431*da0073e9SAndroid Build Coastguard Worker num_devices = torch.cuda.device_count() 11432*da0073e9SAndroid Build Coastguard Worker gc.collect() 11433*da0073e9SAndroid Build Coastguard Worker return tuple(torch.cuda.memory_allocated(i) for i in range(num_devices)) 11434*da0073e9SAndroid Build Coastguard Worker 11435*da0073e9SAndroid Build Coastguard Worker before = _get_cuda_memory_usage() 11436*da0073e9SAndroid Build Coastguard Worker 11437*da0073e9SAndroid Build Coastguard Worker # Run as separate function so that gc can clean up everything when we 11438*da0073e9SAndroid Build Coastguard Worker # check for memory usage. 11439*da0073e9SAndroid Build Coastguard Worker self._test_reentrant_parent_error_on_cpu(device) 11440*da0073e9SAndroid Build Coastguard Worker 11441*da0073e9SAndroid Build Coastguard Worker # Wait for autograd thread to cleanup failed tasks. 11442*da0073e9SAndroid Build Coastguard Worker after = _get_cuda_memory_usage() 11443*da0073e9SAndroid Build Coastguard Worker start = time.time() 11444*da0073e9SAndroid Build Coastguard Worker while before != after and time.time() - start < 30: 11445*da0073e9SAndroid Build Coastguard Worker time.sleep(0.1) 11446*da0073e9SAndroid Build Coastguard Worker after = _get_cuda_memory_usage() 11447*da0073e9SAndroid Build Coastguard Worker 11448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(before, after) 11449*da0073e9SAndroid Build Coastguard Worker 11450*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS 11451*da0073e9SAndroid Build Coastguard Worker # TODO: see if these tests can be ported to OpInfos or moved to where's test suite 11452*da0073e9SAndroid Build Coastguard Worker def test_where_functional(self, device): 11453*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True) 11454*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True) 11455*da0073e9SAndroid Build Coastguard Worker cond = mask_not_all_zeros((5, 5)).to(device=device) 11456*da0073e9SAndroid Build Coastguard Worker 11457*da0073e9SAndroid Build Coastguard Worker def where(cond, x, y): 11458*da0073e9SAndroid Build Coastguard Worker return torch.where(cond, x, y) 11459*da0073e9SAndroid Build Coastguard Worker 11460*da0073e9SAndroid Build Coastguard Worker gradcheck(where, [cond, x, y], raise_exception=True) 11461*da0073e9SAndroid Build Coastguard Worker gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, device=device)]) 11462*da0073e9SAndroid Build Coastguard Worker 11463*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 1, 5, dtype=torch.double, device=device, requires_grad=True) 11464*da0073e9SAndroid Build Coastguard Worker y = torch.randn(5, 5, 1, dtype=torch.double, device=device, requires_grad=True) 11465*da0073e9SAndroid Build Coastguard Worker gradcheck(where, [cond, x, y], raise_exception=True) 11466*da0073e9SAndroid Build Coastguard Worker gradgradcheck(where, [cond, x, y], [torch.randn(5, 5, 5, device=device)]) 11467*da0073e9SAndroid Build Coastguard Worker 11468*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS 11469*da0073e9SAndroid Build Coastguard Worker def test_where_scalar(self, device): 11470*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.double, device=device, requires_grad=True) 11471*da0073e9SAndroid Build Coastguard Worker scalar = 4.0 11472*da0073e9SAndroid Build Coastguard Worker cond = mask_not_all_zeros((5, 5)).to(device=device) 11473*da0073e9SAndroid Build Coastguard Worker 11474*da0073e9SAndroid Build Coastguard Worker def where_scalar_first(cond, x): 11475*da0073e9SAndroid Build Coastguard Worker return torch.where(cond, scalar, x) 11476*da0073e9SAndroid Build Coastguard Worker 11477*da0073e9SAndroid Build Coastguard Worker def where_scalar_second(cond, x): 11478*da0073e9SAndroid Build Coastguard Worker return torch.where(cond, x, scalar) 11479*da0073e9SAndroid Build Coastguard Worker 11480*da0073e9SAndroid Build Coastguard Worker gradcheck(where_scalar_first, (cond, x)) 11481*da0073e9SAndroid Build Coastguard Worker gradgradcheck(where_scalar_first, (cond, x)) 11482*da0073e9SAndroid Build Coastguard Worker 11483*da0073e9SAndroid Build Coastguard Worker gradcheck(where_scalar_second, (cond, x)) 11484*da0073e9SAndroid Build Coastguard Worker gradgradcheck(where_scalar_second, (cond, x)) 11485*da0073e9SAndroid Build Coastguard Worker 11486*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11487*da0073e9SAndroid Build Coastguard Worker def test_free_unneeded_tensor(self, device): 11488*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, 10, 10, device=device, requires_grad=True) 11489*da0073e9SAndroid Build Coastguard Worker m = torch.randn(1, 3, 1, 1, device=device) 11490*da0073e9SAndroid Build Coastguard Worker 11491*da0073e9SAndroid Build Coastguard Worker z = x.sum() 11492*da0073e9SAndroid Build Coastguard Worker base_mem = torch.cuda.memory_allocated() 11493*da0073e9SAndroid Build Coastguard Worker z = ((x + 2) * m).sum() 11494*da0073e9SAndroid Build Coastguard Worker end_mem = torch.cuda.memory_allocated() 11495*da0073e9SAndroid Build Coastguard Worker 11496*da0073e9SAndroid Build Coastguard Worker # In the end the memory usage should remain equal, because neither of 11497*da0073e9SAndroid Build Coastguard Worker # (x + 2) and ((x + 2) * m) should be kept alive for backward, while the 11498*da0073e9SAndroid Build Coastguard Worker # previous allocation of z had the same size as the current one. 11499*da0073e9SAndroid Build Coastguard Worker self.assertEqual(base_mem, end_mem) 11500*da0073e9SAndroid Build Coastguard Worker 11501*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11502*da0073e9SAndroid Build Coastguard Worker def test_pin_memory(self, device): 11503*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, dtype=torch.double, requires_grad=True) 11504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x.pin_memory()) 11505*da0073e9SAndroid Build Coastguard Worker self.assertIsNot(x, x.pin_memory()) 11506*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.pin_memory().requires_grad) 11507*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x.pin_memory(), [x]) 11508*da0073e9SAndroid Build Coastguard Worker gradgradcheck(lambda x: x.pin_memory(), [x]) 11509*da0073e9SAndroid Build Coastguard Worker 11510*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11511*da0073e9SAndroid Build Coastguard Worker def test_profiler_emit_nvtx(self, device): 11512*da0073e9SAndroid Build Coastguard Worker # This test is not intended to ensure correctness of nvtx ranges. 11513*da0073e9SAndroid Build Coastguard Worker # That would require something a great deal more complex (you'd have to create a 11514*da0073e9SAndroid Build Coastguard Worker # profile in a subprocess, open it, and parse the sql somehow). 11515*da0073e9SAndroid Build Coastguard Worker # This test is merely intended to catch if emit_nvtx breaks on construction. 11516*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device) 11517*da0073e9SAndroid Build Coastguard Worker with torch.cuda.profiler.profile(): 11518*da0073e9SAndroid Build Coastguard Worker with emit_nvtx(): 11519*da0073e9SAndroid Build Coastguard Worker a.add(1.0) 11520*da0073e9SAndroid Build Coastguard Worker 11521*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11522*da0073e9SAndroid Build Coastguard Worker def test_rnn_backward_to_input_but_not_parameters(self, device): 11523*da0073e9SAndroid Build Coastguard Worker # this checks whether it is possible to not require 11524*da0073e9SAndroid Build Coastguard Worker # weight parameters, but require inputs, see #7722 11525*da0073e9SAndroid Build Coastguard Worker l = torch.nn.LSTM(2, 3).to(device) 11526*da0073e9SAndroid Build Coastguard Worker for p in l.parameters(): 11527*da0073e9SAndroid Build Coastguard Worker p.requires_grad = False 11528*da0073e9SAndroid Build Coastguard Worker s = torch.randn(1, 1, 2, requires_grad=True, device=device) 11529*da0073e9SAndroid Build Coastguard Worker out, _ = l(s) 11530*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 11531*da0073e9SAndroid Build Coastguard Worker self.assertFalse(s.grad is None or s.grad.abs().sum().item() == 0) 11532*da0073e9SAndroid Build Coastguard Worker 11533*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not torch.profiler.itt.is_available(), "ITT is required") 11534*da0073e9SAndroid Build Coastguard Worker def test_profiler_emit_itt(self, device): 11535*da0073e9SAndroid Build Coastguard Worker # This test is not intended to ensure correctness of itt ranges. 11536*da0073e9SAndroid Build Coastguard Worker # That would require something a great deal more complex (you'd have to create a 11537*da0073e9SAndroid Build Coastguard Worker # profile in a subprocess, open it, and parse the sql somehow). 11538*da0073e9SAndroid Build Coastguard Worker # This test is merely intended to catch if emit_itt breaks on construction. 11539*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1, 2, 3], dtype=torch.float32, device=device) 11540*da0073e9SAndroid Build Coastguard Worker with emit_itt(): 11541*da0073e9SAndroid Build Coastguard Worker a.add(1.0) 11542*da0073e9SAndroid Build Coastguard Worker 11543*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work as randn is not supported with type long 11544*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 11545*da0073e9SAndroid Build Coastguard Worker def test_grad_assignment(self, devices): 11546*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device=devices[0]) 11547*da0073e9SAndroid Build Coastguard Worker 11548*da0073e9SAndroid Build Coastguard Worker # Tests that the wrong type raises 11549*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "expected to be a Tensor or None"): 11550*da0073e9SAndroid Build Coastguard Worker x.grad = 0 11551*da0073e9SAndroid Build Coastguard Worker 11552*da0073e9SAndroid Build Coastguard Worker # Tests that the wrong shape raises 11553*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11554*da0073e9SAndroid Build Coastguard Worker x.grad = torch.randn(2, 2, device=devices[0]) 11555*da0073e9SAndroid Build Coastguard Worker 11556*da0073e9SAndroid Build Coastguard Worker # Tests that the wrong dtype raises 11557*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11558*da0073e9SAndroid Build Coastguard Worker x.grad = torch.randn(5, 5, dtype=torch.long, device=devices[0]) 11559*da0073e9SAndroid Build Coastguard Worker 11560*da0073e9SAndroid Build Coastguard Worker # Tests that self-assignment raises 11561*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11562*da0073e9SAndroid Build Coastguard Worker x.grad = x 11563*da0073e9SAndroid Build Coastguard Worker 11564*da0073e9SAndroid Build Coastguard Worker # Tests device -> cpu grad assignment raises 11565*da0073e9SAndroid Build Coastguard Worker if self.device_type != "cpu": 11566*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11567*da0073e9SAndroid Build Coastguard Worker t_cpu = torch.rand(5, 5) 11568*da0073e9SAndroid Build Coastguard Worker t_cpu.grad = torch.randn(5, 5, device=devices[0]) 11569*da0073e9SAndroid Build Coastguard Worker 11570*da0073e9SAndroid Build Coastguard Worker # Tests half type on CUDA 11571*da0073e9SAndroid Build Coastguard Worker if self.device_type == "cuda": 11572*da0073e9SAndroid Build Coastguard Worker x = x.to(dtype=torch.half, device=devices[0]) 11573*da0073e9SAndroid Build Coastguard Worker x.grad = torch.zeros_like(x) 11574*da0073e9SAndroid Build Coastguard Worker 11575*da0073e9SAndroid Build Coastguard Worker # Tests cross-device assignment raises 11576*da0073e9SAndroid Build Coastguard Worker if len(devices) > 1: 11577*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, device=devices[0]) 11578*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11579*da0073e9SAndroid Build Coastguard Worker x.grad = torch.randn(5, 5, device=devices[1]) 11580*da0073e9SAndroid Build Coastguard Worker 11581*da0073e9SAndroid Build Coastguard Worker @dtypesIfMPS(torch.float32) 11582*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(1) 11583*da0073e9SAndroid Build Coastguard Worker @dtypes(torch.float, torch.double) 11584*da0073e9SAndroid Build Coastguard Worker def test_requires_grad_factory(self, devices, dtype): 11585*da0073e9SAndroid Build Coastguard Worker fns = [torch.ones_like, torch.randn_like] 11586*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 3, dtype=dtype, device=devices[0]) 11587*da0073e9SAndroid Build Coastguard Worker 11588*da0073e9SAndroid Build Coastguard Worker for fn in fns: 11589*da0073e9SAndroid Build Coastguard Worker for requires_grad in [True, False]: 11590*da0073e9SAndroid Build Coastguard Worker output = fn( 11591*da0073e9SAndroid Build Coastguard Worker x, dtype=dtype, device=devices[0], requires_grad=requires_grad 11592*da0073e9SAndroid Build Coastguard Worker ) 11593*da0073e9SAndroid Build Coastguard Worker self.assertEqual(requires_grad, output.requires_grad) 11594*da0073e9SAndroid Build Coastguard Worker self.assertIs(dtype, output.dtype) 11595*da0073e9SAndroid Build Coastguard Worker self.assertEqual(devices[0], str(x.device)) 11596*da0073e9SAndroid Build Coastguard Worker 11597*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 11598*da0073e9SAndroid Build Coastguard Worker def test_unused_output_device(self, devices): 11599*da0073e9SAndroid Build Coastguard Worker from torch.nn.parallel._functions import Broadcast 11600*da0073e9SAndroid Build Coastguard Worker 11601*da0073e9SAndroid Build Coastguard Worker x = torch.randn(5, 5, dtype=torch.float, device=devices[0], requires_grad=True) 11602*da0073e9SAndroid Build Coastguard Worker outputs = Broadcast.apply(list(range(len(devices))), x) 11603*da0073e9SAndroid Build Coastguard Worker y = outputs[-1] * 2 11604*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 11605*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones(5, 5) * 2) 11606*da0073e9SAndroid Build Coastguard Worker 11607*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 11608*da0073e9SAndroid Build Coastguard Worker def test_backward_device(self, devices): 11609*da0073e9SAndroid Build Coastguard Worker # check that current device matches the variable's device 11610*da0073e9SAndroid Build Coastguard Worker device = [None] 11611*da0073e9SAndroid Build Coastguard Worker 11612*da0073e9SAndroid Build Coastguard Worker class Identity(torch.autograd.Function): 11613*da0073e9SAndroid Build Coastguard Worker @staticmethod 11614*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 11615*da0073e9SAndroid Build Coastguard Worker return x.clone() 11616*da0073e9SAndroid Build Coastguard Worker 11617*da0073e9SAndroid Build Coastguard Worker @staticmethod 11618*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 11619*da0073e9SAndroid Build Coastguard Worker device[0] = grad_output.device 11620*da0073e9SAndroid Build Coastguard Worker return grad_output.clone() 11621*da0073e9SAndroid Build Coastguard Worker 11622*da0073e9SAndroid Build Coastguard Worker v = torch.randn(1, device=devices[1], requires_grad=True) 11623*da0073e9SAndroid Build Coastguard Worker Identity.apply(v).backward() 11624*da0073e9SAndroid Build Coastguard Worker self.assertEqual(str(device[0]), devices[1]) 11625*da0073e9SAndroid Build Coastguard Worker 11626*da0073e9SAndroid Build Coastguard Worker @deviceCountAtLeast(2) 11627*da0073e9SAndroid Build Coastguard Worker def test_inputbuffer_add_multidevice(self, devices): 11628*da0073e9SAndroid Build Coastguard Worker input = torch.randn(1, device=devices[0], requires_grad=True) 11629*da0073e9SAndroid Build Coastguard Worker output = input.to(device=devices[1]) + input.to(device=devices[1]) 11630*da0073e9SAndroid Build Coastguard Worker output.backward() 11631*da0073e9SAndroid Build Coastguard Worker 11632*da0073e9SAndroid Build Coastguard Worker @onlyCPU 11633*da0073e9SAndroid Build Coastguard Worker def test_copy_(self, device): 11634*da0073e9SAndroid Build Coastguard Worker # At the time of writing this test, copy_ is not generated from native_functions.yaml 11635*da0073e9SAndroid Build Coastguard Worker # there was a bug that bfloat16 was not recognized as floating. 11636*da0073e9SAndroid Build Coastguard Worker x = torch.randn(10, device=device, requires_grad=True) 11637*da0073e9SAndroid Build Coastguard Worker floating_dt = floating_types_and(torch.half, torch.bfloat16) 11638*da0073e9SAndroid Build Coastguard Worker for dt in floating_dt: 11639*da0073e9SAndroid Build Coastguard Worker y = torch.empty(10, device=device, dtype=dt) 11640*da0073e9SAndroid Build Coastguard Worker y.copy_(x) 11641*da0073e9SAndroid Build Coastguard Worker self.assertTrue(y.requires_grad) 11642*da0073e9SAndroid Build Coastguard Worker z = x.to(torch.bfloat16) 11643*da0073e9SAndroid Build Coastguard Worker self.assertTrue(z.requires_grad) 11644*da0073e9SAndroid Build Coastguard Worker 11645*da0073e9SAndroid Build Coastguard Worker def test_copy_forward_ad_broadcasting(self, device): 11646*da0073e9SAndroid Build Coastguard Worker # copy_ allows the src to have a different shape from self as long as src is 11647*da0073e9SAndroid Build Coastguard Worker # broadcastable to self. Make sure forward AD handles this case. 11648*da0073e9SAndroid Build Coastguard Worker primal = torch.rand(3, 3, device=device) 11649*da0073e9SAndroid Build Coastguard Worker tangent = torch.rand(3, 3, device=device) 11650*da0073e9SAndroid Build Coastguard Worker non_dual = torch.rand(1, 3, 3, device=device) 11651*da0073e9SAndroid Build Coastguard Worker 11652*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 11653*da0073e9SAndroid Build Coastguard Worker dual = fwAD.make_dual(primal, tangent) 11654*da0073e9SAndroid Build Coastguard Worker non_dual.copy_(dual) 11655*da0073e9SAndroid Build Coastguard Worker 11656*da0073e9SAndroid Build Coastguard Worker def test_copy_forward_ad_same_layout_copies_grad(self, device): 11657*da0073e9SAndroid Build Coastguard Worker primal = torch.tensor([[3.0], [4.0]], device=device) 11658*da0073e9SAndroid Build Coastguard Worker tangent = torch.tensor([[5.0], [6.0]], device=device) 11659*da0073e9SAndroid Build Coastguard Worker 11660*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 11661*da0073e9SAndroid Build Coastguard Worker x_dual = fwAD.make_dual(primal, tangent) 11662*da0073e9SAndroid Build Coastguard Worker non_dual = torch.tensor([[1.0], [2.0]]) 11663*da0073e9SAndroid Build Coastguard Worker non_dual.copy_(x_dual) 11664*da0073e9SAndroid Build Coastguard Worker self.assertTrue(fwAD.unpack_dual(non_dual).tangent is not tangent) 11665*da0073e9SAndroid Build Coastguard Worker 11666*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11667*da0073e9SAndroid Build Coastguard Worker def test_simple_reentrant_cross_device(self, device): 11668*da0073e9SAndroid Build Coastguard Worker class ReentrantFunc(Function): 11669*da0073e9SAndroid Build Coastguard Worker _cpu_mode = True 11670*da0073e9SAndroid Build Coastguard Worker 11671*da0073e9SAndroid Build Coastguard Worker @staticmethod 11672*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 11673*da0073e9SAndroid Build Coastguard Worker return x * (x + 2) 11674*da0073e9SAndroid Build Coastguard Worker 11675*da0073e9SAndroid Build Coastguard Worker @staticmethod 11676*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 11677*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 11678*da0073e9SAndroid Build Coastguard Worker if ReentrantFunc._cpu_mode: 11679*da0073e9SAndroid Build Coastguard Worker new_param = torch.randn(2, 2, requires_grad=True) 11680*da0073e9SAndroid Build Coastguard Worker (new_param**2).sum().backward() 11681*da0073e9SAndroid Build Coastguard Worker else: 11682*da0073e9SAndroid Build Coastguard Worker new_param = torch.randn(2, 2, device=device, requires_grad=True) 11683*da0073e9SAndroid Build Coastguard Worker (new_param**2).sum().backward() 11684*da0073e9SAndroid Build Coastguard Worker return grad_output 11685*da0073e9SAndroid Build Coastguard Worker 11686*da0073e9SAndroid Build Coastguard Worker # Reentrant starts on GPU thread, finishs on GPU thread 11687*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, device=device, requires_grad=True) 11688*da0073e9SAndroid Build Coastguard Worker out = ReentrantFunc.apply(x) 11689*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 11690*da0073e9SAndroid Build Coastguard Worker 11691*da0073e9SAndroid Build Coastguard Worker # Reentrant starts on CPU thread, finishs on GPU thread 11692*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, requires_grad=True) 11693*da0073e9SAndroid Build Coastguard Worker # set ReentrantFunc node to GPU to emit tasks to GPU queue 11694*da0073e9SAndroid Build Coastguard Worker ReentrantFunc._cpu_mode = False 11695*da0073e9SAndroid Build Coastguard Worker out = ReentrantFunc.apply(x) 11696*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 11697*da0073e9SAndroid Build Coastguard Worker 11698*da0073e9SAndroid Build Coastguard Worker # Reentrant starts on GPU thread, finishs on CPU thread 11699*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2, device=device, requires_grad=True) 11700*da0073e9SAndroid Build Coastguard Worker # set ReentrantFunc node to CPU to emit tasks to CPU queue 11701*da0073e9SAndroid Build Coastguard Worker ReentrantFunc._cpu_mode = True 11702*da0073e9SAndroid Build Coastguard Worker out = ReentrantFunc.apply(x) 11703*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 11704*da0073e9SAndroid Build Coastguard Worker 11705*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11706*da0073e9SAndroid Build Coastguard Worker def test_cross_device_reentrant_autograd(self, device): 11707*da0073e9SAndroid Build Coastguard Worker # Output on gpu so that this task will be associated with the gpu thread 11708*da0073e9SAndroid Build Coastguard Worker def fn_on_gpu(inp): 11709*da0073e9SAndroid Build Coastguard Worker # Artificially increase the priority of the next op to make sure it runs 11710*da0073e9SAndroid Build Coastguard Worker # as soon as we reach it before the ops of branch1. 11711*da0073e9SAndroid Build Coastguard Worker dummy = inp * 2 * 2 * 2 * 2 11712*da0073e9SAndroid Build Coastguard Worker return inp.to(device=device) 11713*da0073e9SAndroid Build Coastguard Worker 11714*da0073e9SAndroid Build Coastguard Worker def parent_on_cpu(inp): 11715*da0073e9SAndroid Build Coastguard Worker # Slow branch of ops on gpu so that the work queue for the gpu thread 11716*da0073e9SAndroid Build Coastguard Worker # won't empty too quickly. They also have smaller priorities than the 11717*da0073e9SAndroid Build Coastguard Worker # ones created by fn_on_gpu 11718*da0073e9SAndroid Build Coastguard Worker branch1 = inp.to(device=device) 11719*da0073e9SAndroid Build Coastguard Worker branch1 = branch1 / branch1 11720*da0073e9SAndroid Build Coastguard Worker branch1 = branch1 / branch1 11721*da0073e9SAndroid Build Coastguard Worker branch1 = branch1 / branch1 11722*da0073e9SAndroid Build Coastguard Worker # Perform checkpoint on cpu tensors. So the last op performed in the reentrant 11723*da0073e9SAndroid Build Coastguard Worker # autograd is an AccumulateGrad that runs on the cpu thread for the gpu thread. 11724*da0073e9SAndroid Build Coastguard Worker # So the cpu thread will notify the gpu thread with an empty NodeTask. 11725*da0073e9SAndroid Build Coastguard Worker branch2 = checkpoint(fn_on_gpu, inp, use_reentrant=True) 11726*da0073e9SAndroid Build Coastguard Worker out = branch2 + branch1 11727*da0073e9SAndroid Build Coastguard Worker return out 11728*da0073e9SAndroid Build Coastguard Worker 11729*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, requires_grad=True) 11730*da0073e9SAndroid Build Coastguard Worker out = parent_on_cpu(inp) 11731*da0073e9SAndroid Build Coastguard Worker # This will segfault if the empty NodeTask is not handled properly in the 11732*da0073e9SAndroid Build Coastguard Worker # gpu thread ReadyQueue 11733*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 11734*da0073e9SAndroid Build Coastguard Worker 11735*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_backprop_base(self, device): 11736*da0073e9SAndroid Build Coastguard Worker # modify view and back-prop through base 11737*da0073e9SAndroid Build Coastguard Worker root = torch.randn(2, 2, device=device, requires_grad=True) 11738*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11739*da0073e9SAndroid Build Coastguard Worker v1 = x.narrow(0, 0, 1) 11740*da0073e9SAndroid Build Coastguard Worker v1.mul_(2) 11741*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 11742*da0073e9SAndroid Build Coastguard Worker self.assertEqual(root.grad.tolist(), [[2, 2], [1, 1]]) 11743*da0073e9SAndroid Build Coastguard Worker 11744*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_backprop_view_of_view(self, device): 11745*da0073e9SAndroid Build Coastguard Worker # modify view and backprop through view-of-view 11746*da0073e9SAndroid Build Coastguard Worker root = torch.randn(2, 2, device=device, requires_grad=True) 11747*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11748*da0073e9SAndroid Build Coastguard Worker v1 = x.narrow(0, 0, 1) 11749*da0073e9SAndroid Build Coastguard Worker v2 = x.narrow(0, 0, 1) 11750*da0073e9SAndroid Build Coastguard Worker v1.mul_(2) 11751*da0073e9SAndroid Build Coastguard Worker v2.sum().backward() 11752*da0073e9SAndroid Build Coastguard Worker self.assertEqual(root.grad.tolist(), [[2, 2], [0, 0]]) 11753*da0073e9SAndroid Build Coastguard Worker 11754*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_of_view(self, device): 11755*da0073e9SAndroid Build Coastguard Worker # modify view-of-view and backprop through base 11756*da0073e9SAndroid Build Coastguard Worker root = torch.randn(2, 2, device=device, requires_grad=True) 11757*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11758*da0073e9SAndroid Build Coastguard Worker 11759*da0073e9SAndroid Build Coastguard Worker v1 = x.narrow(0, 0, 1) 11760*da0073e9SAndroid Build Coastguard Worker v2 = v1.narrow(1, 1, 1) 11761*da0073e9SAndroid Build Coastguard Worker v2.mul_(2) 11762*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 11763*da0073e9SAndroid Build Coastguard Worker self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1]]) 11764*da0073e9SAndroid Build Coastguard Worker 11765*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11766*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_then_no_grad(self, device): 11767*da0073e9SAndroid Build Coastguard Worker # Perform an in-place operation on a view of a non-leaf variable. 11768*da0073e9SAndroid Build Coastguard Worker a = torch.ones(3, 1, dtype=torch.double, device=device, requires_grad=True) 11769*da0073e9SAndroid Build Coastguard Worker b = a * 2 11770*da0073e9SAndroid Build Coastguard Worker c = b.view_as(b) 11771*da0073e9SAndroid Build Coastguard Worker c[0][0] = 3 11772*da0073e9SAndroid Build Coastguard Worker 11773*da0073e9SAndroid Build Coastguard Worker # Force a graph update with grad disabled. 11774*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 11775*da0073e9SAndroid Build Coastguard Worker c.grad_fn 11776*da0073e9SAndroid Build Coastguard Worker 11777*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 11778*da0073e9SAndroid Build Coastguard Worker 11779*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11780*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_gradcheck(self, device): 11781*da0073e9SAndroid Build Coastguard Worker # gradcheck modifications to views 11782*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True) 11783*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True) 11784*da0073e9SAndroid Build Coastguard Worker 11785*da0073e9SAndroid Build Coastguard Worker def func(root, b): 11786*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11787*da0073e9SAndroid Build Coastguard Worker x.narrow(1, 2, 2).narrow(0, 1, 2).mul_(b) 11788*da0073e9SAndroid Build Coastguard Worker x.narrow(1, 0, 2).narrow(0, 1, 2).mul_(b) 11789*da0073e9SAndroid Build Coastguard Worker return x 11790*da0073e9SAndroid Build Coastguard Worker 11791*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [a, b], raise_exception=True) 11792*da0073e9SAndroid Build Coastguard Worker go = torch.randn( 11793*da0073e9SAndroid Build Coastguard Worker a.size(), dtype=torch.double, device=device, requires_grad=True 11794*da0073e9SAndroid Build Coastguard Worker ) 11795*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, (a, b), (go,)) 11796*da0073e9SAndroid Build Coastguard Worker 11797*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_multiple_outputs(self, device): 11798*da0073e9SAndroid Build Coastguard Worker root = torch.arange(9.0, dtype=torch.double).reshape(3, 3).requires_grad_() 11799*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11800*da0073e9SAndroid Build Coastguard Worker v1 = x.unbind() 11801*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11802*da0073e9SAndroid Build Coastguard Worker v1[0].mul_(2) 11803*da0073e9SAndroid Build Coastguard Worker 11804*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11805*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_of_multiple_output_view(self, device): 11806*da0073e9SAndroid Build Coastguard Worker a = torch.rand( 11807*da0073e9SAndroid Build Coastguard Worker 10, dtype=torch.double, device=device, requires_grad=True 11808*da0073e9SAndroid Build Coastguard Worker ).clone() 11809*da0073e9SAndroid Build Coastguard Worker b = a.unbind(0) 11810*da0073e9SAndroid Build Coastguard Worker c = b[0].view_as(b[0]) 11811*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11812*da0073e9SAndroid Build Coastguard Worker c.mul_(2) 11813*da0073e9SAndroid Build Coastguard Worker 11814*da0073e9SAndroid Build Coastguard Worker @skipIfMps # MPS backend doesn't support double types 11815*da0073e9SAndroid Build Coastguard Worker def test_inplace_multiple_output_view_of_view(self, device): 11816*da0073e9SAndroid Build Coastguard Worker a = torch.rand( 11817*da0073e9SAndroid Build Coastguard Worker 10, dtype=torch.double, device=device, requires_grad=True 11818*da0073e9SAndroid Build Coastguard Worker ).clone() 11819*da0073e9SAndroid Build Coastguard Worker b = a.view_as(a) 11820*da0073e9SAndroid Build Coastguard Worker c = b.unbind(0) 11821*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(RuntimeError): 11822*da0073e9SAndroid Build Coastguard Worker c[0].mul_(2) 11823*da0073e9SAndroid Build Coastguard Worker 11824*da0073e9SAndroid Build Coastguard Worker @skipIfMps # MPS backend doesn't support double types 11825*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_makes_base_require_grad(self, device): 11826*da0073e9SAndroid Build Coastguard Worker # in-place modification to view makes base require grad 11827*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=False) 11828*da0073e9SAndroid Build Coastguard Worker b = torch.randn(4, 2, dtype=torch.double, device=device, requires_grad=True) 11829*da0073e9SAndroid Build Coastguard Worker 11830*da0073e9SAndroid Build Coastguard Worker def func(root, b): 11831*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11832*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x.requires_grad) 11833*da0073e9SAndroid Build Coastguard Worker x.narrow(1, 2, 2).mul_(b) 11834*da0073e9SAndroid Build Coastguard Worker self.assertTrue(x.requires_grad) 11835*da0073e9SAndroid Build Coastguard Worker return x 11836*da0073e9SAndroid Build Coastguard Worker 11837*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [a, b], raise_exception=True) 11838*da0073e9SAndroid Build Coastguard Worker go = torch.randn( 11839*da0073e9SAndroid Build Coastguard Worker a.size(), dtype=torch.double, device=device, requires_grad=True 11840*da0073e9SAndroid Build Coastguard Worker ) 11841*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, (a, b), (go,)) 11842*da0073e9SAndroid Build Coastguard Worker 11843*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_backprop_view(self, device): 11844*da0073e9SAndroid Build Coastguard Worker # modify view and backprop through view 11845*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([2.0, 5.0], device=device, requires_grad=False) 11846*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([3.0], device=device, requires_grad=True) 11847*da0073e9SAndroid Build Coastguard Worker res = a.narrow(0, 1, 1).mul_(b) 11848*da0073e9SAndroid Build Coastguard Worker res.sum().backward() 11849*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.tolist(), [5]) 11850*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(a.grad) 11851*da0073e9SAndroid Build Coastguard Worker 11852*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11853*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_modify_base(self, device): 11854*da0073e9SAndroid Build Coastguard Worker # Test that an in-place operation on a base that forced it to require 11855*da0073e9SAndroid Build Coastguard Worker # grad also forces any previous views to require grad and backprop 11856*da0073e9SAndroid Build Coastguard Worker # correctly 11857*da0073e9SAndroid Build Coastguard Worker r = torch.ones(1, dtype=torch.double, device=device, requires_grad=True) 11858*da0073e9SAndroid Build Coastguard Worker 11859*da0073e9SAndroid Build Coastguard Worker def fn(r): 11860*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, dtype=torch.double, device=device) 11861*da0073e9SAndroid Build Coastguard Worker v = x.select(0, 1) 11862*da0073e9SAndroid Build Coastguard Worker self.assertFalse(v.requires_grad) 11863*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(v.grad_fn) 11864*da0073e9SAndroid Build Coastguard Worker x.add_(r) # v is now dependent on r due to the in-place op on x 11865*da0073e9SAndroid Build Coastguard Worker self.assertTrue(v.requires_grad) 11866*da0073e9SAndroid Build Coastguard Worker return v 11867*da0073e9SAndroid Build Coastguard Worker 11868*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, [r]) 11869*da0073e9SAndroid Build Coastguard Worker gradgradcheck(fn, [r]) 11870*da0073e9SAndroid Build Coastguard Worker 11871*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11872*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_python(self, device): 11873*da0073e9SAndroid Build Coastguard Worker # in-place modifications of Python-autograd created view 11874*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 4, dtype=torch.double, device=device, requires_grad=True) 11875*da0073e9SAndroid Build Coastguard Worker b = torch.randn(2, 2, dtype=torch.double, device=device, requires_grad=True) 11876*da0073e9SAndroid Build Coastguard Worker 11877*da0073e9SAndroid Build Coastguard Worker class PyAdd(torch.autograd.Function): 11878*da0073e9SAndroid Build Coastguard Worker @staticmethod 11879*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, y): 11880*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(x) 11881*da0073e9SAndroid Build Coastguard Worker x.add_(y) 11882*da0073e9SAndroid Build Coastguard Worker return x 11883*da0073e9SAndroid Build Coastguard Worker 11884*da0073e9SAndroid Build Coastguard Worker @staticmethod 11885*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 11886*da0073e9SAndroid Build Coastguard Worker return grad, grad 11887*da0073e9SAndroid Build Coastguard Worker 11888*da0073e9SAndroid Build Coastguard Worker def func(root, b): 11889*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11890*da0073e9SAndroid Build Coastguard Worker PyAdd.apply(x.narrow(1, 2, 2).narrow(0, 1, 2), b) 11891*da0073e9SAndroid Build Coastguard Worker PyAdd.apply(x.narrow(1, 0, 2).narrow(0, 1, 2), b) 11892*da0073e9SAndroid Build Coastguard Worker return x 11893*da0073e9SAndroid Build Coastguard Worker 11894*da0073e9SAndroid Build Coastguard Worker gradcheck(func, [a, b], raise_exception=True) 11895*da0073e9SAndroid Build Coastguard Worker go = torch.randn( 11896*da0073e9SAndroid Build Coastguard Worker a.size(), dtype=torch.double, device=device, requires_grad=True 11897*da0073e9SAndroid Build Coastguard Worker ) 11898*da0073e9SAndroid Build Coastguard Worker gradgradcheck(func, (a, b), (go,)) 11899*da0073e9SAndroid Build Coastguard Worker 11900*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_non_contig(self, device): 11901*da0073e9SAndroid Build Coastguard Worker root = torch.ones(2, 3, 2, device=device).select(2, 1).t().requires_grad_(True) 11902*da0073e9SAndroid Build Coastguard Worker x = root.clone() 11903*da0073e9SAndroid Build Coastguard Worker v1 = x.narrow(0, 0, 1) 11904*da0073e9SAndroid Build Coastguard Worker v2 = v1.narrow(1, 1, 1) 11905*da0073e9SAndroid Build Coastguard Worker v2.mul_(2) 11906*da0073e9SAndroid Build Coastguard Worker x.sum().backward() 11907*da0073e9SAndroid Build Coastguard Worker self.assertEqual(root.grad.tolist(), [[1, 2], [1, 1], [1, 1]]) 11908*da0073e9SAndroid Build Coastguard Worker 11909*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_multi_output_unsafe(self, device): 11910*da0073e9SAndroid Build Coastguard Worker for f in [ 11911*da0073e9SAndroid Build Coastguard Worker lambda t: t.unsafe_split(1), 11912*da0073e9SAndroid Build Coastguard Worker lambda t: t.unsafe_split_with_sizes((1, 1, 1)), 11913*da0073e9SAndroid Build Coastguard Worker lambda t: t.unsafe_chunk(3), 11914*da0073e9SAndroid Build Coastguard Worker ]: 11915*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, device=device, requires_grad=True) 11916*da0073e9SAndroid Build Coastguard Worker b = a + a 11917*da0073e9SAndroid Build Coastguard Worker s1, s2, s3 = f(b) 11918*da0073e9SAndroid Build Coastguard Worker s1.mul_(s2) 11919*da0073e9SAndroid Build Coastguard Worker s1.sum().backward() 11920*da0073e9SAndroid Build Coastguard Worker 11921*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_multi_output_safe(self, device): 11922*da0073e9SAndroid Build Coastguard Worker for f in [ 11923*da0073e9SAndroid Build Coastguard Worker lambda t: t.split(1), 11924*da0073e9SAndroid Build Coastguard Worker lambda t: t.split_with_sizes((1, 1, 1)), 11925*da0073e9SAndroid Build Coastguard Worker lambda t: t.chunk(3), 11926*da0073e9SAndroid Build Coastguard Worker ]: 11927*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, device=device, requires_grad=True) 11928*da0073e9SAndroid Build Coastguard Worker b = a + a 11929*da0073e9SAndroid Build Coastguard Worker s1, s2, s3 = f(b) 11930*da0073e9SAndroid Build Coastguard Worker error_msg = ( 11931*da0073e9SAndroid Build Coastguard Worker "This view is the output of a function that returns multiple views." 11932*da0073e9SAndroid Build Coastguard Worker ) 11933*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, error_msg): 11934*da0073e9SAndroid Build Coastguard Worker s1.mul_(s2) 11935*da0073e9SAndroid Build Coastguard Worker 11936*da0073e9SAndroid Build Coastguard Worker def test_inplace_on_view_undefined_grad_output(self, device): 11937*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True) 11938*da0073e9SAndroid Build Coastguard Worker c = a.clone() 11939*da0073e9SAndroid Build Coastguard Worker v = c[:] 11940*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(1.0, requires_grad=True) 11941*da0073e9SAndroid Build Coastguard Worker 11942*da0073e9SAndroid Build Coastguard Worker class InplaceFunc(torch.autograd.Function): 11943*da0073e9SAndroid Build Coastguard Worker @staticmethod 11944*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, other): 11945*da0073e9SAndroid Build Coastguard Worker ctx.mark_dirty(x) 11946*da0073e9SAndroid Build Coastguard Worker return x.mul_(2) 11947*da0073e9SAndroid Build Coastguard Worker 11948*da0073e9SAndroid Build Coastguard Worker @staticmethod 11949*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad): 11950*da0073e9SAndroid Build Coastguard Worker return grad * 2, None 11951*da0073e9SAndroid Build Coastguard Worker 11952*da0073e9SAndroid Build Coastguard Worker out = InplaceFunc.apply(v, b) 11953*da0073e9SAndroid Build Coastguard Worker out.backward() 11954*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(b.grad) 11955*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad.item(), 2) 11956*da0073e9SAndroid Build Coastguard Worker 11957*da0073e9SAndroid Build Coastguard Worker @skipIfMps # the test doesn't work on MPS as double types are not supported 11958*da0073e9SAndroid Build Coastguard Worker def test_mv_grad_stride_0(self, device): 11959*da0073e9SAndroid Build Coastguard Worker # Reference: https://github.com/pytorch/pytorch/issues/38315 11960*da0073e9SAndroid Build Coastguard Worker mat = torch.randn(2, 2, dtype=torch.double, device=device) 11961*da0073e9SAndroid Build Coastguard Worker vec = torch.randn(1, dtype=torch.double, device=device).requires_grad_(True) 11962*da0073e9SAndroid Build Coastguard Worker 11963*da0073e9SAndroid Build Coastguard Worker def fn(vec): 11964*da0073e9SAndroid Build Coastguard Worker # Expand inside the function to make sure the input to 11965*da0073e9SAndroid Build Coastguard Worker # gradcheck does not have overlapping memory 11966*da0073e9SAndroid Build Coastguard Worker vec = vec.expand(2) 11967*da0073e9SAndroid Build Coastguard Worker return (mat @ vec).sum() 11968*da0073e9SAndroid Build Coastguard Worker 11969*da0073e9SAndroid Build Coastguard Worker gradcheck(fn, (vec)) 11970*da0073e9SAndroid Build Coastguard Worker gradgradcheck(fn, (vec)) 11971*da0073e9SAndroid Build Coastguard Worker 11972*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 11973*da0073e9SAndroid Build Coastguard Worker def test_gradcheck_input_output_different_device(self, device): 11974*da0073e9SAndroid Build Coastguard Worker x = torch.ones((1,), dtype=torch.double, device="cuda", requires_grad=True) 11975*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x.to("cpu"), (x,)) 11976*da0073e9SAndroid Build Coastguard Worker 11977*da0073e9SAndroid Build Coastguard Worker x = torch.ones((1,), dtype=torch.double, device="cpu", requires_grad=True) 11978*da0073e9SAndroid Build Coastguard Worker gradcheck(lambda x: x.to("cuda"), (x,)) 11979*da0073e9SAndroid Build Coastguard Worker 11980*da0073e9SAndroid Build Coastguard Worker def test_strided_leaf_grad_layout(self, device): 11981*da0073e9SAndroid Build Coastguard Worker # (1) If leaf is non-overlapping and dense, grad's layout should match its leaf. 11982*da0073e9SAndroid Build Coastguard Worker for fmt_a in (torch.contiguous_format, torch.channels_last): 11983*da0073e9SAndroid Build Coastguard Worker for fmt_b in (torch.contiguous_format, torch.channels_last): 11984*da0073e9SAndroid Build Coastguard Worker a = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_a) 11985*da0073e9SAndroid Build Coastguard Worker b = torch.rand((2, 3, 4, 5), device=device).to(memory_format=fmt_b) 11986*da0073e9SAndroid Build Coastguard Worker a.requires_grad_() 11987*da0073e9SAndroid Build Coastguard Worker b.requires_grad_() 11988*da0073e9SAndroid Build Coastguard Worker # checks (1) for broadcasted gradients 11989*da0073e9SAndroid Build Coastguard Worker a.sum().backward() 11990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad.stride(), a.stride()) 11991*da0073e9SAndroid Build Coastguard Worker b.sum().backward() 11992*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.stride(), b.stride()) 11993*da0073e9SAndroid Build Coastguard Worker # checks (1) for non-broadcasted gradients 11994*da0073e9SAndroid Build Coastguard Worker a.grad = None 11995*da0073e9SAndroid Build Coastguard Worker b.grad = None 11996*da0073e9SAndroid Build Coastguard Worker (a * b).sum().backward() 11997*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad.stride(), a.stride()) 11998*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad.stride(), b.stride()) 11999*da0073e9SAndroid Build Coastguard Worker 12000*da0073e9SAndroid Build Coastguard Worker # (2) If leaf isn't dense, checks that grads are rowmajor contiguous. 12001*da0073e9SAndroid Build Coastguard Worker c = torch.empty_strided((2, 2), (4, 2), device=device).copy_( 12002*da0073e9SAndroid Build Coastguard Worker torch.rand((2, 2), device=device) 12003*da0073e9SAndroid Build Coastguard Worker ) 12004*da0073e9SAndroid Build Coastguard Worker c.requires_grad_() 12005*da0073e9SAndroid Build Coastguard Worker d = torch.rand((2, 2), device=device) 12006*da0073e9SAndroid Build Coastguard Worker # checks (2) for broadcasted gradients 12007*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 12008*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.grad.stride(), (2, 1)) 12009*da0073e9SAndroid Build Coastguard Worker # checks (2) for non-broadcasted gradients 12010*da0073e9SAndroid Build Coastguard Worker c.grad = None 12011*da0073e9SAndroid Build Coastguard Worker (c * d).sum().backward() 12012*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.grad.stride(), (2, 1)) 12013*da0073e9SAndroid Build Coastguard Worker 12014*da0073e9SAndroid Build Coastguard Worker @skipIfMps 12015*da0073e9SAndroid Build Coastguard Worker def test_copy_r_to_c(self, device): 12016*da0073e9SAndroid Build Coastguard Worker out_c = torch.empty(3, 2, dtype=torch.cdouble, device=device) 12017*da0073e9SAndroid Build Coastguard Worker inp_r = torch.randn(3, 2, dtype=torch.double, device=device, requires_grad=True) 12018*da0073e9SAndroid Build Coastguard Worker 12019*da0073e9SAndroid Build Coastguard Worker def do_test(): 12020*da0073e9SAndroid Build Coastguard Worker out_c.copy_(inp_r) 12021*da0073e9SAndroid Build Coastguard Worker out_c_inter = out_c.sum() 12022*da0073e9SAndroid Build Coastguard Worker out_c_inter.abs().backward() 12023*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 12024*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 12025*da0073e9SAndroid Build Coastguard Worker inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_c_inter).real 12026*da0073e9SAndroid Build Coastguard Worker ) 12027*da0073e9SAndroid Build Coastguard Worker 12028*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(do_test) 12029*da0073e9SAndroid Build Coastguard Worker 12030*da0073e9SAndroid Build Coastguard Worker def test_to_r_to_c(self, device): 12031*da0073e9SAndroid Build Coastguard Worker def do_test(): 12032*da0073e9SAndroid Build Coastguard Worker inp_r = torch.randn( 12033*da0073e9SAndroid Build Coastguard Worker 3, 2, dtype=torch.double, device=device, requires_grad=True 12034*da0073e9SAndroid Build Coastguard Worker ) 12035*da0073e9SAndroid Build Coastguard Worker out = inp_r.to(torch.complex128) 12036*da0073e9SAndroid Build Coastguard Worker out_inter = out.sum() 12037*da0073e9SAndroid Build Coastguard Worker out_inter.abs().backward() 12038*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 12039*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 12040*da0073e9SAndroid Build Coastguard Worker inp_r.grad, torch.ones_like(inp_r) * torch.sgn(out_inter).real 12041*da0073e9SAndroid Build Coastguard Worker ) 12042*da0073e9SAndroid Build Coastguard Worker 12043*da0073e9SAndroid Build Coastguard Worker self.assertNotWarn(do_test) 12044*da0073e9SAndroid Build Coastguard Worker 12045*da0073e9SAndroid Build Coastguard Worker def test_non_differentiable_ops(self, device): 12046*da0073e9SAndroid Build Coastguard Worker # Just make sure the op doesn't raise an error 12047*da0073e9SAndroid Build Coastguard Worker # and resulting tensor has requires_grad=False. 12048*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([[1, 2], [3, 4.0]], requires_grad=True, device=device) 12049*da0073e9SAndroid Build Coastguard Worker out = torch.isin(x, torch.tensor([2, 3], device=device)) 12050*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 12051*da0073e9SAndroid Build Coastguard Worker 12052*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3, requires_grad=True) 12053*da0073e9SAndroid Build Coastguard Worker out = torch.signbit(x) 12054*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 12055*da0073e9SAndroid Build Coastguard Worker 12056*da0073e9SAndroid Build Coastguard Worker def test_warning_in_backward(self, device): 12057*da0073e9SAndroid Build Coastguard Worker # Test warning during backward are always propagated as python warnings (gh-50209) 12058*da0073e9SAndroid Build Coastguard Worker # NOTE: For device=cuda, warning gets propagated from a worker thread 12059*da0073e9SAndroid Build Coastguard Worker a = torch.zeros((), device=device, requires_grad=True) 12060*da0073e9SAndroid Build Coastguard Worker b = torch._C._nn._test_warn_in_autograd(a) 12061*da0073e9SAndroid Build Coastguard Worker 12062*da0073e9SAndroid Build Coastguard Worker with self.assertWarnsRegex(UserWarning, "Warn from backward"): 12063*da0073e9SAndroid Build Coastguard Worker b.backward() 12064*da0073e9SAndroid Build Coastguard Worker 12065*da0073e9SAndroid Build Coastguard Worker def test_complex_scalar_backward(self, device): 12066*da0073e9SAndroid Build Coastguard Worker a = torch.zeros(1, device=device, requires_grad=True) 12067*da0073e9SAndroid Build Coastguard Worker b = a * 0.5j 12068*da0073e9SAndroid Build Coastguard Worker 12069*da0073e9SAndroid Build Coastguard Worker msg = "grad can be implicitly created only for real scalar outputs" 12070*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 12071*da0073e9SAndroid Build Coastguard Worker b.backward() 12072*da0073e9SAndroid Build Coastguard Worker 12073*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 12074*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(b, a) 12075*da0073e9SAndroid Build Coastguard Worker 12076*da0073e9SAndroid Build Coastguard Worker def test_pow_real_negative_base_complex_exponent(self, device): 12077*da0073e9SAndroid Build Coastguard Worker # OpInfo doesn't naturally support input of mixed types, hence this test here. 12078*da0073e9SAndroid Build Coastguard Worker base = -torch.ones(2, device=device, dtype=torch.double) 12079*da0073e9SAndroid Build Coastguard Worker exponent = torch.randn( 12080*da0073e9SAndroid Build Coastguard Worker 2, device=device, dtype=torch.cdouble, requires_grad=True 12081*da0073e9SAndroid Build Coastguard Worker ) 12082*da0073e9SAndroid Build Coastguard Worker 12083*da0073e9SAndroid Build Coastguard Worker def fn(exponent): 12084*da0073e9SAndroid Build Coastguard Worker return torch.pow(base, exponent) 12085*da0073e9SAndroid Build Coastguard Worker 12086*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(fn, (exponent,)) 12087*da0073e9SAndroid Build Coastguard Worker 12088*da0073e9SAndroid Build Coastguard Worker def fn(exponent): 12089*da0073e9SAndroid Build Coastguard Worker return torch.pow(-1, exponent) 12090*da0073e9SAndroid Build Coastguard Worker 12091*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(fn, (exponent,)) 12092*da0073e9SAndroid Build Coastguard Worker 12093*da0073e9SAndroid Build Coastguard Worker def test_resize_version_bump(self, device): 12094*da0073e9SAndroid Build Coastguard Worker x = torch.rand((1,), device=device) 12095*da0073e9SAndroid Build Coastguard Worker y = torch.randn((3,), device=device) 12096*da0073e9SAndroid Build Coastguard Worker x.resize_((1, 2)) 12097*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._version, 1) 12098*da0073e9SAndroid Build Coastguard Worker x.resize_as_(y) 12099*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._version, 2) 12100*da0073e9SAndroid Build Coastguard Worker 12101*da0073e9SAndroid Build Coastguard Worker # In the following cases, `resize` is no-op, 12102*da0073e9SAndroid Build Coastguard Worker # so no version bumps. 12103*da0073e9SAndroid Build Coastguard Worker x.resize_((3,)) 12104*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._version, 2) 12105*da0073e9SAndroid Build Coastguard Worker 12106*da0073e9SAndroid Build Coastguard Worker x.resize_as_(y) 12107*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x._version, 2) 12108*da0073e9SAndroid Build Coastguard Worker 12109*da0073e9SAndroid Build Coastguard Worker 12110*da0073e9SAndroid Build Coastguard Workerclass TestAllowMutationOnSaved(TestCase): 12111*da0073e9SAndroid Build Coastguard Worker def assertClonedLenEqual(self, ctx, n): 12112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(ctx.cloned.items())), n) 12113*da0073e9SAndroid Build Coastguard Worker 12114*da0073e9SAndroid Build Coastguard Worker def assertTIDMapLenEqual(self, ctx, n): 12115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(list(ctx.tid_to_weakhandle.items())), n) 12116*da0073e9SAndroid Build Coastguard Worker 12117*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 12118*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12119*da0073e9SAndroid Build Coastguard Worker 12120*da0073e9SAndroid Build Coastguard Worker def fn(a): 12121*da0073e9SAndroid Build Coastguard Worker b = a.clone() 12122*da0073e9SAndroid Build Coastguard Worker out = (b**2).sum() 12123*da0073e9SAndroid Build Coastguard Worker b.sin_() 12124*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 12125*da0073e9SAndroid Build Coastguard Worker return a.grad 12126*da0073e9SAndroid Build Coastguard Worker 12127*da0073e9SAndroid Build Coastguard Worker msg = ( 12128*da0073e9SAndroid Build Coastguard Worker "variables needed for gradient computation has been modified by an inplace" 12129*da0073e9SAndroid Build Coastguard Worker ) 12130*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 12131*da0073e9SAndroid Build Coastguard Worker fn(a) 12132*da0073e9SAndroid Build Coastguard Worker 12133*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12134*da0073e9SAndroid Build Coastguard Worker da = fn(a) 12135*da0073e9SAndroid Build Coastguard Worker 12136*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a * 2, da)) 12137*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12138*da0073e9SAndroid Build Coastguard Worker 12139*da0073e9SAndroid Build Coastguard Worker def test_views(self): 12140*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12141*da0073e9SAndroid Build Coastguard Worker 12142*da0073e9SAndroid Build Coastguard Worker def fn(a): 12143*da0073e9SAndroid Build Coastguard Worker b = a.clone() 12144*da0073e9SAndroid Build Coastguard Worker c = b.view_as(b) 12145*da0073e9SAndroid Build Coastguard Worker out = (b**2).sum() # How does this work? 12146*da0073e9SAndroid Build Coastguard Worker c.sin_() 12147*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 12148*da0073e9SAndroid Build Coastguard Worker return a.grad 12149*da0073e9SAndroid Build Coastguard Worker 12150*da0073e9SAndroid Build Coastguard Worker msg = ( 12151*da0073e9SAndroid Build Coastguard Worker "variables needed for gradient computation has been modified by an inplace" 12152*da0073e9SAndroid Build Coastguard Worker ) 12153*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 12154*da0073e9SAndroid Build Coastguard Worker fn(a) 12155*da0073e9SAndroid Build Coastguard Worker 12156*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12157*da0073e9SAndroid Build Coastguard Worker da = fn(a) 12158*da0073e9SAndroid Build Coastguard Worker 12159*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a * 2, da)) 12161*da0073e9SAndroid Build Coastguard Worker 12162*da0073e9SAndroid Build Coastguard Worker def test_save_base_and_modify_view(self): 12163*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12164*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12165*da0073e9SAndroid Build Coastguard Worker b = a.clone() 12166*da0073e9SAndroid Build Coastguard Worker c = b[:1] 12167*da0073e9SAndroid Build Coastguard Worker out = b**2 12168*da0073e9SAndroid Build Coastguard Worker # modify the view 12169*da0073e9SAndroid Build Coastguard Worker c *= 10 12170*da0073e9SAndroid Build Coastguard Worker # self.assertClonedLenEqual(ctx, 1) 12171*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 12172*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12173*da0073e9SAndroid Build Coastguard Worker 12174*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12175*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a * 2, a.grad)) 12176*da0073e9SAndroid Build Coastguard Worker 12177*da0073e9SAndroid Build Coastguard Worker def test_save_view_modify_base(self): 12178*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12179*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12180*da0073e9SAndroid Build Coastguard Worker b = a.clone() 12181*da0073e9SAndroid Build Coastguard Worker c = b[:] 12182*da0073e9SAndroid Build Coastguard Worker out = (c**2).sum() 12183*da0073e9SAndroid Build Coastguard Worker b *= 2 12184*da0073e9SAndroid Build Coastguard Worker out.backward() 12185*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(a * 2, a.grad)) 12186*da0073e9SAndroid Build Coastguard Worker 12187*da0073e9SAndroid Build Coastguard Worker def test_double_backward(self): 12188*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12189*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12190*da0073e9SAndroid Build Coastguard Worker b = a.clone() 12191*da0073e9SAndroid Build Coastguard Worker out = (b**2).sum() 12192*da0073e9SAndroid Build Coastguard Worker b.sin_() 12193*da0073e9SAndroid Build Coastguard Worker torch.autograd.grad(out, a, create_graph=True) 12194*da0073e9SAndroid Build Coastguard Worker (da,) = torch.autograd.grad(out, a, create_graph=True) 12195*da0073e9SAndroid Build Coastguard Worker (d2a,) = torch.autograd.grad(da.sum(), a) 12196*da0073e9SAndroid Build Coastguard Worker 12197*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(torch.ones_like(a) * 2, d2a)) 12198*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12199*da0073e9SAndroid Build Coastguard Worker 12200*da0073e9SAndroid Build Coastguard Worker def test_saved_but_not_anymore(self): 12201*da0073e9SAndroid Build Coastguard Worker # Make sure we don't clone if the tensor was once saved, but 12202*da0073e9SAndroid Build Coastguard Worker # by the time we do in-place, it is no longer saved 12203*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12204*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True).clone() 12205*da0073e9SAndroid Build Coastguard Worker out = (a**2).sum() 12206*da0073e9SAndroid Build Coastguard Worker self.assertTIDMapLenEqual(ctx, 1) 12207*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12208*da0073e9SAndroid Build Coastguard Worker out.backward() 12209*da0073e9SAndroid Build Coastguard Worker a.sin_() 12210*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12211*da0073e9SAndroid Build Coastguard Worker out = (a**2).sum() 12212*da0073e9SAndroid Build Coastguard Worker a.sin_() 12213*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 1) 12214*da0073e9SAndroid Build Coastguard Worker del out 12215*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12216*da0073e9SAndroid Build Coastguard Worker 12217*da0073e9SAndroid Build Coastguard Worker def test_saved_same_tensor_many_times(self): 12218*da0073e9SAndroid Build Coastguard Worker # We should only clone once 12219*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12220*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True).clone() 12221*da0073e9SAndroid Build Coastguard Worker b = a**2 12222*da0073e9SAndroid Build Coastguard Worker c = a**2 12223*da0073e9SAndroid Build Coastguard Worker a.sin_() 12224*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 1) 12225*da0073e9SAndroid Build Coastguard Worker del b, c 12226*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12227*da0073e9SAndroid Build Coastguard Worker 12228*da0073e9SAndroid Build Coastguard Worker def test_saved_same_tensor_different_versions(self): 12229*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12230*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, requires_grad=True).clone() 12231*da0073e9SAndroid Build Coastguard Worker b = a**2 12232*da0073e9SAndroid Build Coastguard Worker a.sin_() 12233*da0073e9SAndroid Build Coastguard Worker c = a**2 12234*da0073e9SAndroid Build Coastguard Worker a.sin_() 12235*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 2) 12236*da0073e9SAndroid Build Coastguard Worker del b 12237*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 1) 12238*da0073e9SAndroid Build Coastguard Worker del c 12239*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12240*da0073e9SAndroid Build Coastguard Worker 12241*da0073e9SAndroid Build Coastguard Worker def test_with_math_views(self): 12242*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12243*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1 + 1j], requires_grad=True).clone() 12244*da0073e9SAndroid Build Coastguard Worker b = a.conj() 12245*da0073e9SAndroid Build Coastguard Worker out = (b**2).sum() 12246*da0073e9SAndroid Build Coastguard Worker a.sin_() 12247*da0073e9SAndroid Build Coastguard Worker out.abs().backward() 12248*da0073e9SAndroid Build Coastguard Worker 12249*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1 + 1j], requires_grad=True).clone() 12250*da0073e9SAndroid Build Coastguard Worker b = a.conj() 12251*da0073e9SAndroid Build Coastguard Worker out = (b**2).sum() 12252*da0073e9SAndroid Build Coastguard Worker # in this case, it is no longer a view it seems 12253*da0073e9SAndroid Build Coastguard Worker b.sin_() 12254*da0073e9SAndroid Build Coastguard Worker out.abs().backward() 12255*da0073e9SAndroid Build Coastguard Worker 12256*da0073e9SAndroid Build Coastguard Worker def test_with_out_variant(self): 12257*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12258*da0073e9SAndroid Build Coastguard Worker a = torch.tensor([1.0], requires_grad=True) 12259*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([1.0]) 12260*da0073e9SAndroid Build Coastguard Worker c = torch.tensor([2.0]) 12261*da0073e9SAndroid Build Coastguard Worker out = a * b 12262*da0073e9SAndroid Build Coastguard Worker self.assertTIDMapLenEqual(ctx, 1) 12263*da0073e9SAndroid Build Coastguard Worker torch.sin(c, out=b) 12264*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 1) 12265*da0073e9SAndroid Build Coastguard Worker out.backward() 12266*da0073e9SAndroid Build Coastguard Worker self.assertClonedLenEqual(ctx, 0) 12267*da0073e9SAndroid Build Coastguard Worker 12268*da0073e9SAndroid Build Coastguard Worker def test_backward_out_of_context(self): 12269*da0073e9SAndroid Build Coastguard Worker # Out of context 12270*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12271*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12272*da0073e9SAndroid Build Coastguard Worker out = (a**2).sum() 12273*da0073e9SAndroid Build Coastguard Worker 12274*da0073e9SAndroid Build Coastguard Worker msg = "Trying to backward outside of the 'allow_mutation_on_saved_tensors' context" 12275*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, msg): 12276*da0073e9SAndroid Build Coastguard Worker out.backward() 12277*da0073e9SAndroid Build Coastguard Worker 12278*da0073e9SAndroid Build Coastguard Worker # Different context 12279*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12280*da0073e9SAndroid Build Coastguard Worker a = torch.rand(2, 3, requires_grad=True) 12281*da0073e9SAndroid Build Coastguard Worker out = (a**2).sum() 12282*da0073e9SAndroid Build Coastguard Worker 12283*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12284*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(AssertionError, msg): 12285*da0073e9SAndroid Build Coastguard Worker out.backward() 12286*da0073e9SAndroid Build Coastguard Worker 12287*da0073e9SAndroid Build Coastguard Worker def test_disallow_nesting(self): 12288*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12289*da0073e9SAndroid Build Coastguard Worker msg = "allow_mutation_on_saved_tensors contexts cannot be nested" 12290*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, msg): 12291*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.allow_mutation_on_saved_tensors() as ctx: 12292*da0073e9SAndroid Build Coastguard Worker pass 12293*da0073e9SAndroid Build Coastguard Worker 12294*da0073e9SAndroid Build Coastguard Worker 12295*da0073e9SAndroid Build Coastguard Workerclass TestAutogradInferenceMode(TestCase): 12296*da0073e9SAndroid Build Coastguard Worker def _is_inference_tensor(self, tensor): 12297*da0073e9SAndroid Build Coastguard Worker try: 12298*da0073e9SAndroid Build Coastguard Worker err_msg = "Inference tensors do not track version counter" 12299*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12300*da0073e9SAndroid Build Coastguard Worker tensor._version 12301*da0073e9SAndroid Build Coastguard Worker return True 12302*da0073e9SAndroid Build Coastguard Worker except AssertionError as e: 12303*da0073e9SAndroid Build Coastguard Worker return False 12304*da0073e9SAndroid Build Coastguard Worker 12305*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_context_manager(self): 12306*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference_mode_enabled()) 12307*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12308*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference_mode_enabled()) 12309*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(False): 12310*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference_mode_enabled()) 12311*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference_mode_enabled()) 12312*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference_mode_enabled()) 12313*da0073e9SAndroid Build Coastguard Worker 12314*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_decorator(self): 12315*da0073e9SAndroid Build Coastguard Worker def func(x): 12316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.is_inference_mode_enabled(), mode) 12317*da0073e9SAndroid Build Coastguard Worker return x * x 12318*da0073e9SAndroid Build Coastguard Worker 12319*da0073e9SAndroid Build Coastguard Worker for mode, use_kwarg in product((True, False, None), (True, False)): 12320*da0073e9SAndroid Build Coastguard Worker if mode is None: 12321*da0073e9SAndroid Build Coastguard Worker if use_kwarg: 12322*da0073e9SAndroid Build Coastguard Worker decorated = torch.inference_mode(mode=func) 12323*da0073e9SAndroid Build Coastguard Worker else: 12324*da0073e9SAndroid Build Coastguard Worker decorated = torch.inference_mode(func) 12325*da0073e9SAndroid Build Coastguard Worker mode = True 12326*da0073e9SAndroid Build Coastguard Worker else: 12327*da0073e9SAndroid Build Coastguard Worker if use_kwarg: 12328*da0073e9SAndroid Build Coastguard Worker decorated = torch.inference_mode(mode=mode)(func) 12329*da0073e9SAndroid Build Coastguard Worker else: 12330*da0073e9SAndroid Build Coastguard Worker decorated = torch.inference_mode(mode)(func) 12331*da0073e9SAndroid Build Coastguard Worker 12332*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12333*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12334*da0073e9SAndroid Build Coastguard Worker d = decorated(c) 12335*da0073e9SAndroid Build Coastguard Worker self.assertTrue(not mode or torch.is_inference(d)) 12336*da0073e9SAndroid Build Coastguard Worker self.assertEqual(d.requires_grad, requires_grad and not mode) 12337*da0073e9SAndroid Build Coastguard Worker 12338*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_tensor_creation(self): 12339*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12340*da0073e9SAndroid Build Coastguard Worker # new tensors created through constructors are inference tensors 12341*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3) 12342*da0073e9SAndroid Build Coastguard Worker self.assertFalse(c.requires_grad) 12343*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(c)) 12344*da0073e9SAndroid Build Coastguard Worker 12345*da0073e9SAndroid Build Coastguard Worker # requires_grad doesn't change inference tensor behavior in InferenceMode 12346*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(1, 2, 3, requires_grad=True) 12347*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tmp.requires_grad) 12348*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(tmp)) 12349*da0073e9SAndroid Build Coastguard Worker 12350*da0073e9SAndroid Build Coastguard Worker tmp = torch.ones(1, 2, 3).requires_grad_(False) 12351*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tmp.requires_grad) 12352*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(tmp)) 12353*da0073e9SAndroid Build Coastguard Worker 12354*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_existing_autograd_session(self): 12355*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=True) 12356*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12357*da0073e9SAndroid Build Coastguard Worker 12358*da0073e9SAndroid Build Coastguard Worker # `a` gets saved outside of inference mode 12359*da0073e9SAndroid Build Coastguard Worker out = a * a 12360*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12361*da0073e9SAndroid Build Coastguard Worker a.add_(2) 12362*da0073e9SAndroid Build Coastguard Worker 12363*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(a)) 12364*da0073e9SAndroid Build Coastguard Worker # tensors created outside of inference mode aren't 12365*da0073e9SAndroid Build Coastguard Worker # inference tensors, so they will still have their 12366*da0073e9SAndroid Build Coastguard Worker # version counters tracked 12367*da0073e9SAndroid Build Coastguard Worker err_msg = ( 12368*da0073e9SAndroid Build Coastguard Worker "one of the variables needed for gradient computation has been " 12369*da0073e9SAndroid Build Coastguard Worker "modified by an inplace operation" 12370*da0073e9SAndroid Build Coastguard Worker ) 12371*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12372*da0073e9SAndroid Build Coastguard Worker out.backward(torch.ones_like(out)) 12373*da0073e9SAndroid Build Coastguard Worker 12374*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_inf_tensor_in_inf_mode_functional_op(self): 12375*da0073e9SAndroid Build Coastguard Worker def functional_op(x): 12376*da0073e9SAndroid Build Coastguard Worker return x * x 12377*da0073e9SAndroid Build Coastguard Worker 12378*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12379*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12380*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12381*da0073e9SAndroid Build Coastguard Worker 12382*da0073e9SAndroid Build Coastguard Worker # performing a non-view operation produces a inference tensor 12383*da0073e9SAndroid Build Coastguard Worker # that does not require grad 12384*da0073e9SAndroid Build Coastguard Worker func_out = functional_op(c) 12385*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(func_out)) 12386*da0073e9SAndroid Build Coastguard Worker self.assertFalse(func_out.requires_grad) 12387*da0073e9SAndroid Build Coastguard Worker 12388*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_inf_tensor_in_inf_mode_inplace_op(self): 12389*da0073e9SAndroid Build Coastguard Worker @torch.inference_mode() 12390*da0073e9SAndroid Build Coastguard Worker def run_test(fn): 12391*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12392*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12393*da0073e9SAndroid Build Coastguard Worker 12394*da0073e9SAndroid Build Coastguard Worker # after performing inplace operation, tensor is still 12395*da0073e9SAndroid Build Coastguard Worker # an inference tensor 12396*da0073e9SAndroid Build Coastguard Worker fn(c) 12397*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(c)) 12398*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.requires_grad, requires_grad) 12399*da0073e9SAndroid Build Coastguard Worker 12400*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.add_(2)) 12401*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.transpose_(0, 1)) 12402*da0073e9SAndroid Build Coastguard Worker 12403*da0073e9SAndroid Build Coastguard Worker # inplace ops with manual kernel for ADInplaceOrView key in VariableTypeManual.cpp 12404*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.resize_(1, 2)) 12405*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.resize_as_(torch.ones(1, 2))) 12406*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.copy_(torch.ones(1, 2, 3))) 12407*da0073e9SAndroid Build Coastguard Worker 12408*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_inf_tensor_in_inf_mode_view_op(self): 12409*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12410*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12411*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12412*da0073e9SAndroid Build Coastguard Worker 12413*da0073e9SAndroid Build Coastguard Worker # perform view operation produces inference tensor 12414*da0073e9SAndroid Build Coastguard Worker # that does not require grad 12415*da0073e9SAndroid Build Coastguard Worker view_out = c.view(-1) 12416*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(view_out)) 12417*da0073e9SAndroid Build Coastguard Worker self.assertFalse(view_out.requires_grad) 12418*da0073e9SAndroid Build Coastguard Worker 12419*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_inf_tensor_in_normal_mode_functional_op(self): 12420*da0073e9SAndroid Build Coastguard Worker def functional_op(x): 12421*da0073e9SAndroid Build Coastguard Worker return x * x 12422*da0073e9SAndroid Build Coastguard Worker 12423*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12424*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12425*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12426*da0073e9SAndroid Build Coastguard Worker 12427*da0073e9SAndroid Build Coastguard Worker func_out = functional_op(c) 12428*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(func_out)) 12429*da0073e9SAndroid Build Coastguard Worker self.assertFalse(func_out.requires_grad) 12430*da0073e9SAndroid Build Coastguard Worker self.assertTrue(func_out.is_leaf) 12431*da0073e9SAndroid Build Coastguard Worker 12432*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_inf_tensor_in_normal_mode_inplace_op(self): 12433*da0073e9SAndroid Build Coastguard Worker def run_test(fn): 12434*da0073e9SAndroid Build Coastguard Worker for requires_grad in (False, True): 12435*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12436*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12437*da0073e9SAndroid Build Coastguard Worker 12438*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12439*da0073e9SAndroid Build Coastguard Worker # leaf variable that requires grad is being used in an inplace 12440*da0073e9SAndroid Build Coastguard Worker # operation when requires_grad=True 12441*da0073e9SAndroid Build Coastguard Worker pass 12442*da0073e9SAndroid Build Coastguard Worker else: 12443*da0073e9SAndroid Build Coastguard Worker err_msg = "Inplace update to inference tensor outside InferenceMode" 12444*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12445*da0073e9SAndroid Build Coastguard Worker fn(c) 12446*da0073e9SAndroid Build Coastguard Worker 12447*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.add_(2)) 12448*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.transpose_(0, 1)) 12449*da0073e9SAndroid Build Coastguard Worker 12450*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_inf_tensor_in_normal_mode_view_op(self): 12451*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12452*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12453*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12454*da0073e9SAndroid Build Coastguard Worker 12455*da0073e9SAndroid Build Coastguard Worker out = c.view(-1) 12456*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(out)) 12457*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 12458*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out._is_view()) 12459*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_leaf) 12460*da0073e9SAndroid Build Coastguard Worker 12461*da0073e9SAndroid Build Coastguard Worker def test_normal_tensor_inplace_output_in_inference_mode(self): 12462*da0073e9SAndroid Build Coastguard Worker def run_test(fn): 12463*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12464*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12465*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12466*da0073e9SAndroid Build Coastguard Worker 12467*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12468*da0073e9SAndroid Build Coastguard Worker fn(a) 12469*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(a)) 12470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, requires_grad) 12471*da0073e9SAndroid Build Coastguard Worker 12472*da0073e9SAndroid Build Coastguard Worker # inplace -> inplace 12473*da0073e9SAndroid Build Coastguard Worker fn(a) 12474*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(a)) 12475*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, requires_grad) 12476*da0073e9SAndroid Build Coastguard Worker 12477*da0073e9SAndroid Build Coastguard Worker # inplace -> inplace -> view 12478*da0073e9SAndroid Build Coastguard Worker view_out = a.view(-1) 12479*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(view_out)) 12480*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view_out.requires_grad, requires_grad) 12481*da0073e9SAndroid Build Coastguard Worker 12482*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.add_(2)) 12483*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.transpose_(0, 1)) 12484*da0073e9SAndroid Build Coastguard Worker 12485*da0073e9SAndroid Build Coastguard Worker def test_normal_tensor_inplace_output_in_normal_mode(self): 12486*da0073e9SAndroid Build Coastguard Worker def run_test(fn): 12487*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12488*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12489*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12490*da0073e9SAndroid Build Coastguard Worker 12491*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12492*da0073e9SAndroid Build Coastguard Worker fn(a) 12493*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(a)) 12494*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, requires_grad) 12495*da0073e9SAndroid Build Coastguard Worker 12496*da0073e9SAndroid Build Coastguard Worker fn(a) 12497*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(a)) 12498*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, requires_grad) 12499*da0073e9SAndroid Build Coastguard Worker 12500*da0073e9SAndroid Build Coastguard Worker # inplace -> inplace 12501*da0073e9SAndroid Build Coastguard Worker fn(a) 12502*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(a)) 12503*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.requires_grad, requires_grad) 12504*da0073e9SAndroid Build Coastguard Worker 12505*da0073e9SAndroid Build Coastguard Worker # inplace -> inplace -> view 12506*da0073e9SAndroid Build Coastguard Worker view_out = a.view(-1) 12507*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(view_out)) 12508*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view_out.requires_grad, requires_grad) 12509*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.add_(2)) 12510*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.transpose_(0, 1)) 12511*da0073e9SAndroid Build Coastguard Worker 12512*da0073e9SAndroid Build Coastguard Worker def test_normal_tensor_view_output_in_inference_mode(self): 12513*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12514*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12515*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12516*da0073e9SAndroid Build Coastguard Worker 12517*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12518*da0073e9SAndroid Build Coastguard Worker out = a.view(-1) 12519*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(out)) 12520*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.requires_grad, requires_grad) 12521*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out._is_view()) 12522*da0073e9SAndroid Build Coastguard Worker 12523*da0073e9SAndroid Build Coastguard Worker # view -> view 12524*da0073e9SAndroid Build Coastguard Worker tmp = out.view(-1) 12525*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(tmp)) 12526*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp.requires_grad, requires_grad) 12527*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tmp._is_view()) 12528*da0073e9SAndroid Build Coastguard Worker self.assertTrue(tmp.is_leaf) 12529*da0073e9SAndroid Build Coastguard Worker 12530*da0073e9SAndroid Build Coastguard Worker # view -> view -> inplace 12531*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference_mode_enabled()) 12532*da0073e9SAndroid Build Coastguard Worker tmp.add_(2) 12533*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(tmp)) 12534*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp.requires_grad, requires_grad) 12535*da0073e9SAndroid Build Coastguard Worker # Accessing is_leaf in python tries to update grad_fn and raises: 12536*da0073e9SAndroid Build Coastguard Worker # A view was created in inference mode and its base or 12537*da0073e9SAndroid Build Coastguard Worker # another view of its base has been modified inplace in normal mode 12538*da0073e9SAndroid Build Coastguard Worker # tmp.is_leaf 12539*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a._version, tmp._version) 12540*da0073e9SAndroid Build Coastguard Worker 12541*da0073e9SAndroid Build Coastguard Worker def test_normal_tensor_view_output_in_normal_mode(self): 12542*da0073e9SAndroid Build Coastguard Worker def functional_op(x): 12543*da0073e9SAndroid Build Coastguard Worker return x * x 12544*da0073e9SAndroid Build Coastguard Worker 12545*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12546*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12547*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12548*da0073e9SAndroid Build Coastguard Worker 12549*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12550*da0073e9SAndroid Build Coastguard Worker out = a.view(-1) 12551*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(out)) 12552*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.requires_grad, requires_grad) 12553*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out._is_view()) 12554*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.is_leaf) 12555*da0073e9SAndroid Build Coastguard Worker 12556*da0073e9SAndroid Build Coastguard Worker tmp = functional_op(out) 12557*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(tmp)) 12558*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp.requires_grad, requires_grad) 12559*da0073e9SAndroid Build Coastguard Worker 12560*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12561*da0073e9SAndroid Build Coastguard Worker err_msg = ( 12562*da0073e9SAndroid Build Coastguard Worker "A view was created in inference mode and is being modified inplace" 12563*da0073e9SAndroid Build Coastguard Worker ) 12564*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12565*da0073e9SAndroid Build Coastguard Worker out.add_(2) 12566*da0073e9SAndroid Build Coastguard Worker else: 12567*da0073e9SAndroid Build Coastguard Worker out.add_(2) 12568*da0073e9SAndroid Build Coastguard Worker 12569*da0073e9SAndroid Build Coastguard Worker tmp = out.view(2, 3) 12570*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(tmp)) 12571*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp.requires_grad, requires_grad) 12572*da0073e9SAndroid Build Coastguard Worker 12573*da0073e9SAndroid Build Coastguard Worker def test_mix_inference_and_normal_tensor_functional_op(self): 12574*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12575*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12576*da0073e9SAndroid Build Coastguard Worker 12577*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12578*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3, requires_grad=requires_grad) 12579*da0073e9SAndroid Build Coastguard Worker 12580*da0073e9SAndroid Build Coastguard Worker # add is safe since it doesn't save any variable for backward 12581*da0073e9SAndroid Build Coastguard Worker out = c.add(s) 12582*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(out)) 12583*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.requires_grad, requires_grad) 12584*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12585*da0073e9SAndroid Build Coastguard Worker # leaf inference tensor with requires_grad=True can still have gradient 12586*da0073e9SAndroid Build Coastguard Worker out.backward(torch.ones_like(out)) 12587*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.grad, torch.ones_like(c)) 12588*da0073e9SAndroid Build Coastguard Worker 12589*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12590*da0073e9SAndroid Build Coastguard Worker err_msg = "Inference tensors cannot be saved for backward" 12591*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12592*da0073e9SAndroid Build Coastguard Worker c * s 12593*da0073e9SAndroid Build Coastguard Worker 12594*da0073e9SAndroid Build Coastguard Worker # TODO: Test this with an autograd.Function when it works 12595*da0073e9SAndroid Build Coastguard Worker # stack stopped capturing a TensorList input 12596*da0073e9SAndroid Build Coastguard Worker # # inference tensor in TensorList input 12597*da0073e9SAndroid Build Coastguard Worker # inputs = [s, c] 12598*da0073e9SAndroid Build Coastguard Worker # with self.assertRaisesRegex(RuntimeError, err_msg): 12599*da0073e9SAndroid Build Coastguard Worker # torch.stack(inputs) 12600*da0073e9SAndroid Build Coastguard Worker 12601*da0073e9SAndroid Build Coastguard Worker def test_mix_inference_and_normal_tensor_inplace_op(self): 12602*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12603*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12604*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12605*da0073e9SAndroid Build Coastguard Worker 12606*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12607*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3) 12608*da0073e9SAndroid Build Coastguard Worker 12609*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(c)) 12610*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12611*da0073e9SAndroid Build Coastguard Worker err_msg = "Inference tensors cannot be saved for backward" 12612*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12613*da0073e9SAndroid Build Coastguard Worker a.mul_(c) 12614*da0073e9SAndroid Build Coastguard Worker 12615*da0073e9SAndroid Build Coastguard Worker # inference tensor in TensorList input 12616*da0073e9SAndroid Build Coastguard Worker err_msg = ( 12617*da0073e9SAndroid Build Coastguard Worker "out=... arguments don't support automatic differentiation, " 12618*da0073e9SAndroid Build Coastguard Worker "but one of the arguments requires grad" 12619*da0073e9SAndroid Build Coastguard Worker ) 12620*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12621*da0073e9SAndroid Build Coastguard Worker torch.mul(s, s, out=c) 12622*da0073e9SAndroid Build Coastguard Worker else: 12623*da0073e9SAndroid Build Coastguard Worker a.mul_(c) 12624*da0073e9SAndroid Build Coastguard Worker err_msg = "Inplace update to inference tensor outside InferenceMode is not allowed" 12625*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12626*da0073e9SAndroid Build Coastguard Worker torch.mul(s, s, out=c) 12627*da0073e9SAndroid Build Coastguard Worker 12628*da0073e9SAndroid Build Coastguard Worker def test_mix_inference_and_normal_tensor_view_op(self): 12629*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12630*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12631*da0073e9SAndroid Build Coastguard Worker 12632*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12633*da0073e9SAndroid Build Coastguard Worker c = torch.ones(1, 2, 3) 12634*da0073e9SAndroid Build Coastguard Worker 12635*da0073e9SAndroid Build Coastguard Worker # view_as is a composite op which calls view with only one 12636*da0073e9SAndroid Build Coastguard Worker # tensor argument. So there isn't a mixed inference and normal 12637*da0073e9SAndroid Build Coastguard Worker # tensor inputs for view ops 12638*da0073e9SAndroid Build Coastguard Worker tmp1 = c.view_as(s) 12639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.is_inference(tmp1)) 12640*da0073e9SAndroid Build Coastguard Worker self.assertFalse(tmp1.requires_grad) 12641*da0073e9SAndroid Build Coastguard Worker 12642*da0073e9SAndroid Build Coastguard Worker # this is fine since its equivalent as s.view(c.sizes()) which 12643*da0073e9SAndroid Build Coastguard Worker # isn't a mixed input scenario 12644*da0073e9SAndroid Build Coastguard Worker tmp2 = s.view_as(c) 12645*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.is_inference(tmp2)) 12646*da0073e9SAndroid Build Coastguard Worker self.assertEqual(tmp2.requires_grad, requires_grad) 12647*da0073e9SAndroid Build Coastguard Worker 12648*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_handle_direct_view_on_rebase(self): 12649*da0073e9SAndroid Build Coastguard Worker def run_test(fn): 12650*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12651*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12652*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12653*da0073e9SAndroid Build Coastguard Worker 12654*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12655*da0073e9SAndroid Build Coastguard Worker view_out = a.view_as(a) 12656*da0073e9SAndroid Build Coastguard Worker 12657*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12658*da0073e9SAndroid Build Coastguard Worker err_msg = "A view was created in inference mode and is being modified inplace" 12659*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12660*da0073e9SAndroid Build Coastguard Worker fn(view_out) 12661*da0073e9SAndroid Build Coastguard Worker else: 12662*da0073e9SAndroid Build Coastguard Worker fn(view_out) 12663*da0073e9SAndroid Build Coastguard Worker 12664*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.add_(2)) 12665*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.transpose_(0, 1)) 12666*da0073e9SAndroid Build Coastguard Worker 12667*da0073e9SAndroid Build Coastguard Worker def test_inference_mode_handle_indirect_view_on_rebase(self): 12668*da0073e9SAndroid Build Coastguard Worker def run_test(fn): 12669*da0073e9SAndroid Build Coastguard Worker for requires_grad in (True, False): 12670*da0073e9SAndroid Build Coastguard Worker s = torch.ones(1, 2, 3, requires_grad=requires_grad) 12671*da0073e9SAndroid Build Coastguard Worker a = s.clone() 12672*da0073e9SAndroid Build Coastguard Worker 12673*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 12674*da0073e9SAndroid Build Coastguard Worker view_out = a.view(-1) 12675*da0073e9SAndroid Build Coastguard Worker 12676*da0073e9SAndroid Build Coastguard Worker fn(a) 12677*da0073e9SAndroid Build Coastguard Worker if requires_grad: 12678*da0073e9SAndroid Build Coastguard Worker err_msg = "A view was created in inference mode and its base or another view " 12679*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 12680*da0073e9SAndroid Build Coastguard Worker view_out.grad_fn 12681*da0073e9SAndroid Build Coastguard Worker else: 12682*da0073e9SAndroid Build Coastguard Worker view_out.grad_fn 12683*da0073e9SAndroid Build Coastguard Worker 12684*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.add_(2)) 12685*da0073e9SAndroid Build Coastguard Worker run_test(lambda x: x.transpose_(0, 1)) 12686*da0073e9SAndroid Build Coastguard Worker 12687*da0073e9SAndroid Build Coastguard Worker 12688*da0073e9SAndroid Build Coastguard Workerclass TestMultithreadAutograd(TestCase): 12689*da0073e9SAndroid Build Coastguard Worker def _run_py_multithread_fn( 12690*da0073e9SAndroid Build Coastguard Worker self, fn, args=(), num_threads=10, kwargs=None, pass_idx=False 12691*da0073e9SAndroid Build Coastguard Worker ): 12692*da0073e9SAndroid Build Coastguard Worker class PropagatingThread(threading.Thread): 12693*da0073e9SAndroid Build Coastguard Worker """Helper class to propagate exception from child 12694*da0073e9SAndroid Build Coastguard Worker thread to main thread on join. 12695*da0073e9SAndroid Build Coastguard Worker 12696*da0073e9SAndroid Build Coastguard Worker Reference: https://stackoverflow.com/a/31614591/5602957 12697*da0073e9SAndroid Build Coastguard Worker """ 12698*da0073e9SAndroid Build Coastguard Worker 12699*da0073e9SAndroid Build Coastguard Worker def run(self): 12700*da0073e9SAndroid Build Coastguard Worker self.exception = None 12701*da0073e9SAndroid Build Coastguard Worker try: 12702*da0073e9SAndroid Build Coastguard Worker self.ret = super().run() 12703*da0073e9SAndroid Build Coastguard Worker except Exception as e: 12704*da0073e9SAndroid Build Coastguard Worker self.exception = e 12705*da0073e9SAndroid Build Coastguard Worker 12706*da0073e9SAndroid Build Coastguard Worker def join(self, timeout=None): 12707*da0073e9SAndroid Build Coastguard Worker super().join(timeout) 12708*da0073e9SAndroid Build Coastguard Worker if self.exception: 12709*da0073e9SAndroid Build Coastguard Worker raise self.exception from self.exception 12710*da0073e9SAndroid Build Coastguard Worker return self.ret 12711*da0073e9SAndroid Build Coastguard Worker 12712*da0073e9SAndroid Build Coastguard Worker threads = [] 12713*da0073e9SAndroid Build Coastguard Worker for idx in range(num_threads): 12714*da0073e9SAndroid Build Coastguard Worker p = PropagatingThread(target=fn, args=((idx, *args) if pass_idx else args)) 12715*da0073e9SAndroid Build Coastguard Worker p.start() 12716*da0073e9SAndroid Build Coastguard Worker threads.append(p) 12717*da0073e9SAndroid Build Coastguard Worker 12718*da0073e9SAndroid Build Coastguard Worker for p in threads: 12719*da0073e9SAndroid Build Coastguard Worker p.join() 12720*da0073e9SAndroid Build Coastguard Worker 12721*da0073e9SAndroid Build Coastguard Worker def test_multithreaded_exception_propagation(self): 12722*da0073e9SAndroid Build Coastguard Worker # Test whether exception in child thread 12723*da0073e9SAndroid Build Coastguard Worker # are propagated to main thread. 12724*da0073e9SAndroid Build Coastguard Worker def fn(): 12725*da0073e9SAndroid Build Coastguard Worker self.assertTrue(False) 12726*da0073e9SAndroid Build Coastguard Worker 12727*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 12728*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(fn) 12729*da0073e9SAndroid Build Coastguard Worker 12730*da0073e9SAndroid Build Coastguard Worker def test_simple_backward(self): 12731*da0073e9SAndroid Build Coastguard Worker # simple multithreaded backward that create threads in the beginning of training 12732*da0073e9SAndroid Build Coastguard Worker # and everything else is training separately, i.e. inputs, operations, etc. 12733*da0073e9SAndroid Build Coastguard Worker def train_fn(): 12734*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 12735*da0073e9SAndroid Build Coastguard Worker y = (x + 3) * (x + 4) * 0.5 12736*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 12737*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, x + 3.5) 12738*da0073e9SAndroid Build Coastguard Worker 12739*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(train_fn) 12740*da0073e9SAndroid Build Coastguard Worker 12741*da0073e9SAndroid Build Coastguard Worker def test_simple_backward_same_input(self): 12742*da0073e9SAndroid Build Coastguard Worker # simple multithreaded backward with only shared inputs (i.e. This is common 12743*da0073e9SAndroid Build Coastguard Worker # for things like Hogwild multithreaded training with multiple CPU threads) 12744*da0073e9SAndroid Build Coastguard Worker def train_fn_backward(x): 12745*da0073e9SAndroid Build Coastguard Worker y = (x + 3) * (x + 4) * 0.5 12746*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 12747*da0073e9SAndroid Build Coastguard Worker 12748*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 12749*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(train_fn_backward, (x,)) 12750*da0073e9SAndroid Build Coastguard Worker # Since we are calling backward from multiple threads 12751*da0073e9SAndroid Build Coastguard Worker # and all threads share the same input, when we do backward 12752*da0073e9SAndroid Build Coastguard Worker # concurrently, different backwards will all accumulate to 12753*da0073e9SAndroid Build Coastguard Worker # the same .grad for each input, and the gradients should 12754*da0073e9SAndroid Build Coastguard Worker # be equal to num_threads * gradient 12755*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, 10 * (x + 3.5)) 12756*da0073e9SAndroid Build Coastguard Worker 12757*da0073e9SAndroid Build Coastguard Worker def train_fn_grad(x): 12758*da0073e9SAndroid Build Coastguard Worker y = (x + 3) * (x + 4) * 0.5 12759*da0073e9SAndroid Build Coastguard Worker grads = torch.autograd.grad(y.sum(), x) 12760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(grads), 1) 12761*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grads[0], x + 3.5) 12762*da0073e9SAndroid Build Coastguard Worker 12763*da0073e9SAndroid Build Coastguard Worker # since we use functional grad() api, gradients will not 12764*da0073e9SAndroid Build Coastguard Worker # be accumulate to the same place and should be the same 12765*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(train_fn_grad, (x,)) 12766*da0073e9SAndroid Build Coastguard Worker 12767*da0073e9SAndroid Build Coastguard Worker def test_multi_grad_all_hooks(self): 12768*da0073e9SAndroid Build Coastguard Worker # Multihooks should behave independently per execution of backward 12769*da0073e9SAndroid Build Coastguard Worker # Test that the hook fired the number of times we ran backward 12770*da0073e9SAndroid Build Coastguard Worker # even if those executions occur concurrently on different threads 12771*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(2, requires_grad=True) 12772*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(2, requires_grad=True) 12773*da0073e9SAndroid Build Coastguard Worker t3 = torch.rand(2, requires_grad=True) 12774*da0073e9SAndroid Build Coastguard Worker t4 = torch.rand(2, requires_grad=True) 12775*da0073e9SAndroid Build Coastguard Worker 12776*da0073e9SAndroid Build Coastguard Worker res = None 12777*da0073e9SAndroid Build Coastguard Worker count = [0] 12778*da0073e9SAndroid Build Coastguard Worker hook_lock = threading.Lock() 12779*da0073e9SAndroid Build Coastguard Worker 12780*da0073e9SAndroid Build Coastguard Worker def hook(grads): 12781*da0073e9SAndroid Build Coastguard Worker nonlocal res 12782*da0073e9SAndroid Build Coastguard Worker with hook_lock: 12783*da0073e9SAndroid Build Coastguard Worker count[0] += 1 12784*da0073e9SAndroid Build Coastguard Worker grad_is_none = [g is not None for g in grads] 12785*da0073e9SAndroid Build Coastguard Worker if res is None: 12786*da0073e9SAndroid Build Coastguard Worker res = grad_is_none 12787*da0073e9SAndroid Build Coastguard Worker else: 12788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, grad_is_none) 12789*da0073e9SAndroid Build Coastguard Worker 12790*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.register_multi_grad_hook((t1, t2, t3, t4), hook) 12791*da0073e9SAndroid Build Coastguard Worker 12792*da0073e9SAndroid Build Coastguard Worker out = (t2 * t3).sum() 12793*da0073e9SAndroid Build Coastguard Worker 12794*da0073e9SAndroid Build Coastguard Worker def backward_retain_graph(out, t2, t3): 12795*da0073e9SAndroid Build Coastguard Worker out.backward(inputs=(t2, t3), retain_graph=True) 12796*da0073e9SAndroid Build Coastguard Worker 12797*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12798*da0073e9SAndroid Build Coastguard Worker 12799*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 5) 12800*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, [False, True, True, False]) 12801*da0073e9SAndroid Build Coastguard Worker 12802*da0073e9SAndroid Build Coastguard Worker # Leave one hook partially applied 12803*da0073e9SAndroid Build Coastguard Worker res = None 12804*da0073e9SAndroid Build Coastguard Worker count = [0] 12805*da0073e9SAndroid Build Coastguard Worker err_count = [0] 12806*da0073e9SAndroid Build Coastguard Worker bw_count = [0] 12807*da0073e9SAndroid Build Coastguard Worker bw_count_lock = threading.Lock() 12808*da0073e9SAndroid Build Coastguard Worker err_count_lock = threading.Lock() 12809*da0073e9SAndroid Build Coastguard Worker 12810*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 12811*da0073e9SAndroid Build Coastguard Worker @staticmethod 12812*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 12813*da0073e9SAndroid Build Coastguard Worker return x 12814*da0073e9SAndroid Build Coastguard Worker 12815*da0073e9SAndroid Build Coastguard Worker @staticmethod 12816*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 12817*da0073e9SAndroid Build Coastguard Worker with bw_count_lock: 12818*da0073e9SAndroid Build Coastguard Worker bw_count[0] += 1 12819*da0073e9SAndroid Build Coastguard Worker if bw_count[0] == 1: 12820*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("error message") 12821*da0073e9SAndroid Build Coastguard Worker else: 12822*da0073e9SAndroid Build Coastguard Worker return gO 12823*da0073e9SAndroid Build Coastguard Worker 12824*da0073e9SAndroid Build Coastguard Worker out = (Func.apply(t2) * t3).sum() 12825*da0073e9SAndroid Build Coastguard Worker 12826*da0073e9SAndroid Build Coastguard Worker def backward_retain_graph(out, t2, t3): 12827*da0073e9SAndroid Build Coastguard Worker try: 12828*da0073e9SAndroid Build Coastguard Worker out.backward(inputs=(t2, t3), retain_graph=True) 12829*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 12830*da0073e9SAndroid Build Coastguard Worker with err_count_lock: 12831*da0073e9SAndroid Build Coastguard Worker err_count[0] += 1 12832*da0073e9SAndroid Build Coastguard Worker 12833*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12834*da0073e9SAndroid Build Coastguard Worker 12835*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 4) 12836*da0073e9SAndroid Build Coastguard Worker self.assertEqual(err_count[0], 1) 12837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, [False, True, True, False]) 12838*da0073e9SAndroid Build Coastguard Worker 12839*da0073e9SAndroid Build Coastguard Worker def test_multi_grad_any_hooks(self): 12840*da0073e9SAndroid Build Coastguard Worker # Multihooks should behave independently per execution of backward 12841*da0073e9SAndroid Build Coastguard Worker # Test that the hook fired the number of times we ran backward 12842*da0073e9SAndroid Build Coastguard Worker # even if those executions occur concurrently on different threads 12843*da0073e9SAndroid Build Coastguard Worker t1 = torch.rand(2, requires_grad=True) 12844*da0073e9SAndroid Build Coastguard Worker t2 = torch.rand(2, requires_grad=True) 12845*da0073e9SAndroid Build Coastguard Worker t3 = torch.rand(2, requires_grad=True) 12846*da0073e9SAndroid Build Coastguard Worker t4 = torch.rand(2, requires_grad=True) 12847*da0073e9SAndroid Build Coastguard Worker 12848*da0073e9SAndroid Build Coastguard Worker res = None 12849*da0073e9SAndroid Build Coastguard Worker count = [0] 12850*da0073e9SAndroid Build Coastguard Worker hook_lock = threading.Lock() 12851*da0073e9SAndroid Build Coastguard Worker 12852*da0073e9SAndroid Build Coastguard Worker def hook(grad): 12853*da0073e9SAndroid Build Coastguard Worker nonlocal res 12854*da0073e9SAndroid Build Coastguard Worker with hook_lock: 12855*da0073e9SAndroid Build Coastguard Worker count[0] += 1 12856*da0073e9SAndroid Build Coastguard Worker if res is None: 12857*da0073e9SAndroid Build Coastguard Worker res = "foo" 12858*da0073e9SAndroid Build Coastguard Worker else: 12859*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, "foo") 12860*da0073e9SAndroid Build Coastguard Worker 12861*da0073e9SAndroid Build Coastguard Worker torch.autograd.graph.register_multi_grad_hook( 12862*da0073e9SAndroid Build Coastguard Worker (t1, t2, t3, t4), hook, mode="any" 12863*da0073e9SAndroid Build Coastguard Worker ) 12864*da0073e9SAndroid Build Coastguard Worker 12865*da0073e9SAndroid Build Coastguard Worker out = (t2 * t3).sum() 12866*da0073e9SAndroid Build Coastguard Worker 12867*da0073e9SAndroid Build Coastguard Worker def backward_retain_graph(out, t2, t3): 12868*da0073e9SAndroid Build Coastguard Worker out.backward(inputs=(t2, t3), retain_graph=True) 12869*da0073e9SAndroid Build Coastguard Worker 12870*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12871*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 5) 12872*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, "foo") 12873*da0073e9SAndroid Build Coastguard Worker 12874*da0073e9SAndroid Build Coastguard Worker # Raise an error in one thread's backward 12875*da0073e9SAndroid Build Coastguard Worker res = None 12876*da0073e9SAndroid Build Coastguard Worker count = [0] 12877*da0073e9SAndroid Build Coastguard Worker err_count = [0] 12878*da0073e9SAndroid Build Coastguard Worker bw_count = [0] 12879*da0073e9SAndroid Build Coastguard Worker bw_count_lock = threading.Lock() 12880*da0073e9SAndroid Build Coastguard Worker err_count_lock = threading.Lock() 12881*da0073e9SAndroid Build Coastguard Worker 12882*da0073e9SAndroid Build Coastguard Worker class Func(torch.autograd.Function): 12883*da0073e9SAndroid Build Coastguard Worker @staticmethod 12884*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 12885*da0073e9SAndroid Build Coastguard Worker return x 12886*da0073e9SAndroid Build Coastguard Worker 12887*da0073e9SAndroid Build Coastguard Worker @staticmethod 12888*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 12889*da0073e9SAndroid Build Coastguard Worker with bw_count_lock: 12890*da0073e9SAndroid Build Coastguard Worker bw_count[0] += 1 12891*da0073e9SAndroid Build Coastguard Worker if bw_count[0] == 1: 12892*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("error message") 12893*da0073e9SAndroid Build Coastguard Worker else: 12894*da0073e9SAndroid Build Coastguard Worker return gO 12895*da0073e9SAndroid Build Coastguard Worker 12896*da0073e9SAndroid Build Coastguard Worker out = (Func.apply(t2) * t3).sum() 12897*da0073e9SAndroid Build Coastguard Worker 12898*da0073e9SAndroid Build Coastguard Worker def backward_retain_graph(out, t2, t3): 12899*da0073e9SAndroid Build Coastguard Worker try: 12900*da0073e9SAndroid Build Coastguard Worker out.backward(inputs=(t2, t3), retain_graph=True) 12901*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 12902*da0073e9SAndroid Build Coastguard Worker with err_count_lock: 12903*da0073e9SAndroid Build Coastguard Worker err_count[0] += 1 12904*da0073e9SAndroid Build Coastguard Worker 12905*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(backward_retain_graph, (out, t2, t3), num_threads=5) 12906*da0073e9SAndroid Build Coastguard Worker 12907*da0073e9SAndroid Build Coastguard Worker # Expect all 5 threads to increment count since the hook runs before 12908*da0073e9SAndroid Build Coastguard Worker # the custom backward 12909*da0073e9SAndroid Build Coastguard Worker self.assertEqual(count[0], 5) 12910*da0073e9SAndroid Build Coastguard Worker self.assertEqual(err_count[0], 1) 12911*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, "foo") 12912*da0073e9SAndroid Build Coastguard Worker 12913*da0073e9SAndroid Build Coastguard Worker def test_dataparallel_saved_tensors_hooks(self): 12914*da0073e9SAndroid Build Coastguard Worker def pack(x): 12915*da0073e9SAndroid Build Coastguard Worker warnings.warn("pack") 12916*da0073e9SAndroid Build Coastguard Worker return x 12917*da0073e9SAndroid Build Coastguard Worker 12918*da0073e9SAndroid Build Coastguard Worker _self = self 12919*da0073e9SAndroid Build Coastguard Worker 12920*da0073e9SAndroid Build Coastguard Worker class Model(torch.nn.Module): 12921*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 12922*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12923*da0073e9SAndroid Build Coastguard Worker y = x * x 12924*da0073e9SAndroid Build Coastguard Worker if torch.cuda.device_count() >= 2: 12925*da0073e9SAndroid Build Coastguard Worker # DataParallel is calling the forward in different threads 12926*da0073e9SAndroid Build Coastguard Worker # without progating TLS, so hooks should not be called here 12927*da0073e9SAndroid Build Coastguard Worker _self.assertEqual(len(w), 0) 12928*da0073e9SAndroid Build Coastguard Worker else: 12929*da0073e9SAndroid Build Coastguard Worker # DataParallel only uses one thread 12930*da0073e9SAndroid Build Coastguard Worker # so hooks should be called here 12931*da0073e9SAndroid Build Coastguard Worker _self.assertGreater(len(w), 0) 12932*da0073e9SAndroid Build Coastguard Worker 12933*da0073e9SAndroid Build Coastguard Worker x = torch.ones(5, 5, requires_grad=True) 12934*da0073e9SAndroid Build Coastguard Worker model = torch.nn.DataParallel(Model()) 12935*da0073e9SAndroid Build Coastguard Worker 12936*da0073e9SAndroid Build Coastguard Worker with torch.autograd.graph.saved_tensors_hooks(pack, lambda x: x): 12937*da0073e9SAndroid Build Coastguard Worker model(x) 12938*da0073e9SAndroid Build Coastguard Worker with warnings.catch_warnings(record=True) as w: 12939*da0073e9SAndroid Build Coastguard Worker y = x * x 12940*da0073e9SAndroid Build Coastguard Worker # hooks should be called here 12941*da0073e9SAndroid Build Coastguard Worker _self.assertGreater(len(w), 0) 12942*da0073e9SAndroid Build Coastguard Worker 12943*da0073e9SAndroid Build Coastguard Worker def test_python_thread_in_middle(self): 12944*da0073e9SAndroid Build Coastguard Worker # User might write a network that starts on one CPU thread, then runs its second half 12945*da0073e9SAndroid Build Coastguard Worker # concurrently with other threads (either via python threading or fork/join calls), 12946*da0073e9SAndroid Build Coastguard Worker # then calls backward()/grad() on BOTH threads, like a Y pattern from input at the 12947*da0073e9SAndroid Build Coastguard Worker # bottom to output at the top. This way part of the GraphTask is being shared across 12948*da0073e9SAndroid Build Coastguard Worker # different threads and we need to ensure user specify retain_graph=True, otherwise 12949*da0073e9SAndroid Build Coastguard Worker # error out with the correct error message 12950*da0073e9SAndroid Build Coastguard Worker 12951*da0073e9SAndroid Build Coastguard Worker # Case 1: multiple backward with python threads, retain_graph=False 12952*da0073e9SAndroid Build Coastguard Worker # should throw error in some threads with no retain_graph. 12953*da0073e9SAndroid Build Coastguard Worker success_vs_raises = [0, 0] 12954*da0073e9SAndroid Build Coastguard Worker 12955*da0073e9SAndroid Build Coastguard Worker def train_fn_no_retain_graph(x): 12956*da0073e9SAndroid Build Coastguard Worker y = x + x**2 12957*da0073e9SAndroid Build Coastguard Worker try: 12958*da0073e9SAndroid Build Coastguard Worker y.sum().backward() 12959*da0073e9SAndroid Build Coastguard Worker success_vs_raises[0] += 1 12960*da0073e9SAndroid Build Coastguard Worker except RuntimeError as error: 12961*da0073e9SAndroid Build Coastguard Worker success_vs_raises[1] += 1 12962*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(error), "Specify retain_graph=True") 12963*da0073e9SAndroid Build Coastguard Worker 12964*da0073e9SAndroid Build Coastguard Worker x_no_retain = torch.ones(5, 5, requires_grad=True) 12965*da0073e9SAndroid Build Coastguard Worker y_no_retain = x_no_retain + x_no_retain**2 12966*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn( 12967*da0073e9SAndroid Build Coastguard Worker train_fn_no_retain_graph, (y_no_retain,), num_threads=5 12968*da0073e9SAndroid Build Coastguard Worker ) 12969*da0073e9SAndroid Build Coastguard Worker # at least one thread will be success in this case, all other threads should raise 12970*da0073e9SAndroid Build Coastguard Worker # with the error that throw to user to recommend them specify retain_graph=True 12971*da0073e9SAndroid Build Coastguard Worker self.assertTrue(success_vs_raises[0] >= 1) 12972*da0073e9SAndroid Build Coastguard Worker 12973*da0073e9SAndroid Build Coastguard Worker # multiple backward with python threads, no error with retain_graph=True 12974*da0073e9SAndroid Build Coastguard Worker def train_fn_retain_graph(x): 12975*da0073e9SAndroid Build Coastguard Worker y = x + x**2 12976*da0073e9SAndroid Build Coastguard Worker y.sum().backward(retain_graph=True) 12977*da0073e9SAndroid Build Coastguard Worker 12978*da0073e9SAndroid Build Coastguard Worker x_retain = torch.ones(5, 5, requires_grad=True) 12979*da0073e9SAndroid Build Coastguard Worker y_retain = x_retain + x_retain**2 12980*da0073e9SAndroid Build Coastguard Worker self._run_py_multithread_fn(train_fn_retain_graph, (y_retain,), num_threads=5) 12981*da0073e9SAndroid Build Coastguard Worker # result should equal to num_thread * gradients 12982*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 12983*da0073e9SAndroid Build Coastguard Worker x_retain.grad, 12984*da0073e9SAndroid Build Coastguard Worker 5 * (4 * x_retain**3 + 6 * (x_retain**2) + 4 * x_retain + 1), 12985*da0073e9SAndroid Build Coastguard Worker ) 12986*da0073e9SAndroid Build Coastguard Worker 12987*da0073e9SAndroid Build Coastguard Worker def test_fork_join_in_middle(self): 12988*da0073e9SAndroid Build Coastguard Worker # multiple backward with jit threads (fork/join primitive) 12989*da0073e9SAndroid Build Coastguard Worker # similar to test_python_thread_in_middle, we test with retain_graph=False/True 12990*da0073e9SAndroid Build Coastguard Worker 12991*da0073e9SAndroid Build Coastguard Worker # Case 1: multiple grad() calls with jit threads, retain_graph=False 12992*da0073e9SAndroid Build Coastguard Worker # should throw error in some threads with no retain_graph. 12993*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12994*da0073e9SAndroid Build Coastguard Worker def train_fn_jit_no_retain(middle, orig_x): 12995*da0073e9SAndroid Build Coastguard Worker y = middle + middle**2 12996*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad([y.sum()], [orig_x]) 12997*da0073e9SAndroid Build Coastguard Worker 12998*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 12999*da0073e9SAndroid Build Coastguard Worker def train_fn_fork_join_calls_no_retain(x): 13000*da0073e9SAndroid Build Coastguard Worker y_no_retain = (x + 3) * (x + 4) * 0.5 13001*da0073e9SAndroid Build Coastguard Worker 13002*da0073e9SAndroid Build Coastguard Worker fut = torch.jit._fork(train_fn_jit_no_retain, y_no_retain, x) 13003*da0073e9SAndroid Build Coastguard Worker grad_hat = train_fn_jit_no_retain(y_no_retain, x) 13004*da0073e9SAndroid Build Coastguard Worker grad = torch.jit._wait(fut) 13005*da0073e9SAndroid Build Coastguard Worker return grad, grad_hat 13006*da0073e9SAndroid Build Coastguard Worker 13007*da0073e9SAndroid Build Coastguard Worker try: 13008*da0073e9SAndroid Build Coastguard Worker train_fn_fork_join_calls_no_retain(torch.randn(5, 5, requires_grad=True)) 13009*da0073e9SAndroid Build Coastguard Worker except RuntimeError as error: 13010*da0073e9SAndroid Build Coastguard Worker self.assertRegex(str(error), "Specify retain_graph=True") 13011*da0073e9SAndroid Build Coastguard Worker 13012*da0073e9SAndroid Build Coastguard Worker # Case 2: no error with retain_graph=True 13013*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13014*da0073e9SAndroid Build Coastguard Worker def train_fn_jit_retain(middle, orig_x): 13015*da0073e9SAndroid Build Coastguard Worker y = middle + middle**2 13016*da0073e9SAndroid Build Coastguard Worker return torch.autograd.grad([y.sum()], [orig_x], retain_graph=True) 13017*da0073e9SAndroid Build Coastguard Worker 13018*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 13019*da0073e9SAndroid Build Coastguard Worker def train_fn_fork_join_calls_retain(x): 13020*da0073e9SAndroid Build Coastguard Worker y_retain = (x + 3) * (x + 4) * 0.5 13021*da0073e9SAndroid Build Coastguard Worker fut1 = torch.jit._fork(train_fn_jit_retain, y_retain, x) 13022*da0073e9SAndroid Build Coastguard Worker fut2 = torch.jit._fork(train_fn_jit_retain, y_retain, x) 13023*da0073e9SAndroid Build Coastguard Worker grad = train_fn_jit_retain(y_retain, x) 13024*da0073e9SAndroid Build Coastguard Worker grad1 = torch.jit._wait(fut1) 13025*da0073e9SAndroid Build Coastguard Worker grad2 = torch.jit._wait(fut2) 13026*da0073e9SAndroid Build Coastguard Worker return grad, grad1, grad2 13027*da0073e9SAndroid Build Coastguard Worker 13028*da0073e9SAndroid Build Coastguard Worker grad, grad1, grad2 = train_fn_fork_join_calls_retain( 13029*da0073e9SAndroid Build Coastguard Worker torch.randn(5, 5, requires_grad=True) 13030*da0073e9SAndroid Build Coastguard Worker ) 13031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, grad1) 13032*da0073e9SAndroid Build Coastguard Worker self.assertEqual(grad, grad2) 13033*da0073e9SAndroid Build Coastguard Worker 13034*da0073e9SAndroid Build Coastguard Worker def test_preserve_backtrace(self): 13035*da0073e9SAndroid Build Coastguard Worker class Foo(torch.autograd.Function): 13036*da0073e9SAndroid Build Coastguard Worker @staticmethod 13037*da0073e9SAndroid Build Coastguard Worker def forward(ctx, input): 13038*da0073e9SAndroid Build Coastguard Worker return input 13039*da0073e9SAndroid Build Coastguard Worker 13040*da0073e9SAndroid Build Coastguard Worker @staticmethod 13041*da0073e9SAndroid Build Coastguard Worker def backward(ctx, *grad): 13042*da0073e9SAndroid Build Coastguard Worker raise ValueError("something") 13043*da0073e9SAndroid Build Coastguard Worker 13044*da0073e9SAndroid Build Coastguard Worker t = torch.rand(10, requires_grad=True) 13045*da0073e9SAndroid Build Coastguard Worker try: 13046*da0073e9SAndroid Build Coastguard Worker Foo.apply(t).sum().backward() 13047*da0073e9SAndroid Build Coastguard Worker except Exception: 13048*da0073e9SAndroid Build Coastguard Worker import traceback 13049*da0073e9SAndroid Build Coastguard Worker 13050*da0073e9SAndroid Build Coastguard Worker tb = sys.exc_info()[2] 13051*da0073e9SAndroid Build Coastguard Worker tb_str = "\n".join(traceback.format_tb(tb)) 13052*da0073e9SAndroid Build Coastguard Worker self.assertTrue('raise ValueError("something")' in tb_str) 13053*da0073e9SAndroid Build Coastguard Worker 13054*da0073e9SAndroid Build Coastguard Worker # TODO(@anjali411): add an OpInfo based test for torch.cat 13055*da0073e9SAndroid Build Coastguard Worker # Issue: https://github.com/pytorch/pytorch/issues/51627 13056*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/75852 13057*da0073e9SAndroid Build Coastguard Worker def test_cat_stack_r_to_c(self): 13058*da0073e9SAndroid Build Coastguard Worker inp_c = torch.rand(3, 2, dtype=torch.cdouble, requires_grad=True) 13059*da0073e9SAndroid Build Coastguard Worker inp_r = torch.randn(3, 2, dtype=torch.double, requires_grad=True) 13060*da0073e9SAndroid Build Coastguard Worker 13061*da0073e9SAndroid Build Coastguard Worker def fn(x1, x2): 13062*da0073e9SAndroid Build Coastguard Worker return torch.cat((x1, x2), dim=-1) 13063*da0073e9SAndroid Build Coastguard Worker 13064*da0073e9SAndroid Build Coastguard Worker def fn2(x1, x2): 13065*da0073e9SAndroid Build Coastguard Worker return torch.stack((x1, x2), dim=-1) 13066*da0073e9SAndroid Build Coastguard Worker 13067*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(fn, [inp_r, inp_c], check_forward_ad=True) 13068*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(fn, [inp_c, inp_r], check_forward_ad=True) 13069*da0073e9SAndroid Build Coastguard Worker 13070*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(fn2, [inp_r, inp_c], check_forward_ad=True) 13071*da0073e9SAndroid Build Coastguard Worker torch.autograd.gradcheck(fn2, [inp_c, inp_r], check_forward_ad=True) 13072*da0073e9SAndroid Build Coastguard Worker 13073*da0073e9SAndroid Build Coastguard Worker def test_set_multithreading_enabled_as_context_manager_and_function(self): 13074*da0073e9SAndroid Build Coastguard Worker # Test as a context manager 13075*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 13076*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_multithreading_enabled()) 13077*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_multithreading_enabled()) 13078*da0073e9SAndroid Build Coastguard Worker 13079*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(True): 13080*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_multithreading_enabled()) 13081*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_multithreading_enabled()) 13082*da0073e9SAndroid Build Coastguard Worker 13083*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 13084*da0073e9SAndroid Build Coastguard Worker torch.autograd.set_multithreading_enabled(True) 13085*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_multithreading_enabled()) 13086*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_multithreading_enabled()) 13087*da0073e9SAndroid Build Coastguard Worker 13088*da0073e9SAndroid Build Coastguard Worker torch.autograd.set_multithreading_enabled(False) 13089*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.autograd.is_multithreading_enabled()) 13090*da0073e9SAndroid Build Coastguard Worker 13091*da0073e9SAndroid Build Coastguard Worker torch.autograd.set_multithreading_enabled(True) 13092*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.autograd.is_multithreading_enabled()) 13093*da0073e9SAndroid Build Coastguard Worker 13094*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "test requires CUDA") 13095*da0073e9SAndroid Build Coastguard Worker def test_custom_function_propagates_errors_from_device_thread(self): 13096*da0073e9SAndroid Build Coastguard Worker class MyFunc(Function): 13097*da0073e9SAndroid Build Coastguard Worker @staticmethod 13098*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 13099*da0073e9SAndroid Build Coastguard Worker return x 13100*da0073e9SAndroid Build Coastguard Worker 13101*da0073e9SAndroid Build Coastguard Worker @staticmethod 13102*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 13103*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("blah") 13104*da0073e9SAndroid Build Coastguard Worker return gO 13105*da0073e9SAndroid Build Coastguard Worker 13106*da0073e9SAndroid Build Coastguard Worker t = torch.tensor([1.0, 2.0], requires_grad=True, device=torch.device("cuda")) 13107*da0073e9SAndroid Build Coastguard Worker out = MyFunc.apply(t).sum() 13108*da0073e9SAndroid Build Coastguard Worker 13109*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "blah"): 13110*da0073e9SAndroid Build Coastguard Worker out.backward() 13111*da0073e9SAndroid Build Coastguard Worker 13112*da0073e9SAndroid Build Coastguard Worker 13113*da0073e9SAndroid Build Coastguard Workerclass TestNestedCheckpoint(TestCase): 13114*da0073e9SAndroid Build Coastguard Worker @staticmethod 13115*da0073e9SAndroid Build Coastguard Worker def grad(fn): 13116*da0073e9SAndroid Build Coastguard Worker def wrapper(x): 13117*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 13118*da0073e9SAndroid Build Coastguard Worker out = fn(x) 13119*da0073e9SAndroid Build Coastguard Worker (grad_input,) = torch.autograd.grad(out, inputs=(x,), create_graph=True) 13120*da0073e9SAndroid Build Coastguard Worker return grad_input 13121*da0073e9SAndroid Build Coastguard Worker 13122*da0073e9SAndroid Build Coastguard Worker return wrapper 13123*da0073e9SAndroid Build Coastguard Worker 13124*da0073e9SAndroid Build Coastguard Worker @staticmethod 13125*da0073e9SAndroid Build Coastguard Worker def sum(fn): 13126*da0073e9SAndroid Build Coastguard Worker def wrapped(x): 13127*da0073e9SAndroid Build Coastguard Worker return fn(x).sum() 13128*da0073e9SAndroid Build Coastguard Worker 13129*da0073e9SAndroid Build Coastguard Worker return wrapped 13130*da0073e9SAndroid Build Coastguard Worker 13131*da0073e9SAndroid Build Coastguard Worker @staticmethod 13132*da0073e9SAndroid Build Coastguard Worker def checkpoint(fn): 13133*da0073e9SAndroid Build Coastguard Worker def wrapped(*args, **kwargs): 13134*da0073e9SAndroid Build Coastguard Worker return torch.utils.checkpoint.checkpoint( 13135*da0073e9SAndroid Build Coastguard Worker fn, *args, use_reentrant=False, **kwargs 13136*da0073e9SAndroid Build Coastguard Worker ) 13137*da0073e9SAndroid Build Coastguard Worker 13138*da0073e9SAndroid Build Coastguard Worker return wrapped 13139*da0073e9SAndroid Build Coastguard Worker 13140*da0073e9SAndroid Build Coastguard Worker def get_tests(self, fn): 13141*da0073e9SAndroid Build Coastguard Worker grad, c = self.grad, self.checkpoint 13142*da0073e9SAndroid Build Coastguard Worker 13143*da0073e9SAndroid Build Coastguard Worker tests = ( 13144*da0073e9SAndroid Build Coastguard Worker # function <> tuple of function arbitrarily wrapped in checkpoint in various ways 13145*da0073e9SAndroid Build Coastguard Worker (fn, (c(fn), c(c(fn)))), 13146*da0073e9SAndroid Build Coastguard Worker (grad(fn), (grad(c(fn)), grad(c(c(fn))))), 13147*da0073e9SAndroid Build Coastguard Worker ( 13148*da0073e9SAndroid Build Coastguard Worker grad(grad(fn)), 13149*da0073e9SAndroid Build Coastguard Worker (grad(c(grad(fn))), c(grad(grad(c(fn)))), grad(c(grad(c(fn))))), 13150*da0073e9SAndroid Build Coastguard Worker ), 13151*da0073e9SAndroid Build Coastguard Worker ( 13152*da0073e9SAndroid Build Coastguard Worker grad(grad(grad(fn))), 13153*da0073e9SAndroid Build Coastguard Worker (grad(c(grad(grad(c(fn))))), grad(c(grad(c(grad(c(fn))))))), 13154*da0073e9SAndroid Build Coastguard Worker ), 13155*da0073e9SAndroid Build Coastguard Worker ) 13156*da0073e9SAndroid Build Coastguard Worker return tests 13157*da0073e9SAndroid Build Coastguard Worker 13158*da0073e9SAndroid Build Coastguard Worker def check_graph_dies(self, fn): 13159*da0073e9SAndroid Build Coastguard Worker def iter_graph(roots): 13160*da0073e9SAndroid Build Coastguard Worker if not roots: 13161*da0073e9SAndroid Build Coastguard Worker return 13162*da0073e9SAndroid Build Coastguard Worker seen = set() 13163*da0073e9SAndroid Build Coastguard Worker q = collections.deque() 13164*da0073e9SAndroid Build Coastguard Worker for node in roots: 13165*da0073e9SAndroid Build Coastguard Worker if node is not None: 13166*da0073e9SAndroid Build Coastguard Worker seen.add(node) 13167*da0073e9SAndroid Build Coastguard Worker q.append(node) 13168*da0073e9SAndroid Build Coastguard Worker 13169*da0073e9SAndroid Build Coastguard Worker while q: 13170*da0073e9SAndroid Build Coastguard Worker node = q.popleft() 13171*da0073e9SAndroid Build Coastguard Worker for fn, _idx in node.next_functions: 13172*da0073e9SAndroid Build Coastguard Worker if fn in seen or fn is None: 13173*da0073e9SAndroid Build Coastguard Worker continue 13174*da0073e9SAndroid Build Coastguard Worker seen.add(fn) 13175*da0073e9SAndroid Build Coastguard Worker q.append(fn) 13176*da0073e9SAndroid Build Coastguard Worker 13177*da0073e9SAndroid Build Coastguard Worker yield node 13178*da0073e9SAndroid Build Coastguard Worker 13179*da0073e9SAndroid Build Coastguard Worker class Handle: 13180*da0073e9SAndroid Build Coastguard Worker __slot__ = ["node_name"] 13181*da0073e9SAndroid Build Coastguard Worker 13182*da0073e9SAndroid Build Coastguard Worker def __init__(self, node_name): 13183*da0073e9SAndroid Build Coastguard Worker self.node_name = node_name 13184*da0073e9SAndroid Build Coastguard Worker 13185*da0073e9SAndroid Build Coastguard Worker def scope(): 13186*da0073e9SAndroid Build Coastguard Worker a = torch.randn((), requires_grad=True) 13187*da0073e9SAndroid Build Coastguard Worker out = fn(a) 13188*da0073e9SAndroid Build Coastguard Worker refs = [] 13189*da0073e9SAndroid Build Coastguard Worker for node in iter_graph([out.grad_fn]): 13190*da0073e9SAndroid Build Coastguard Worker handle = Handle(node.name()) 13191*da0073e9SAndroid Build Coastguard Worker refs.append(weakref.ref(handle)) 13192*da0073e9SAndroid Build Coastguard Worker node.metadata["blah"] = handle 13193*da0073e9SAndroid Build Coastguard Worker return refs 13194*da0073e9SAndroid Build Coastguard Worker 13195*da0073e9SAndroid Build Coastguard Worker refs = scope() 13196*da0073e9SAndroid Build Coastguard Worker node_names = [ref().node_name for ref in refs if ref() is not None] 13197*da0073e9SAndroid Build Coastguard Worker if len(node_names) > 0: 13198*da0073e9SAndroid Build Coastguard Worker print("Nodes still alive:", node_names) 13199*da0073e9SAndroid Build Coastguard Worker 13200*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(node_names), 0) 13201*da0073e9SAndroid Build Coastguard Worker 13202*da0073e9SAndroid Build Coastguard Worker @parametrize("early_stop", [True, False]) 13203*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint(self, early_stop): 13204*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13205*da0073e9SAndroid Build Coastguard Worker x = torch.randn((), requires_grad=True) 13206*da0073e9SAndroid Build Coastguard Worker 13207*da0073e9SAndroid Build Coastguard Worker def f(x): 13208*da0073e9SAndroid Build Coastguard Worker out = x.sin().exp().sin() 13209*da0073e9SAndroid Build Coastguard Worker return out 13210*da0073e9SAndroid Build Coastguard Worker 13211*da0073e9SAndroid Build Coastguard Worker def g(x): 13212*da0073e9SAndroid Build Coastguard Worker a = x.sin().exp().sin() 13213*da0073e9SAndroid Build Coastguard Worker b = x.sin().exp().sin() 13214*da0073e9SAndroid Build Coastguard Worker (ga,) = torch.autograd.grad(a, x) 13215*da0073e9SAndroid Build Coastguard Worker (gb,) = torch.autograd.grad(b, x) 13216*da0073e9SAndroid Build Coastguard Worker return x.sin() 13217*da0073e9SAndroid Build Coastguard Worker 13218*da0073e9SAndroid Build Coastguard Worker for fn in (f, g): 13219*da0073e9SAndroid Build Coastguard Worker for expected_fn, actual_fns in self.get_tests(fn): 13220*da0073e9SAndroid Build Coastguard Worker expected = expected_fn(x) 13221*da0073e9SAndroid Build Coastguard Worker 13222*da0073e9SAndroid Build Coastguard Worker for actual_fn in actual_fns: 13223*da0073e9SAndroid Build Coastguard Worker actual = actual_fn(x) 13224*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, actual)) 13225*da0073e9SAndroid Build Coastguard Worker self.check_graph_dies(actual_fn) 13226*da0073e9SAndroid Build Coastguard Worker 13227*da0073e9SAndroid Build Coastguard Worker @parametrize("early_stop", [True, False]) 13228*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_two_children(self, early_stop): 13229*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13230*da0073e9SAndroid Build Coastguard Worker grad, sum, c = self.grad, self.sum, self.checkpoint 13231*da0073e9SAndroid Build Coastguard Worker 13232*da0073e9SAndroid Build Coastguard Worker def f(x): 13233*da0073e9SAndroid Build Coastguard Worker return x.sin().exp().sin() 13234*da0073e9SAndroid Build Coastguard Worker 13235*da0073e9SAndroid Build Coastguard Worker def g(x): 13236*da0073e9SAndroid Build Coastguard Worker return x.cos().sin().exp() 13237*da0073e9SAndroid Build Coastguard Worker 13238*da0073e9SAndroid Build Coastguard Worker def hc(x): 13239*da0073e9SAndroid Build Coastguard Worker return c(g)(c(f)(x)) 13240*da0073e9SAndroid Build Coastguard Worker 13241*da0073e9SAndroid Build Coastguard Worker def h(x): 13242*da0073e9SAndroid Build Coastguard Worker return g(f(x)) 13243*da0073e9SAndroid Build Coastguard Worker 13244*da0073e9SAndroid Build Coastguard Worker a = torch.randn(3, 3, requires_grad=True) 13245*da0073e9SAndroid Build Coastguard Worker expected = grad(sum(grad(sum(h))))(a) 13246*da0073e9SAndroid Build Coastguard Worker actual = grad(sum(grad(sum(c(hc)))))(a) 13247*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, actual)) 13248*da0073e9SAndroid Build Coastguard Worker 13249*da0073e9SAndroid Build Coastguard Worker actual = grad(sum(c(grad(sum(c(hc))))))(a) 13250*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(expected, actual)) 13251*da0073e9SAndroid Build Coastguard Worker 13252*da0073e9SAndroid Build Coastguard Worker self.check_graph_dies(grad(c(hc))) 13253*da0073e9SAndroid Build Coastguard Worker self.check_graph_dies(grad(sum(grad(sum(c(hc)))))) 13254*da0073e9SAndroid Build Coastguard Worker self.check_graph_dies(grad(sum(c(grad(sum(c(hc))))))) 13255*da0073e9SAndroid Build Coastguard Worker 13256*da0073e9SAndroid Build Coastguard Worker @parametrize("early_stop", [True, False]) 13257*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_non_tensor_inputs_and_outputs(self, early_stop): 13258*da0073e9SAndroid Build Coastguard Worker def fn(k, a, b, f): 13259*da0073e9SAndroid Build Coastguard Worker return f(k * a * b.exp()), 1, "abcd" 13260*da0073e9SAndroid Build Coastguard Worker 13261*da0073e9SAndroid Build Coastguard Worker k = 3 13262*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(2.0, requires_grad=True) 13263*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(3.0, requires_grad=True) 13264*da0073e9SAndroid Build Coastguard Worker 13265*da0073e9SAndroid Build Coastguard Worker def f(x): 13266*da0073e9SAndroid Build Coastguard Worker return x.sin() 13267*da0073e9SAndroid Build Coastguard Worker 13268*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13269*da0073e9SAndroid Build Coastguard Worker out, _unused1, _unused2 = checkpoint(fn, k, a, b, f, use_reentrant=False) 13270*da0073e9SAndroid Build Coastguard Worker actual_grads = torch.autograd.grad(out, (a, b)) 13271*da0073e9SAndroid Build Coastguard Worker 13272*da0073e9SAndroid Build Coastguard Worker out, _unused1, _unused2 = fn(k, a, b, f) 13273*da0073e9SAndroid Build Coastguard Worker expected_grads = torch.autograd.grad(out, (a, b)) 13274*da0073e9SAndroid Build Coastguard Worker for actual, expected in zip(actual_grads, expected_grads): 13275*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(actual, expected)) 13276*da0073e9SAndroid Build Coastguard Worker 13277*da0073e9SAndroid Build Coastguard Worker @parametrize("early_stop", [True, False]) 13278*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_kwargs(self, early_stop): 13279*da0073e9SAndroid Build Coastguard Worker def fn(a, blah=None): 13280*da0073e9SAndroid Build Coastguard Worker out = a.sin().exp() 13281*da0073e9SAndroid Build Coastguard Worker if blah is not None: 13282*da0073e9SAndroid Build Coastguard Worker out = out * blah 13283*da0073e9SAndroid Build Coastguard Worker return out.sin().exp() 13284*da0073e9SAndroid Build Coastguard Worker 13285*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(2.0, requires_grad=True) 13286*da0073e9SAndroid Build Coastguard Worker b = torch.tensor(3.0, requires_grad=True) 13287*da0073e9SAndroid Build Coastguard Worker 13288*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13289*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, blah=b, use_reentrant=False) 13290*da0073e9SAndroid Build Coastguard Worker actual_grads = torch.autograd.grad(out, (a, b)) 13291*da0073e9SAndroid Build Coastguard Worker 13292*da0073e9SAndroid Build Coastguard Worker out = fn(a, blah=b) 13293*da0073e9SAndroid Build Coastguard Worker expected_grads = torch.autograd.grad(out, (a, b)) 13294*da0073e9SAndroid Build Coastguard Worker for actual, expected in zip(actual_grads, expected_grads): 13295*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.allclose(actual, expected)) 13296*da0073e9SAndroid Build Coastguard Worker 13297*da0073e9SAndroid Build Coastguard Worker @parametrize("early_stop", [True, False]) 13298*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_same_graph(self, early_stop): 13299*da0073e9SAndroid Build Coastguard Worker counter = [0] 13300*da0073e9SAndroid Build Coastguard Worker 13301*da0073e9SAndroid Build Coastguard Worker def hook(*_unused_args): 13302*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13303*da0073e9SAndroid Build Coastguard Worker 13304*da0073e9SAndroid Build Coastguard Worker def fn(a): 13305*da0073e9SAndroid Build Coastguard Worker return a.sin().cos().sin() 13306*da0073e9SAndroid Build Coastguard Worker 13307*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13308*da0073e9SAndroid Build Coastguard Worker 13309*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13310*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 13311*da0073e9SAndroid Build Coastguard Worker # The hook is registered on the original graph 13312*da0073e9SAndroid Build Coastguard Worker out.grad_fn.next_functions[0][0].register_hook(hook) 13313*da0073e9SAndroid Build Coastguard Worker # And backward is performed on the original graph 13314*da0073e9SAndroid Build Coastguard Worker out.backward() 13315*da0073e9SAndroid Build Coastguard Worker 13316*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 13317*da0073e9SAndroid Build Coastguard Worker 13318*da0073e9SAndroid Build Coastguard Worker @parametrize("early_stop", [True, False]) 13319*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_reentrant_backwards(self, early_stop): 13320*da0073e9SAndroid Build Coastguard Worker def fn(a): 13321*da0073e9SAndroid Build Coastguard Worker x = a.sin().cos() 13322*da0073e9SAndroid Build Coastguard Worker out = x.sin() 13323*da0073e9SAndroid Build Coastguard Worker return x, out 13324*da0073e9SAndroid Build Coastguard Worker 13325*da0073e9SAndroid Build Coastguard Worker def hook(*_unused_args): 13326*da0073e9SAndroid Build Coastguard Worker # do backward again, but skip over the part of the graph where 13327*da0073e9SAndroid Build Coastguard Worker # the hook was registered 13328*da0073e9SAndroid Build Coastguard Worker x.backward(retain_graph=True) 13329*da0073e9SAndroid Build Coastguard Worker 13330*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13331*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(early_stop): 13332*da0073e9SAndroid Build Coastguard Worker x, out = checkpoint(fn, a, use_reentrant=False) 13333*da0073e9SAndroid Build Coastguard Worker out.grad_fn.register_hook(hook) 13334*da0073e9SAndroid Build Coastguard Worker out.backward(retain_graph=True) 13335*da0073e9SAndroid Build Coastguard Worker 13336*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_set_early_stop(self): 13337*da0073e9SAndroid Build Coastguard Worker counter = [0] 13338*da0073e9SAndroid Build Coastguard Worker 13339*da0073e9SAndroid Build Coastguard Worker def clone(x): 13340*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13341*da0073e9SAndroid Build Coastguard Worker return x.clone() 13342*da0073e9SAndroid Build Coastguard Worker 13343*da0073e9SAndroid Build Coastguard Worker def fn(x): 13344*da0073e9SAndroid Build Coastguard Worker # Since clone does not save anything, it is not recomputed iff 13345*da0073e9SAndroid Build Coastguard Worker # early stop is enabled. 13346*da0073e9SAndroid Build Coastguard Worker return clone(x.sin().cos()) 13347*da0073e9SAndroid Build Coastguard Worker 13348*da0073e9SAndroid Build Coastguard Worker # Early stopping is enabled by default 13349*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13350*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 13351*da0073e9SAndroid Build Coastguard Worker out.backward() 13352*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 13353*da0073e9SAndroid Build Coastguard Worker 13354*da0073e9SAndroid Build Coastguard Worker # Try using the context manager to set early stopping to False. 13355*da0073e9SAndroid Build Coastguard Worker # Expect early stopping to be disabled for all checkpoints ran under 13356*da0073e9SAndroid Build Coastguard Worker # the context manager, even though context manager is no longer active 13357*da0073e9SAndroid Build Coastguard Worker # when backward/recomputation is performed. 13358*da0073e9SAndroid Build Coastguard Worker counter = [0] 13359*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13360*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(False): 13361*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 13362*da0073e9SAndroid Build Coastguard Worker 13363*da0073e9SAndroid Build Coastguard Worker out.backward() 13364*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 2) 13365*da0073e9SAndroid Build Coastguard Worker 13366*da0073e9SAndroid Build Coastguard Worker def test_nested_checkpoint_set_early_stop_no_recompution_needed(self): 13367*da0073e9SAndroid Build Coastguard Worker # Case 1: We have one tensor saved and its the input 13368*da0073e9SAndroid Build Coastguard Worker 13369*da0073e9SAndroid Build Coastguard Worker # We have two different counters here because in this case we actually 13370*da0073e9SAndroid Build Coastguard Worker # do call into x.sin() at the python level during recomputation whether 13371*da0073e9SAndroid Build Coastguard Worker # or not early stop is enabled. This is because the early stopping 13372*da0073e9SAndroid Build Coastguard Worker # only happens at the autograd level (preventing us from reaching the 13373*da0073e9SAndroid Build Coastguard Worker # backend). 13374*da0073e9SAndroid Build Coastguard Worker python_dispatch_counter = [0] 13375*da0073e9SAndroid Build Coastguard Worker counter = [0] 13376*da0073e9SAndroid Build Coastguard Worker 13377*da0073e9SAndroid Build Coastguard Worker class SinCounterMode(TorchDispatchMode): 13378*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 13379*da0073e9SAndroid Build Coastguard Worker self.count = 0 13380*da0073e9SAndroid Build Coastguard Worker 13381*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 13382*da0073e9SAndroid Build Coastguard Worker kwargs = {} if kwargs is None else kwargs 13383*da0073e9SAndroid Build Coastguard Worker if func is torch.ops.aten.sin.default: 13384*da0073e9SAndroid Build Coastguard Worker self.count += 1 13385*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 13386*da0073e9SAndroid Build Coastguard Worker 13387*da0073e9SAndroid Build Coastguard Worker def fn(x): 13388*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13389*da0073e9SAndroid Build Coastguard Worker return x.sin() 13390*da0073e9SAndroid Build Coastguard Worker 13391*da0073e9SAndroid Build Coastguard Worker # With early stopping (enabled by default) 13392*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13393*da0073e9SAndroid Build Coastguard Worker with SinCounterMode() as python_dispatch_counter: # noqa: F811 13394*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 13395*da0073e9SAndroid Build Coastguard Worker out.backward() 13396*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 2) 13397*da0073e9SAndroid Build Coastguard Worker self.assertEqual(python_dispatch_counter.count, 1) 13398*da0073e9SAndroid Build Coastguard Worker 13399*da0073e9SAndroid Build Coastguard Worker # Without early stopping 13400*da0073e9SAndroid Build Coastguard Worker counter = [0] 13401*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13402*da0073e9SAndroid Build Coastguard Worker with SinCounterMode() as python_dispatch_counter: 13403*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(False): 13404*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, a, use_reentrant=False) 13405*da0073e9SAndroid Build Coastguard Worker out.backward() 13406*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 2) 13407*da0073e9SAndroid Build Coastguard Worker self.assertEqual(python_dispatch_counter.count, 2) 13408*da0073e9SAndroid Build Coastguard Worker 13409*da0073e9SAndroid Build Coastguard Worker # Case 2: Forward saves no tensors 13410*da0073e9SAndroid Build Coastguard Worker 13411*da0073e9SAndroid Build Coastguard Worker # Since unpack isn't even called, counter is 1 whether or not early stop 13412*da0073e9SAndroid Build Coastguard Worker # is enabled! 13413*da0073e9SAndroid Build Coastguard Worker counter = [0] 13414*da0073e9SAndroid Build Coastguard Worker 13415*da0073e9SAndroid Build Coastguard Worker def fn2(x): 13416*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13417*da0073e9SAndroid Build Coastguard Worker return x.clone() 13418*da0073e9SAndroid Build Coastguard Worker 13419*da0073e9SAndroid Build Coastguard Worker # With early stopping (enabled by default) 13420*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13421*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn2, a, use_reentrant=False) 13422*da0073e9SAndroid Build Coastguard Worker out.backward() 13423*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 13424*da0073e9SAndroid Build Coastguard Worker 13425*da0073e9SAndroid Build Coastguard Worker # Without early stopping 13426*da0073e9SAndroid Build Coastguard Worker counter = [0] 13427*da0073e9SAndroid Build Coastguard Worker a = torch.tensor(1.0, requires_grad=True) 13428*da0073e9SAndroid Build Coastguard Worker with torch.utils.checkpoint.set_checkpoint_early_stop(False): 13429*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn2, a, use_reentrant=False) 13430*da0073e9SAndroid Build Coastguard Worker out.backward() 13431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 13432*da0073e9SAndroid Build Coastguard Worker 13433*da0073e9SAndroid Build Coastguard Worker 13434*da0073e9SAndroid Build Coastguard Workerclass TestSelectiveActivationCheckpoint(TestCase): 13435*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(not TEST_CUDA, "requires CUDA") 13436*da0073e9SAndroid Build Coastguard Worker def test_flops_and_mem(self): 13437*da0073e9SAndroid Build Coastguard Worker # From https://github.com/pytorch/pytorch/pull/126320 13438*da0073e9SAndroid Build Coastguard Worker def get_act_mem(f): 13439*da0073e9SAndroid Build Coastguard Worker out = f() 13440*da0073e9SAndroid Build Coastguard Worker out.backward() 13441*da0073e9SAndroid Build Coastguard Worker # Why do one forward and backward? 13442*da0073e9SAndroid Build Coastguard Worker start_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] 13443*da0073e9SAndroid Build Coastguard Worker out = f() 13444*da0073e9SAndroid Build Coastguard Worker cur_mem = torch.cuda.memory_stats()["requested_bytes.all.current"] 13445*da0073e9SAndroid Build Coastguard Worker act_mem = (cur_mem - start_mem) / (1024 * 1024) 13446*da0073e9SAndroid Build Coastguard Worker out.backward() 13447*da0073e9SAndroid Build Coastguard Worker return act_mem 13448*da0073e9SAndroid Build Coastguard Worker 13449*da0073e9SAndroid Build Coastguard Worker def get_bw_flops(f): 13450*da0073e9SAndroid Build Coastguard Worker # Normalized so that a 512 square matmul returns 1 13451*da0073e9SAndroid Build Coastguard Worker f().backward() 13452*da0073e9SAndroid Build Coastguard Worker out = f() 13453*da0073e9SAndroid Build Coastguard Worker # NB: FlopCounterMode is pushed onto the mode stack before CachedMode, so 13454*da0073e9SAndroid Build Coastguard Worker # it will be able to observe whether an op is cached or not. 13455*da0073e9SAndroid Build Coastguard Worker with FlopCounterMode(display=False) as mode: 13456*da0073e9SAndroid Build Coastguard Worker out.backward() 13457*da0073e9SAndroid Build Coastguard Worker return mode.get_total_flops() / (512**3 * 2) 13458*da0073e9SAndroid Build Coastguard Worker 13459*da0073e9SAndroid Build Coastguard Worker x = torch.randn(512, 512, requires_grad=True, device="cuda") 13460*da0073e9SAndroid Build Coastguard Worker y = torch.randn(512, 512, requires_grad=True, device="cuda") 13461*da0073e9SAndroid Build Coastguard Worker 13462*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 13463*da0073e9SAndroid Build Coastguard Worker return torch.mm(x.cos(), y).sin().sum() 13464*da0073e9SAndroid Build Coastguard Worker 13465*da0073e9SAndroid Build Coastguard Worker def fn_ac(x, y): 13466*da0073e9SAndroid Build Coastguard Worker return checkpoint(fn, x, y, use_reentrant=False) 13467*da0073e9SAndroid Build Coastguard Worker 13468*da0073e9SAndroid Build Coastguard Worker def fn_sac(x, y): 13469*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13470*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 13471*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.mm.default], 13472*da0073e9SAndroid Build Coastguard Worker ) 13473*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13474*da0073e9SAndroid Build Coastguard Worker return out 13475*da0073e9SAndroid Build Coastguard Worker 13476*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 13477*da0073e9SAndroid Build Coastguard Worker if op == torch.ops.aten.mm.default: 13478*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13479*da0073e9SAndroid Build Coastguard Worker else: 13480*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 13481*da0073e9SAndroid Build Coastguard Worker 13482*da0073e9SAndroid Build Coastguard Worker def fn_sac2(x, y): 13483*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13484*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 13485*da0073e9SAndroid Build Coastguard Worker policy_fn, 13486*da0073e9SAndroid Build Coastguard Worker ) 13487*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13488*da0073e9SAndroid Build Coastguard Worker return out 13489*da0073e9SAndroid Build Coastguard Worker 13490*da0073e9SAndroid Build Coastguard Worker def policy_fn_bool(ctx, op, *args, **kwargs): 13491*da0073e9SAndroid Build Coastguard Worker return op == torch.ops.aten.mm.default 13492*da0073e9SAndroid Build Coastguard Worker 13493*da0073e9SAndroid Build Coastguard Worker def fn_sac3(x, y): 13494*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13495*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 13496*da0073e9SAndroid Build Coastguard Worker policy_fn_bool, 13497*da0073e9SAndroid Build Coastguard Worker ) 13498*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13499*da0073e9SAndroid Build Coastguard Worker return out 13500*da0073e9SAndroid Build Coastguard Worker 13501*da0073e9SAndroid Build Coastguard Worker act_mem_noac = get_act_mem(lambda: fn(x, y)) 13502*da0073e9SAndroid Build Coastguard Worker bw_flops_noac = get_bw_flops(lambda: fn(x, y)) 13503*da0073e9SAndroid Build Coastguard Worker 13504*da0073e9SAndroid Build Coastguard Worker self.assertEqual(act_mem_noac, 2.0) 13505*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bw_flops_noac, 2.0) 13506*da0073e9SAndroid Build Coastguard Worker 13507*da0073e9SAndroid Build Coastguard Worker act_mem_ac = get_act_mem(lambda: fn_ac(x, y)) 13508*da0073e9SAndroid Build Coastguard Worker bw_flops_ac = get_bw_flops(lambda: fn_ac(x, y)) 13509*da0073e9SAndroid Build Coastguard Worker 13510*da0073e9SAndroid Build Coastguard Worker self.assertEqual(act_mem_ac, 0.0) 13511*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bw_flops_ac, 3.0) 13512*da0073e9SAndroid Build Coastguard Worker 13513*da0073e9SAndroid Build Coastguard Worker act_mem_sac = get_act_mem(lambda: fn_sac(x, y)) 13514*da0073e9SAndroid Build Coastguard Worker bw_flops_sac = get_bw_flops(lambda: fn_sac(x, y)) 13515*da0073e9SAndroid Build Coastguard Worker 13516*da0073e9SAndroid Build Coastguard Worker self.assertEqual(act_mem_sac, 1.0) 13517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bw_flops_sac, 2.0) 13518*da0073e9SAndroid Build Coastguard Worker 13519*da0073e9SAndroid Build Coastguard Worker act_mem_sac2 = get_act_mem(lambda: fn_sac2(x, y)) 13520*da0073e9SAndroid Build Coastguard Worker bw_flops_sac2 = get_bw_flops(lambda: fn_sac2(x, y)) 13521*da0073e9SAndroid Build Coastguard Worker 13522*da0073e9SAndroid Build Coastguard Worker self.assertEqual(act_mem_sac2, 1.0) 13523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bw_flops_sac2, 2.0) 13524*da0073e9SAndroid Build Coastguard Worker 13525*da0073e9SAndroid Build Coastguard Worker act_mem_sac3 = get_act_mem(lambda: fn_sac3(x, y)) 13526*da0073e9SAndroid Build Coastguard Worker bw_flops_sac3 = get_bw_flops(lambda: fn_sac3(x, y)) 13527*da0073e9SAndroid Build Coastguard Worker 13528*da0073e9SAndroid Build Coastguard Worker self.assertEqual(act_mem_sac3, 1.0) 13529*da0073e9SAndroid Build Coastguard Worker self.assertEqual(bw_flops_sac3, 2.0) 13530*da0073e9SAndroid Build Coastguard Worker 13531*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13532*da0073e9SAndroid Build Coastguard Worker def test_output_already_has_autograd_meta(self): 13533*da0073e9SAndroid Build Coastguard Worker # View of tensor of non-differentiable dtype still has AutogradMeta 13534*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 13535*da0073e9SAndroid Build Coastguard Worker return x.view(-1), y.sin().cos() 13536*da0073e9SAndroid Build Coastguard Worker 13537*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3], dtype=torch.int64) 13538*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, requires_grad=True) 13539*da0073e9SAndroid Build Coastguard Worker 13540*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13541*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 13542*da0073e9SAndroid Build Coastguard Worker [torch.ops.aten.view.default], 13543*da0073e9SAndroid Build Coastguard Worker ) 13544*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, y, use_reentrant=False, context_fn=context_fn) 13545*da0073e9SAndroid Build Coastguard Worker out[1].sum().backward() 13546*da0073e9SAndroid Build Coastguard Worker 13547*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13548*da0073e9SAndroid Build Coastguard Worker def test_subclass_dispatching_sizes(self): 13549*da0073e9SAndroid Build Coastguard Worker # Test that we ignore ops that grab metadata like torch.ops.aten.sym_size.default 13550*da0073e9SAndroid Build Coastguard Worker # Caching such metadata ops can be problematic when the following are satisfied: 13551*da0073e9SAndroid Build Coastguard Worker # 13552*da0073e9SAndroid Build Coastguard Worker # 1. size/strides are dispatched upon 13553*da0073e9SAndroid Build Coastguard Worker # 2. our policy saves sizes 13554*da0073e9SAndroid Build Coastguard Worker ta = torch.randn(6, 2) 13555*da0073e9SAndroid Build Coastguard Worker 13556*da0073e9SAndroid Build Coastguard Worker class CustomSizeDynamicShapesTensor(torch.Tensor): 13557*da0073e9SAndroid Build Coastguard Worker @staticmethod 13558*da0073e9SAndroid Build Coastguard Worker def __new__(cls, inner): 13559*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( 13560*da0073e9SAndroid Build Coastguard Worker # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 13561*da0073e9SAndroid Build Coastguard Worker # Calling the overload that has kwargs causes us to go down the first overload path, 13562*da0073e9SAndroid Build Coastguard Worker # which will **always** specialize sizes. 13563*da0073e9SAndroid Build Coastguard Worker # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 13564*da0073e9SAndroid Build Coastguard Worker cls, 13565*da0073e9SAndroid Build Coastguard Worker inner.size(), 13566*da0073e9SAndroid Build Coastguard Worker inner.stride(), 13567*da0073e9SAndroid Build Coastguard Worker None, 13568*da0073e9SAndroid Build Coastguard Worker None, 13569*da0073e9SAndroid Build Coastguard Worker inner.dtype, 13570*da0073e9SAndroid Build Coastguard Worker inner.layout, 13571*da0073e9SAndroid Build Coastguard Worker inner.device, 13572*da0073e9SAndroid Build Coastguard Worker False, 13573*da0073e9SAndroid Build Coastguard Worker inner.requires_grad, 13574*da0073e9SAndroid Build Coastguard Worker "sizes", 13575*da0073e9SAndroid Build Coastguard Worker ) 13576*da0073e9SAndroid Build Coastguard Worker 13577*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 13578*da0073e9SAndroid Build Coastguard Worker self.inner = inner 13579*da0073e9SAndroid Build Coastguard Worker 13580*da0073e9SAndroid Build Coastguard Worker @classmethod 13581*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 13582*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 13583*da0073e9SAndroid Build Coastguard Worker kwargs = {} 13584*da0073e9SAndroid Build Coastguard Worker args_inner = torch.utils._pytree.tree_map_only( 13585*da0073e9SAndroid Build Coastguard Worker cls, lambda x: x.inner, args 13586*da0073e9SAndroid Build Coastguard Worker ) 13587*da0073e9SAndroid Build Coastguard Worker out_inner = func(*args_inner, **kwargs) 13588*da0073e9SAndroid Build Coastguard Worker return torch.utils._pytree.tree_map_only( 13589*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda x: cls(x), out_inner 13590*da0073e9SAndroid Build Coastguard Worker ) 13591*da0073e9SAndroid Build Coastguard Worker 13592*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 13593*da0073e9SAndroid Build Coastguard Worker if op is torch.ops.aten.sym_size.default: 13594*da0073e9SAndroid Build Coastguard Worker # Silently ignored! 13595*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13596*da0073e9SAndroid Build Coastguard Worker else: 13597*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 13598*da0073e9SAndroid Build Coastguard Worker 13599*da0073e9SAndroid Build Coastguard Worker def fn(x): 13600*da0073e9SAndroid Build Coastguard Worker # We avoid the following case 13601*da0073e9SAndroid Build Coastguard Worker # 13602*da0073e9SAndroid Build Coastguard Worker # saved :[4, 3], [], [], [4, 3], [4, 3], [4, 3], [12] 13603*da0073e9SAndroid Build Coastguard Worker # forward :sum ,sum,mul, mul , mul ,view , view 13604*da0073e9SAndroid Build Coastguard Worker # recompute :sum ,sum,mul, view , view 13605*da0073e9SAndroid Build Coastguard Worker # 13606*da0073e9SAndroid Build Coastguard Worker # Views save the shape of their input, so we expect the second 13607*da0073e9SAndroid Build Coastguard Worker # view to save 12, but because during AC packing during forward 13608*da0073e9SAndroid Build Coastguard Worker # saves the shapes of the input for metadata checks later, 13609*da0073e9SAndroid Build Coastguard Worker # we would save the wrong shape during the recompute. 13610*da0073e9SAndroid Build Coastguard Worker view_out = (x * x.sum()).view(-1).view(4, 3) 13611*da0073e9SAndroid Build Coastguard Worker self.assertEqual(view_out.grad_fn._saved_self_sym_sizes, [12]) 13612*da0073e9SAndroid Build Coastguard Worker return view_out.exp() 13613*da0073e9SAndroid Build Coastguard Worker 13614*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 3, requires_grad=True) 13615*da0073e9SAndroid Build Coastguard Worker x_wrapper = CustomSizeDynamicShapesTensor(x) 13616*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13617*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x_wrapper, use_reentrant=False, context_fn=context_fn) 13618*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 13619*da0073e9SAndroid Build Coastguard Worker 13620*da0073e9SAndroid Build Coastguard Worker def test_bad_inputs(self): 13621*da0073e9SAndroid Build Coastguard Worker bad_op_list1 = [2] 13622*da0073e9SAndroid Build Coastguard Worker 13623*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 13624*da0073e9SAndroid Build Coastguard Worker ValueError, "Expected op in `op_list` to be an OpOverload" 13625*da0073e9SAndroid Build Coastguard Worker ): 13626*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts(bad_op_list1) 13627*da0073e9SAndroid Build Coastguard Worker 13628*da0073e9SAndroid Build Coastguard Worker bad_op_list2 = [torch.ops.aten.sin] 13629*da0073e9SAndroid Build Coastguard Worker 13630*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 13631*da0073e9SAndroid Build Coastguard Worker ValueError, "update the OpOverloadPacket to a specific OpOverload" 13632*da0073e9SAndroid Build Coastguard Worker ): 13633*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts(bad_op_list2) 13634*da0073e9SAndroid Build Coastguard Worker 13635*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "either a function or a list of ops."): 13636*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts(2) 13637*da0073e9SAndroid Build Coastguard Worker 13638*da0073e9SAndroid Build Coastguard Worker # Dynamo fails for various reasons: 13639*da0073e9SAndroid Build Coastguard Worker # - some tests using custom op that does not implement Fake 13640*da0073e9SAndroid Build Coastguard Worker # - dynamo is trying to trace into saved variable hooks unpack hook for some reason 13641*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13642*da0073e9SAndroid Build Coastguard Worker def test_policy_with_state(self): 13643*da0073e9SAndroid Build Coastguard Worker # If I have a stateful callable, state is shared between the original 13644*da0073e9SAndroid Build Coastguard Worker # forward and the recompute. 13645*da0073e9SAndroid Build Coastguard Worker counters = [] 13646*da0073e9SAndroid Build Coastguard Worker 13647*da0073e9SAndroid Build Coastguard Worker class Policy: 13648*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 13649*da0073e9SAndroid Build Coastguard Worker self.counter = [0] 13650*da0073e9SAndroid Build Coastguard Worker self.recompute_counter = [0] 13651*da0073e9SAndroid Build Coastguard Worker 13652*da0073e9SAndroid Build Coastguard Worker def __call__(self, ctx, func, *args, **kwargs): 13653*da0073e9SAndroid Build Coastguard Worker counter = self.recompute_counter if ctx.is_recompute else self.counter 13654*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13655*da0073e9SAndroid Build Coastguard Worker counters.append(counter[0]) 13656*da0073e9SAndroid Build Coastguard Worker if counter == 1 and func is torch.ops.aten.mm.default: 13657*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13658*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 13659*da0073e9SAndroid Build Coastguard Worker 13660*da0073e9SAndroid Build Coastguard Worker def fn(x): 13661*da0073e9SAndroid Build Coastguard Worker return x.sin().sin().sin() 13662*da0073e9SAndroid Build Coastguard Worker 13663*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13664*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13665*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 13666*da0073e9SAndroid Build Coastguard Worker Policy(), 13667*da0073e9SAndroid Build Coastguard Worker allow_cache_entry_mutation=True, 13668*da0073e9SAndroid Build Coastguard Worker ) 13669*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13670*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 13671*da0073e9SAndroid Build Coastguard Worker # 1. counter properly reset to 0 for the recompute 13672*da0073e9SAndroid Build Coastguard Worker # 2. due to early-stop we do not recompute the final op 13673*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counters, [1, 2, 3, 1, 2]) 13674*da0073e9SAndroid Build Coastguard Worker 13675*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13676*da0073e9SAndroid Build Coastguard Worker def test_storage_lifetime(self): 13677*da0073e9SAndroid Build Coastguard Worker from torch.utils._python_dispatch import _get_current_dispatch_mode 13678*da0073e9SAndroid Build Coastguard Worker from torch.utils.checkpoint import ( 13679*da0073e9SAndroid Build Coastguard Worker _CachedTorchDispatchMode, 13680*da0073e9SAndroid Build Coastguard Worker _CachingTorchDispatchMode, 13681*da0073e9SAndroid Build Coastguard Worker ) 13682*da0073e9SAndroid Build Coastguard Worker 13683*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 13684*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13685*da0073e9SAndroid Build Coastguard Worker 13686*da0073e9SAndroid Build Coastguard Worker ref = None 13687*da0073e9SAndroid Build Coastguard Worker 13688*da0073e9SAndroid Build Coastguard Worker def fn(x): 13689*da0073e9SAndroid Build Coastguard Worker nonlocal ref 13690*da0073e9SAndroid Build Coastguard Worker 13691*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance( 13692*da0073e9SAndroid Build Coastguard Worker _get_current_dispatch_mode(), 13693*da0073e9SAndroid Build Coastguard Worker (_CachingTorchDispatchMode, _CachedTorchDispatchMode), 13694*da0073e9SAndroid Build Coastguard Worker ) 13695*da0073e9SAndroid Build Coastguard Worker 13696*da0073e9SAndroid Build Coastguard Worker out = x.cos().exp() 13697*da0073e9SAndroid Build Coastguard Worker 13698*da0073e9SAndroid Build Coastguard Worker if isinstance(_get_current_dispatch_mode(), _CachingTorchDispatchMode): 13699*da0073e9SAndroid Build Coastguard Worker raw_val = ( 13700*da0073e9SAndroid Build Coastguard Worker _get_current_dispatch_mode() 13701*da0073e9SAndroid Build Coastguard Worker .storage[torch.ops.aten.exp.default][0] 13702*da0073e9SAndroid Build Coastguard Worker .val 13703*da0073e9SAndroid Build Coastguard Worker ) 13704*da0073e9SAndroid Build Coastguard Worker # ref should've been detached 13705*da0073e9SAndroid Build Coastguard Worker # to avoid graph -> the saved variable hooks -> recompute_context -> storage -> graph 13706*da0073e9SAndroid Build Coastguard Worker self.assertFalse(raw_val.requires_grad) 13707*da0073e9SAndroid Build Coastguard Worker ref = weakref.ref(raw_val) 13708*da0073e9SAndroid Build Coastguard Worker 13709*da0073e9SAndroid Build Coastguard Worker # Careful for early-stop 13710*da0073e9SAndroid Build Coastguard Worker return out.sin() 13711*da0073e9SAndroid Build Coastguard Worker 13712*da0073e9SAndroid Build Coastguard Worker with disable_gc(): 13713*da0073e9SAndroid Build Coastguard Worker # Case 1: If graph goes away without backward, make sure there's no reference cycle 13714*da0073e9SAndroid Build Coastguard Worker # keeping storage alive. 13715*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13716*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13717*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, policy_fn 13718*da0073e9SAndroid Build Coastguard Worker ) 13719*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13720*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(ref()) 13721*da0073e9SAndroid Build Coastguard Worker del out 13722*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref()) 13723*da0073e9SAndroid Build Coastguard Worker 13724*da0073e9SAndroid Build Coastguard Worker # Case 2: After backward, even if retain_graph=True, the storage should go away 13725*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13726*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13727*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, policy_fn 13728*da0073e9SAndroid Build Coastguard Worker ) 13729*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13730*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(ref()) 13731*da0073e9SAndroid Build Coastguard Worker out.sum().backward(retain_graph=True) 13732*da0073e9SAndroid Build Coastguard Worker # The dispatch mode's storage should still be alive, but the entries should've 13733*da0073e9SAndroid Build Coastguard Worker # been cleared. 13734*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(ref()) 13735*da0073e9SAndroid Build Coastguard Worker 13736*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13737*da0073e9SAndroid Build Coastguard Worker def test_version_counter(self): 13738*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 13739*da0073e9SAndroid Build Coastguard Worker if op == torch.ops.aten.sin.default: 13740*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13741*da0073e9SAndroid Build Coastguard Worker else: 13742*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 13743*da0073e9SAndroid Build Coastguard Worker 13744*da0073e9SAndroid Build Coastguard Worker def fn(x): 13745*da0073e9SAndroid Build Coastguard Worker return x.sin().mul_(2).cos().exp() 13746*da0073e9SAndroid Build Coastguard Worker 13747*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13748*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13749*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13750*da0073e9SAndroid Build Coastguard Worker 13751*da0073e9SAndroid Build Coastguard Worker # 1) Error because the output of sin is saved and mutated by mul_ 13752*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "has been mutated"): 13753*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 13754*da0073e9SAndroid Build Coastguard Worker 13755*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13756*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13757*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, 13758*da0073e9SAndroid Build Coastguard Worker policy_fn, 13759*da0073e9SAndroid Build Coastguard Worker allow_cache_entry_mutation=True, 13760*da0073e9SAndroid Build Coastguard Worker ) 13761*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13762*da0073e9SAndroid Build Coastguard Worker 13763*da0073e9SAndroid Build Coastguard Worker # 2) No longer should be an error because of allow_cache_entry_mutation 13764*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 13765*da0073e9SAndroid Build Coastguard Worker 13766*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13767*da0073e9SAndroid Build Coastguard Worker def test_function_with_more_than_one_output(self): 13768*da0073e9SAndroid Build Coastguard Worker # maybe there is a more systematic way: 13769*da0073e9SAndroid Build Coastguard Worker counter = [0] 13770*da0073e9SAndroid Build Coastguard Worker 13771*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 13772*da0073e9SAndroid Build Coastguard Worker if op == torch.ops.aten.var_mean.correction: 13773*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13774*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13775*da0073e9SAndroid Build Coastguard Worker else: 13776*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 13777*da0073e9SAndroid Build Coastguard Worker 13778*da0073e9SAndroid Build Coastguard Worker # var_mean has two outputs 13779*da0073e9SAndroid Build Coastguard Worker def fn(x): 13780*da0073e9SAndroid Build Coastguard Worker a, b = torch.var_mean(x) 13781*da0073e9SAndroid Build Coastguard Worker return a * b 13782*da0073e9SAndroid Build Coastguard Worker 13783*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13784*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13785*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13786*da0073e9SAndroid Build Coastguard Worker x_grad = torch.autograd.grad(out.sum(), (x,)) 13787*da0073e9SAndroid Build Coastguard Worker x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) 13788*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_grad, x_grad_ref) 13789*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 2) 13790*da0073e9SAndroid Build Coastguard Worker 13791*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13792*da0073e9SAndroid Build Coastguard Worker def test_function_with_non_tensor_output(self): 13793*da0073e9SAndroid Build Coastguard Worker # When SAC is enabled, the op is not computed a second time 13794*da0073e9SAndroid Build Coastguard Worker with torch.library._scoped_library("mylib", "FRAGMENT") as lib: 13795*da0073e9SAndroid Build Coastguard Worker counter = [0] 13796*da0073e9SAndroid Build Coastguard Worker 13797*da0073e9SAndroid Build Coastguard Worker @torch.library.custom_op("mylib::sin_with_extra", mutates_args=()) 13798*da0073e9SAndroid Build Coastguard Worker def sin_with_extra(x: torch.Tensor) -> Tuple[torch.Tensor, int]: 13799*da0073e9SAndroid Build Coastguard Worker counter[0] += 1 13800*da0073e9SAndroid Build Coastguard Worker return x.sin(), 2 13801*da0073e9SAndroid Build Coastguard Worker 13802*da0073e9SAndroid Build Coastguard Worker def setup_context(ctx, inputs, output) -> torch.Tensor: 13803*da0073e9SAndroid Build Coastguard Worker (x,) = inputs 13804*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 13805*da0073e9SAndroid Build Coastguard Worker 13806*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad, _unused): 13807*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 13808*da0073e9SAndroid Build Coastguard Worker return grad * x.cos() 13809*da0073e9SAndroid Build Coastguard Worker 13810*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd( 13811*da0073e9SAndroid Build Coastguard Worker "mylib::sin_with_extra", backward, setup_context=setup_context 13812*da0073e9SAndroid Build Coastguard Worker ) 13813*da0073e9SAndroid Build Coastguard Worker 13814*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13815*da0073e9SAndroid Build Coastguard Worker 13816*da0073e9SAndroid Build Coastguard Worker def fn(x): 13817*da0073e9SAndroid Build Coastguard Worker return (torch.ops.mylib.sin_with_extra(x)[0] * x.sin().exp()).sin() 13818*da0073e9SAndroid Build Coastguard Worker 13819*da0073e9SAndroid Build Coastguard Worker ops_list = [torch.ops.mylib.sin_with_extra.default] 13820*da0073e9SAndroid Build Coastguard Worker 13821*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13822*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial( 13823*da0073e9SAndroid Build Coastguard Worker create_selective_checkpoint_contexts, ops_list 13824*da0073e9SAndroid Build Coastguard Worker ) 13825*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13826*da0073e9SAndroid Build Coastguard Worker x_grad = torch.autograd.grad(out.sum(), (x,)) 13827*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter[0], 1) 13828*da0073e9SAndroid Build Coastguard Worker x_grad_ref = torch.autograd.grad(fn(x).sum(), (x,)) 13829*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_grad, x_grad_ref) 13830*da0073e9SAndroid Build Coastguard Worker 13831*da0073e9SAndroid Build Coastguard Worker @skipIfTorchDynamo("compile tested in test/dynamo/test_activation_checkpointing.py") 13832*da0073e9SAndroid Build Coastguard Worker def test_can_only_trigger_recompute_once(self): 13833*da0073e9SAndroid Build Coastguard Worker # We don't support this to avoid adding extra complexity for now. 13834*da0073e9SAndroid Build Coastguard Worker # If there's a need, we could probably do some kind of use_count tracking. 13835*da0073e9SAndroid Build Coastguard Worker # TODO: have a nice error message here. 13836*da0073e9SAndroid Build Coastguard Worker def policy_fn(ctx, op, *args, **kwargs): 13837*da0073e9SAndroid Build Coastguard Worker if op == torch.ops.aten.sin.default: 13838*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.MUST_SAVE 13839*da0073e9SAndroid Build Coastguard Worker else: 13840*da0073e9SAndroid Build Coastguard Worker return CheckpointPolicy.PREFER_RECOMPUTE 13841*da0073e9SAndroid Build Coastguard Worker 13842*da0073e9SAndroid Build Coastguard Worker def fn(x): 13843*da0073e9SAndroid Build Coastguard Worker return x.sin().cos().exp() 13844*da0073e9SAndroid Build Coastguard Worker 13845*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 13846*da0073e9SAndroid Build Coastguard Worker context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) 13847*da0073e9SAndroid Build Coastguard Worker out = checkpoint(fn, x, use_reentrant=False, context_fn=context_fn) 13848*da0073e9SAndroid Build Coastguard Worker out.sum().backward(retain_graph=True) 13849*da0073e9SAndroid Build Coastguard Worker 13850*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Trying to backward an extra time"): 13851*da0073e9SAndroid Build Coastguard Worker out.sum().backward(retain_graph=True) 13852*da0073e9SAndroid Build Coastguard Worker 13853*da0073e9SAndroid Build Coastguard Worker 13854*da0073e9SAndroid Build Coastguard Workerclass TestAutogradMultipleDispatch(TestCase): 13855*da0073e9SAndroid Build Coastguard Worker def test_autograd_multiple_dispatch_registrations(self, device): 13856*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, device=device, requires_grad=True) 13857*da0073e9SAndroid Build Coastguard Worker # using _test_autograd_multiple_dispatch.fullcoverage which has 13858*da0073e9SAndroid Build Coastguard Worker # registrations in derivatives.yaml for Default, AutogradCUDA and NestedTensorAutograd 13859*da0073e9SAndroid Build Coastguard Worker out = torch._test_autograd_multiple_dispatch(t) 13860*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(3, 3, device=device) 13861*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 13862*da0073e9SAndroid Build Coastguard Worker 13863*da0073e9SAndroid Build Coastguard Worker if "cuda" not in device: 13864*da0073e9SAndroid Build Coastguard Worker # bogus default gradient registered for Autograd is grad + 1 13865*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, grad + 1) 13866*da0073e9SAndroid Build Coastguard Worker else: 13867*da0073e9SAndroid Build Coastguard Worker # bogus gradient registered for AutogradCUDA is grad * 2 13868*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, grad * 2) 13869*da0073e9SAndroid Build Coastguard Worker 13870*da0073e9SAndroid Build Coastguard Worker # test registered AutogradNestedTensor formula 13871*da0073e9SAndroid Build Coastguard Worker a = ( 13872*da0073e9SAndroid Build Coastguard Worker torch.arange(6, dtype=torch.float, device=device) 13873*da0073e9SAndroid Build Coastguard Worker .reshape(2, 3) 13874*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 13875*da0073e9SAndroid Build Coastguard Worker ) 13876*da0073e9SAndroid Build Coastguard Worker b = ( 13877*da0073e9SAndroid Build Coastguard Worker torch.arange(8, dtype=torch.float, device=device) 13878*da0073e9SAndroid Build Coastguard Worker .reshape(2, 4) 13879*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 13880*da0073e9SAndroid Build Coastguard Worker ) 13881*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device) 13882*da0073e9SAndroid Build Coastguard Worker 13883*da0073e9SAndroid Build Coastguard Worker nt_out = torch._test_autograd_multiple_dispatch(nt) 13884*da0073e9SAndroid Build Coastguard Worker c = torch.randn(2, 3, device=device) 13885*da0073e9SAndroid Build Coastguard Worker d = torch.randn(2, 4, device=device) 13886*da0073e9SAndroid Build Coastguard Worker nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device) 13887*da0073e9SAndroid Build Coastguard Worker nt_out.backward(nt_grad) 13888*da0073e9SAndroid Build Coastguard Worker 13889*da0073e9SAndroid Build Coastguard Worker # bogus gradient for AutogradNestedTensor is grad * grad 13890*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, c * c) 13891*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, d * d) 13892*da0073e9SAndroid Build Coastguard Worker 13893*da0073e9SAndroid Build Coastguard Worker def test_autograd_composite_implicit_and_dispatch_registration(self, device): 13894*da0073e9SAndroid Build Coastguard Worker t = torch.randn(3, 3, device=device, requires_grad=True) 13895*da0073e9SAndroid Build Coastguard Worker # using _test_autograd_multiple_dispatch.ntonly 13896*da0073e9SAndroid Build Coastguard Worker # which has registrations in derivatives.yaml for NestedTensorAutograd and otherwise is CompositeImplicit 13897*da0073e9SAndroid Build Coastguard Worker out = torch._test_autograd_multiple_dispatch(t, True) 13898*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(3, 3, device=device) 13899*da0073e9SAndroid Build Coastguard Worker out.backward(grad) 13900*da0073e9SAndroid Build Coastguard Worker 13901*da0073e9SAndroid Build Coastguard Worker # t.grad is just out.grad by composite op since _test_autograd_multiple_dispatch is just a clone 13902*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, grad) 13903*da0073e9SAndroid Build Coastguard Worker 13904*da0073e9SAndroid Build Coastguard Worker # test registered AutogradNestedTensor formula 13905*da0073e9SAndroid Build Coastguard Worker a = ( 13906*da0073e9SAndroid Build Coastguard Worker torch.arange(6, dtype=torch.float, device=device) 13907*da0073e9SAndroid Build Coastguard Worker .reshape(2, 3) 13908*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 13909*da0073e9SAndroid Build Coastguard Worker ) 13910*da0073e9SAndroid Build Coastguard Worker b = ( 13911*da0073e9SAndroid Build Coastguard Worker torch.arange(8, dtype=torch.float, device=device) 13912*da0073e9SAndroid Build Coastguard Worker .reshape(2, 4) 13913*da0073e9SAndroid Build Coastguard Worker .requires_grad_(True) 13914*da0073e9SAndroid Build Coastguard Worker ) 13915*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.as_nested_tensor([a, b], dtype=torch.float, device=device) 13916*da0073e9SAndroid Build Coastguard Worker 13917*da0073e9SAndroid Build Coastguard Worker nt_out = torch._test_autograd_multiple_dispatch(nt, True) 13918*da0073e9SAndroid Build Coastguard Worker c = torch.randn(2, 3, device=device) 13919*da0073e9SAndroid Build Coastguard Worker d = torch.randn(2, 4, device=device) 13920*da0073e9SAndroid Build Coastguard Worker nt_grad = torch.nested.nested_tensor([c, d], dtype=torch.float, device=device) 13921*da0073e9SAndroid Build Coastguard Worker nt_out.backward(nt_grad) 13922*da0073e9SAndroid Build Coastguard Worker 13923*da0073e9SAndroid Build Coastguard Worker # bogus gradient for AutogradNestedTensor is grad * grad + grad 13924*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad, c * c + c) 13925*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad, d * d + d) 13926*da0073e9SAndroid Build Coastguard Worker 13927*da0073e9SAndroid Build Coastguard Worker def test_foward_mode_AD(self, device): 13928*da0073e9SAndroid Build Coastguard Worker # check that forward mode AD is only registered for the Default 13929*da0073e9SAndroid Build Coastguard Worker # dispatch for _test_autograd_multiple_dispatch.fullcoverage and not AutogradCUDA 13930*da0073e9SAndroid Build Coastguard Worker 13931*da0073e9SAndroid Build Coastguard Worker primal = torch.randn(3, device=device) 13932*da0073e9SAndroid Build Coastguard Worker tangent = torch.randn(3, device=device) 13933*da0073e9SAndroid Build Coastguard Worker 13934*da0073e9SAndroid Build Coastguard Worker with fwAD.dual_level(): 13935*da0073e9SAndroid Build Coastguard Worker dual_input = fwAD.make_dual(primal, tangent) 13936*da0073e9SAndroid Build Coastguard Worker 13937*da0073e9SAndroid Build Coastguard Worker err_msg = r"Trying to use forward AD with .* that does not support it" 13938*da0073e9SAndroid Build Coastguard Worker hint_msg = "Running forward AD for an OP that does not implement it should raise a NotImplementedError" 13939*da0073e9SAndroid Build Coastguard Worker 13940*da0073e9SAndroid Build Coastguard Worker if "cuda" in device: 13941*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg): 13942*da0073e9SAndroid Build Coastguard Worker torch._test_autograd_multiple_dispatch(dual_input) 13943*da0073e9SAndroid Build Coastguard Worker else: 13944*da0073e9SAndroid Build Coastguard Worker torch._test_autograd_multiple_dispatch(dual_input) 13945*da0073e9SAndroid Build Coastguard Worker 13946*da0073e9SAndroid Build Coastguard Worker def test_view_copy(self, device): 13947*da0073e9SAndroid Build Coastguard Worker # tests that view_copy derivative formulas are also generated per dispatch key 13948*da0073e9SAndroid Build Coastguard Worker # from their respective view ops in derivatives.yaml 13949*da0073e9SAndroid Build Coastguard Worker t = torch.randn(2, 2, device=device, requires_grad=True) 13950*da0073e9SAndroid Build Coastguard Worker t_ref = t.clone().detach().requires_grad_() 13951*da0073e9SAndroid Build Coastguard Worker # _test_autograd_multiple_dispatch_view does a .view(-1) on the input 13952*da0073e9SAndroid Build Coastguard Worker t_view = torch._test_autograd_multiple_dispatch_view(t_ref) 13953*da0073e9SAndroid Build Coastguard Worker t_view_copy = torch._test_autograd_multiple_dispatch_view_copy(t) 13954*da0073e9SAndroid Build Coastguard Worker 13955*da0073e9SAndroid Build Coastguard Worker grad = torch.randn(4, device=device) 13956*da0073e9SAndroid Build Coastguard Worker t_view_copy.backward(grad) 13957*da0073e9SAndroid Build Coastguard Worker t_view.backward(grad.clone()) 13958*da0073e9SAndroid Build Coastguard Worker 13959*da0073e9SAndroid Build Coastguard Worker # forward and backward give the same shape + result 13960*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t_view_copy, t_view) 13961*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, t_ref.grad) 13962*da0073e9SAndroid Build Coastguard Worker # backward results are per-dispatch-key in derivatives.yaml 13963*da0073e9SAndroid Build Coastguard Worker if "cuda" in device: 13964*da0073e9SAndroid Build Coastguard Worker # gradient registered to AutogradCUDA is grad.reshape_as(self) + 1 13965*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, grad.reshape_as(t) + 1) 13966*da0073e9SAndroid Build Coastguard Worker else: 13967*da0073e9SAndroid Build Coastguard Worker # Default gradient registered is grad.reshape_as(self) 13968*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.grad, grad.reshape_as(t)) 13969*da0073e9SAndroid Build Coastguard Worker 13970*da0073e9SAndroid Build Coastguard Worker @onlyCPU 13971*da0073e9SAndroid Build Coastguard Worker def test_per_dispatch_key_input_saving(self, device): 13972*da0073e9SAndroid Build Coastguard Worker # Tests that sum.dim_IntList's input is not saved for regular tensors but is saved for nested tensors 13973*da0073e9SAndroid Build Coastguard Worker def foo(x): 13974*da0073e9SAndroid Build Coastguard Worker # Don't modify the input inplace 13975*da0073e9SAndroid Build Coastguard Worker x = x.clone() 13976*da0073e9SAndroid Build Coastguard Worker res = x.sum(-1, keepdim=True) 13977*da0073e9SAndroid Build Coastguard Worker x.add_(x) 13978*da0073e9SAndroid Build Coastguard Worker return res 13979*da0073e9SAndroid Build Coastguard Worker 13980*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(2, device=device, requires_grad=True) 13981*da0073e9SAndroid Build Coastguard Worker # sum's input is not saved for regular Tensors 13982*da0073e9SAndroid Build Coastguard Worker foo(inp).backward() 13983*da0073e9SAndroid Build Coastguard Worker 13984*da0073e9SAndroid Build Coastguard Worker # sum's input is saved for Nested Tensors 13985*da0073e9SAndroid Build Coastguard Worker nt = torch.nested.nested_tensor( 13986*da0073e9SAndroid Build Coastguard Worker [torch.rand(2), torch.rand(2)], device=device, requires_grad=True 13987*da0073e9SAndroid Build Coastguard Worker ) 13988*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "modified by an inplace operation"): 13989*da0073e9SAndroid Build Coastguard Worker foo(nt).backward( 13990*da0073e9SAndroid Build Coastguard Worker torch.nested.nested_tensor( 13991*da0073e9SAndroid Build Coastguard Worker [torch.rand(1), torch.rand(1)], device=device 13992*da0073e9SAndroid Build Coastguard Worker ) 13993*da0073e9SAndroid Build Coastguard Worker ) 13994*da0073e9SAndroid Build Coastguard Worker 13995*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 13996*da0073e9SAndroid Build Coastguard Worker def test_backward_single_threaded(self): 13997*da0073e9SAndroid Build Coastguard Worker threads_eq = None 13998*da0073e9SAndroid Build Coastguard Worker 13999*da0073e9SAndroid Build Coastguard Worker class TestFn(Function): 14000*da0073e9SAndroid Build Coastguard Worker @staticmethod 14001*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, self): 14002*da0073e9SAndroid Build Coastguard Worker ctx.self = self 14003*da0073e9SAndroid Build Coastguard Worker ctx.tid = threading.get_ident() 14004*da0073e9SAndroid Build Coastguard Worker return x.clone() 14005*da0073e9SAndroid Build Coastguard Worker 14006*da0073e9SAndroid Build Coastguard Worker @staticmethod 14007*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 14008*da0073e9SAndroid Build Coastguard Worker nonlocal threads_eq 14009*da0073e9SAndroid Build Coastguard Worker threads_eq = ctx.tid == threading.get_ident() 14010*da0073e9SAndroid Build Coastguard Worker return gO, None 14011*da0073e9SAndroid Build Coastguard Worker 14012*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(10, device="cuda", requires_grad=True) 14013*da0073e9SAndroid Build Coastguard Worker 14014*da0073e9SAndroid Build Coastguard Worker with torch.autograd.set_multithreading_enabled(False): 14015*da0073e9SAndroid Build Coastguard Worker TestFn.apply(inp, None).sum().backward() 14016*da0073e9SAndroid Build Coastguard Worker self.assertTrue(threads_eq) 14017*da0073e9SAndroid Build Coastguard Worker 14018*da0073e9SAndroid Build Coastguard Worker TestFn.apply(inp, None).sum().backward() 14019*da0073e9SAndroid Build Coastguard Worker self.assertFalse(threads_eq) 14020*da0073e9SAndroid Build Coastguard Worker 14021*da0073e9SAndroid Build Coastguard Worker @onlyCUDA 14022*da0073e9SAndroid Build Coastguard Worker def test_backward_tls_stash(self): 14023*da0073e9SAndroid Build Coastguard Worker local = threading.local() 14024*da0073e9SAndroid Build Coastguard Worker local.my_obj = {} 14025*da0073e9SAndroid Build Coastguard Worker local.my_obj[10] = 10 14026*da0073e9SAndroid Build Coastguard Worker test_self = self 14027*da0073e9SAndroid Build Coastguard Worker torch._C._stash_obj_in_tls("my_obj", local.my_obj) 14028*da0073e9SAndroid Build Coastguard Worker 14029*da0073e9SAndroid Build Coastguard Worker class TestFn(Function): 14030*da0073e9SAndroid Build Coastguard Worker @staticmethod 14031*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x, self): 14032*da0073e9SAndroid Build Coastguard Worker return x.clone() 14033*da0073e9SAndroid Build Coastguard Worker 14034*da0073e9SAndroid Build Coastguard Worker @staticmethod 14035*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 14036*da0073e9SAndroid Build Coastguard Worker test_self.assertTrue(torch._C._is_key_in_tls("my_obj")) 14037*da0073e9SAndroid Build Coastguard Worker test_self.assertTrue(torch._C._get_obj_in_tls("my_obj")[10] == 10) 14038*da0073e9SAndroid Build Coastguard Worker torch._C._get_obj_in_tls("my_obj")[10] = 5 14039*da0073e9SAndroid Build Coastguard Worker return gO, None 14040*da0073e9SAndroid Build Coastguard Worker 14041*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(10, device="cuda", requires_grad=True) 14042*da0073e9SAndroid Build Coastguard Worker 14043*da0073e9SAndroid Build Coastguard Worker TestFn.apply(inp, None).sum().backward() 14044*da0073e9SAndroid Build Coastguard Worker self.assertEqual(local.my_obj[10], 5) 14045*da0073e9SAndroid Build Coastguard Worker 14046*da0073e9SAndroid Build Coastguard Worker def test_is_retain_graph(self): 14047*da0073e9SAndroid Build Coastguard Worker retain_graph_set = False 14048*da0073e9SAndroid Build Coastguard Worker 14049*da0073e9SAndroid Build Coastguard Worker class TestFn(Function): 14050*da0073e9SAndroid Build Coastguard Worker @staticmethod 14051*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 14052*da0073e9SAndroid Build Coastguard Worker return x.clone() 14053*da0073e9SAndroid Build Coastguard Worker 14054*da0073e9SAndroid Build Coastguard Worker @staticmethod 14055*da0073e9SAndroid Build Coastguard Worker def backward(ctx, gO): 14056*da0073e9SAndroid Build Coastguard Worker nonlocal retain_graph_set 14057*da0073e9SAndroid Build Coastguard Worker retain_graph_set = ( 14058*da0073e9SAndroid Build Coastguard Worker torch._C._autograd._get_current_graph_task_keep_graph() 14059*da0073e9SAndroid Build Coastguard Worker ) 14060*da0073e9SAndroid Build Coastguard Worker return gO, None 14061*da0073e9SAndroid Build Coastguard Worker 14062*da0073e9SAndroid Build Coastguard Worker inp = torch.rand(10, requires_grad=True) 14063*da0073e9SAndroid Build Coastguard Worker 14064*da0073e9SAndroid Build Coastguard Worker out = TestFn.apply(inp) 14065*da0073e9SAndroid Build Coastguard Worker self.assertFalse(retain_graph_set) 14066*da0073e9SAndroid Build Coastguard Worker out.sum().backward(retain_graph=True) 14067*da0073e9SAndroid Build Coastguard Worker self.assertTrue(retain_graph_set) 14068*da0073e9SAndroid Build Coastguard Worker out.sum().backward(retain_graph=False) 14069*da0073e9SAndroid Build Coastguard Worker self.assertFalse(retain_graph_set) 14070*da0073e9SAndroid Build Coastguard Worker 14071*da0073e9SAndroid Build Coastguard Worker def test_set_sequence_nr(self): 14072*da0073e9SAndroid Build Coastguard Worker x = torch.randn((10,), dtype=torch.float32, requires_grad=True) 14073*da0073e9SAndroid Build Coastguard Worker y = torch.randn((10,), dtype=torch.float32, requires_grad=True) 14074*da0073e9SAndroid Build Coastguard Worker z = torch.randn((10,), dtype=torch.float32, requires_grad=True) 14075*da0073e9SAndroid Build Coastguard Worker 14076*da0073e9SAndroid Build Coastguard Worker a = x + y 14077*da0073e9SAndroid Build Coastguard Worker b = y + z 14078*da0073e9SAndroid Build Coastguard Worker c = a + b 14079*da0073e9SAndroid Build Coastguard Worker 14080*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(a.grad_fn) 14081*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(b.grad_fn) 14082*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(c.grad_fn) 14083*da0073e9SAndroid Build Coastguard Worker 14084*da0073e9SAndroid Build Coastguard Worker a.grad_fn._set_sequence_nr(100) 14085*da0073e9SAndroid Build Coastguard Worker b.grad_fn._set_sequence_nr(99) 14086*da0073e9SAndroid Build Coastguard Worker c.grad_fn._set_sequence_nr(98) 14087*da0073e9SAndroid Build Coastguard Worker 14088*da0073e9SAndroid Build Coastguard Worker self.assertEqual(a.grad_fn._sequence_nr(), 100) 14089*da0073e9SAndroid Build Coastguard Worker self.assertEqual(b.grad_fn._sequence_nr(), 99) 14090*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c.grad_fn._sequence_nr(), 98) 14091*da0073e9SAndroid Build Coastguard Worker 14092*da0073e9SAndroid Build Coastguard Worker def log_grad_order(grad: torch.Tensor, name: str, order): 14093*da0073e9SAndroid Build Coastguard Worker order.append(name) 14094*da0073e9SAndroid Build Coastguard Worker return grad 14095*da0073e9SAndroid Build Coastguard Worker 14096*da0073e9SAndroid Build Coastguard Worker order = [] 14097*da0073e9SAndroid Build Coastguard Worker a.register_hook(partial(log_grad_order, name="a", order=order)) 14098*da0073e9SAndroid Build Coastguard Worker b.register_hook(partial(log_grad_order, name="b", order=order)) 14099*da0073e9SAndroid Build Coastguard Worker c.register_hook(partial(log_grad_order, name="c", order=order)) 14100*da0073e9SAndroid Build Coastguard Worker 14101*da0073e9SAndroid Build Coastguard Worker c.sum().backward() 14102*da0073e9SAndroid Build Coastguard Worker 14103*da0073e9SAndroid Build Coastguard Worker # Expect to see that even though c has the smallest sequence number, it is still the first node to get run in autograd. 14104*da0073e9SAndroid Build Coastguard Worker # Also check that although a comes first during the forward, after giving it priority with sequence_nr, 14105*da0073e9SAndroid Build Coastguard Worker # its autograd node is run before that of b. 14106*da0073e9SAndroid Build Coastguard Worker self.assertEqual(order, ["c", "a", "b"]) 14107*da0073e9SAndroid Build Coastguard Worker 14108*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, torch.ones_like(x)) 14109*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, 2 * torch.ones_like(x)) 14110*da0073e9SAndroid Build Coastguard Worker self.assertEqual(z.grad, torch.ones_like(x)) 14111*da0073e9SAndroid Build Coastguard Worker 14112*da0073e9SAndroid Build Coastguard Worker 14113*da0073e9SAndroid Build Coastguard Worker# Import test cases from below autograd/ here. These are found 14114*da0073e9SAndroid Build Coastguard Worker# implicitly by the loader, so Flake8 thinks they are unused, hence 14115*da0073e9SAndroid Build Coastguard Worker# the suppressions. 14116*da0073e9SAndroid Build Coastguard Worker 14117*da0073e9SAndroid Build Coastguard Workerfrom autograd.test_complex import TestAutogradComplex # noqa: F401 14118*da0073e9SAndroid Build Coastguard Workerfrom autograd.test_functional import TestAutogradFunctional # noqa: F401 14119*da0073e9SAndroid Build Coastguard Workerfrom autograd.test_logging import TestAutogradLogging # noqa: F401 14120*da0073e9SAndroid Build Coastguard Worker 14121*da0073e9SAndroid Build Coastguard Worker 14122*da0073e9SAndroid Build Coastguard Worker# e.g., TestAutogradDeviceTypeCPU and TestAutogradDeviceTypeCUDA 14123*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestAutogradDeviceType, globals(), except_for=None) 14124*da0073e9SAndroid Build Coastguard Worker 14125*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests( 14126*da0073e9SAndroid Build Coastguard Worker TestAutogradMultipleDispatch, globals(), only_for=("cpu", "cuda") 14127*da0073e9SAndroid Build Coastguard Worker) 14128*da0073e9SAndroid Build Coastguard Worker 14129*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestAutograd) 14130*da0073e9SAndroid Build Coastguard Workerinstantiate_parametrized_tests(TestNestedCheckpoint) 14131*da0073e9SAndroid Build Coastguard Worker 14132*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 14133*da0073e9SAndroid Build Coastguard Worker run_tests() 14134