1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport functools 5*da0073e9SAndroid Build Coastguard Workerimport unittest 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torch 8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 11*da0073e9SAndroid Build Coastguard Workerfrom functorch.compile import nop 12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import compiled_autograd 13*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.aot_autograd import aot_module_simplified 14*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.hooks import RemovableHandle 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerdef compiler_fn(gm): 18*da0073e9SAndroid Build Coastguard Worker return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm) 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerdef global_hook_0(grad): 22*da0073e9SAndroid Build Coastguard Worker return grad * 4 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerdef global_hook_1(grad): 26*da0073e9SAndroid Build Coastguard Worker return grad / 2 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Workerdef global_hook_2(grad): 30*da0073e9SAndroid Build Coastguard Worker return grad * 3 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker 33*da0073e9SAndroid Build Coastguard Workerh0 = None 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Workerclass ClassWithVal: 37*da0073e9SAndroid Build Coastguard Worker def __init__(self, val): 38*da0073e9SAndroid Build Coastguard Worker self.val = val 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Workerclass HooksTests(torch._dynamo.test_case.TestCase): 42*da0073e9SAndroid Build Coastguard Worker def test_tensor_only_register_hook_in_graph_lambda(self): 43*da0073e9SAndroid Build Coastguard Worker def fn(x): 44*da0073e9SAndroid Build Coastguard Worker x.register_hook(lambda grad: grad * 2) 45*da0073e9SAndroid Build Coastguard Worker return x 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 48*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 49*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 50*da0073e9SAndroid Build Coastguard Worker v = fn(v) 51*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 52*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 53*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_in_graph_lambda(self): 56*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 57*da0073e9SAndroid Build Coastguard Worker x.register_hook(lambda grad: grad * 2) 58*da0073e9SAndroid Build Coastguard Worker return x, y * y, z * z 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 61*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 62*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 63*da0073e9SAndroid Build Coastguard Worker v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 64*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 65*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 66*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_in_graph_break_handle_lambda(self): 69*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 70*da0073e9SAndroid Build Coastguard Worker handle = x.register_hook(lambda grad: grad * 2) 71*da0073e9SAndroid Build Coastguard Worker z = z * z 72*da0073e9SAndroid Build Coastguard Worker handle.remove() 73*da0073e9SAndroid Build Coastguard Worker x.register_hook(lambda grad: grad * 3) 74*da0073e9SAndroid Build Coastguard Worker return x, y * y, z 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 77*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 78*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 79*da0073e9SAndroid Build Coastguard Worker v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 80*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 81*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_multi_handle_return(self): 85*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 86*da0073e9SAndroid Build Coastguard Worker handle = x.register_hook(lambda grad: grad * 2) 87*da0073e9SAndroid Build Coastguard Worker h2 = handle 88*da0073e9SAndroid Build Coastguard Worker z = z * z 89*da0073e9SAndroid Build Coastguard Worker return x, y * y, z, handle, h2 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 92*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 93*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 94*da0073e9SAndroid Build Coastguard Worker v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) 95*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 96*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 97*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 98*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(h, None) 99*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(h2, None) 100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(h2, h) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_repeated_handle_return(self): 103*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 104*da0073e9SAndroid Build Coastguard Worker handle = x.register_hook(lambda grad: grad * 2) 105*da0073e9SAndroid Build Coastguard Worker h2 = handle 106*da0073e9SAndroid Build Coastguard Worker z = z * z 107*da0073e9SAndroid Build Coastguard Worker return x, y * y, z, handle, handle 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 110*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 111*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 112*da0073e9SAndroid Build Coastguard Worker v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) 113*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 115*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 116*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(h, RemovableHandle) 117*da0073e9SAndroid Build Coastguard Worker self.assertIs(h2, h) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def test_removed_handle_return(self): 120*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=cnt, fullgraph=True) 123*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 124*da0073e9SAndroid Build Coastguard Worker handle = x.register_hook(lambda grad: grad * 2) 125*da0073e9SAndroid Build Coastguard Worker z = z * z 126*da0073e9SAndroid Build Coastguard Worker handle.remove() 127*da0073e9SAndroid Build Coastguard Worker handle.remove() 128*da0073e9SAndroid Build Coastguard Worker return x, y * y, z, handle, handle 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 131*da0073e9SAndroid Build Coastguard Worker v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) 132*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 133*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0])) 134*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 135*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(h, RemovableHandle) 136*da0073e9SAndroid Build Coastguard Worker self.assertIs(h2, h) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_repeated_handle_not_local(self): 139*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z, mod): 140*da0073e9SAndroid Build Coastguard Worker mod.handle = x.register_hook(lambda grad: grad * 2) 141*da0073e9SAndroid Build Coastguard Worker z = z * z 142*da0073e9SAndroid Build Coastguard Worker return x, y * y, z 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 145*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 146*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker mod = torch.nn.Module() 149*da0073e9SAndroid Build Coastguard Worker mod.handle = None 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod) 152*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 155*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(mod.handle, None) 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Worker def test_tensor_only_register_hook_in_graph_local(self): 160*da0073e9SAndroid Build Coastguard Worker def local_hook(grad): 161*da0073e9SAndroid Build Coastguard Worker return grad * 2 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker def fn(x): 164*da0073e9SAndroid Build Coastguard Worker x.register_hook(local_hook) 165*da0073e9SAndroid Build Coastguard Worker return x 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 168*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 169*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 170*da0073e9SAndroid Build Coastguard Worker v = fn(v) 171*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 172*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 173*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) 174*da0073e9SAndroid Build Coastguard Worker 175*da0073e9SAndroid Build Coastguard Worker def test_tensor_only_register_hook_in_graph_local_inner(self): 176*da0073e9SAndroid Build Coastguard Worker def fn(x): 177*da0073e9SAndroid Build Coastguard Worker def local_hook(grad): 178*da0073e9SAndroid Build Coastguard Worker return grad * 2 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker z = x * x 181*da0073e9SAndroid Build Coastguard Worker x.register_hook(local_hook) 182*da0073e9SAndroid Build Coastguard Worker z.register_hook(local_hook) 183*da0073e9SAndroid Build Coastguard Worker return x, z 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 186*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 187*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 188*da0073e9SAndroid Build Coastguard Worker v = fn(v) 189*da0073e9SAndroid Build Coastguard Worker v[0].backward(torch.tensor([1.0, 2.0, 3.0])) 190*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0])) 191*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_in_graph_local(self): 194*da0073e9SAndroid Build Coastguard Worker def local_hook(grad): 195*da0073e9SAndroid Build Coastguard Worker return grad * 2 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 198*da0073e9SAndroid Build Coastguard Worker x.register_hook(local_hook) 199*da0073e9SAndroid Build Coastguard Worker return x, y * y, z * z 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 202*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 203*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 204*da0073e9SAndroid Build Coastguard Worker v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 205*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 206*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 207*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_hook_in_graph_break_handle_local(self): 210*da0073e9SAndroid Build Coastguard Worker def local_hook(grad): 211*da0073e9SAndroid Build Coastguard Worker return grad * 2 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Worker def local_hook2(grad): 214*da0073e9SAndroid Build Coastguard Worker return grad * 3 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker def fn(x, y, z): 217*da0073e9SAndroid Build Coastguard Worker handle = x.register_hook(local_hook) 218*da0073e9SAndroid Build Coastguard Worker z = z * z 219*da0073e9SAndroid Build Coastguard Worker handle.remove() 220*da0073e9SAndroid Build Coastguard Worker x.register_hook(local_hook2) 221*da0073e9SAndroid Build Coastguard Worker return x, y * y, z 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 224*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 225*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 226*da0073e9SAndroid Build Coastguard Worker v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 227*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_global_hook(self): 232*da0073e9SAndroid Build Coastguard Worker def fn(x): 233*da0073e9SAndroid Build Coastguard Worker x.register_hook(global_hook_0) 234*da0073e9SAndroid Build Coastguard Worker return x, x * x 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 237*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 238*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 239*da0073e9SAndroid Build Coastguard Worker v = fn(v)[0] 240*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 241*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0])) 242*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_multiple_hooks(self): 245*da0073e9SAndroid Build Coastguard Worker def fn(x): 246*da0073e9SAndroid Build Coastguard Worker x.register_hook(global_hook_0) # * 4 247*da0073e9SAndroid Build Coastguard Worker x.register_hook(global_hook_1) # / 2 248*da0073e9SAndroid Build Coastguard Worker x.register_hook(global_hook_2) # * 3 249*da0073e9SAndroid Build Coastguard Worker return x, x * x 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 252*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 253*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 254*da0073e9SAndroid Build Coastguard Worker v = fn(v)[0] 255*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 256*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0])) 257*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_multiple_hooks_handles_in_list(self): 260*da0073e9SAndroid Build Coastguard Worker def fn(x): 261*da0073e9SAndroid Build Coastguard Worker h0 = x.register_hook(global_hook_0) # * 4 262*da0073e9SAndroid Build Coastguard Worker h1 = x.register_hook(global_hook_1) # / 2 263*da0073e9SAndroid Build Coastguard Worker h2 = x.register_hook(global_hook_2) # * 3 264*da0073e9SAndroid Build Coastguard Worker return x, x * x, h0, h1, h2 265*da0073e9SAndroid Build Coastguard Worker 266*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 267*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 268*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 269*da0073e9SAndroid Build Coastguard Worker v, r, handle_0, handle_1, handle_2 = fn(v) 270*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 271*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0])) 272*da0073e9SAndroid Build Coastguard Worker handle_0.remove() 273*da0073e9SAndroid Build Coastguard Worker handle_1.remove() 274*da0073e9SAndroid Build Coastguard Worker handle_2.remove() 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 277*da0073e9SAndroid Build Coastguard Worker # Handles gone, grad is just applied as is 278*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0])) 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker def test_tensor_register_global_hooks_handles_in_list(self): 283*da0073e9SAndroid Build Coastguard Worker def fn(x): 284*da0073e9SAndroid Build Coastguard Worker global h0 285*da0073e9SAndroid Build Coastguard Worker h0 = x.register_hook(global_hook_0) # * 4 286*da0073e9SAndroid Build Coastguard Worker return x, x * x 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 289*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts)(fn) 290*da0073e9SAndroid Build Coastguard Worker v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 291*da0073e9SAndroid Build Coastguard Worker v, r = fn(v) 292*da0073e9SAndroid Build Coastguard Worker 293*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(h0) 294*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0])) 296*da0073e9SAndroid Build Coastguard Worker h0.remove() 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker v.backward(torch.tensor([1.0, 2.0, 3.0])) 299*da0073e9SAndroid Build Coastguard Worker # Handles gone, grad is just applied as is 300*da0073e9SAndroid Build Coastguard Worker self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0])) 301*da0073e9SAndroid Build Coastguard Worker 302*da0073e9SAndroid Build Coastguard Worker # NYI! 303*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 0) 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker def test_intermediary_hooks(self): 306*da0073e9SAndroid Build Coastguard Worker # Graph breaks because compiled_autograd is not set 307*da0073e9SAndroid Build Coastguard Worker def simple_hook(g): 308*da0073e9SAndroid Build Coastguard Worker return g * 2 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker def f(x): 311*da0073e9SAndroid Build Coastguard Worker y = x + 1 312*da0073e9SAndroid Build Coastguard Worker y.register_hook(simple_hook) 313*da0073e9SAndroid Build Coastguard Worker z = y + 1 314*da0073e9SAndroid Build Coastguard Worker return z 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker out = torch.randn(1, requires_grad=True) 317*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 318*da0073e9SAndroid Build Coastguard Worker fn = torch._dynamo.optimize(cnts, nopython=False)(f) 319*da0073e9SAndroid Build Coastguard Worker res = fn(out) 320*da0073e9SAndroid Build Coastguard Worker res.backward() 321*da0073e9SAndroid Build Coastguard Worker self.assertEqual(res, f(out)) 322*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 2) 323*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.grad, torch.Tensor([2.0])) 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker def test_intermediary_hooks_same_on_aot_eager(self): 326*da0073e9SAndroid Build Coastguard Worker def my_hook(grad, *, k=0): 327*da0073e9SAndroid Build Coastguard Worker return grad + k 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 330*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 331*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 332*da0073e9SAndroid Build Coastguard Worker hook1 = functools.partial(my_hook, k=3) 333*da0073e9SAndroid Build Coastguard Worker hook2 = functools.partial(my_hook, k=4) 334*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook1) 335*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook2) 336*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 337*da0073e9SAndroid Build Coastguard Worker return (z,) 338*da0073e9SAndroid Build Coastguard Worker 339*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 340*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 341*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0) 342*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 343*da0073e9SAndroid Build Coastguard Worker 344*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 345*da0073e9SAndroid Build Coastguard Worker mod_compiled = aot_module_simplified(mod, (x1,), nop) 346*da0073e9SAndroid Build Coastguard Worker aot_out = mod_compiled(x1) 347*da0073e9SAndroid Build Coastguard Worker aot_out[0].backward(torch.ones(4)) 348*da0073e9SAndroid Build Coastguard Worker 349*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 350*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 351*da0073e9SAndroid Build Coastguard Worker dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2) 352*da0073e9SAndroid Build Coastguard Worker dynamo_out[0].backward(torch.ones(4)) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, aot_out) 355*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, eager_out) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x1.grad) 358*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker def test_input_hooks_same(self): 361*da0073e9SAndroid Build Coastguard Worker backends = ["eager", "aot_eager", "inductor"] 362*da0073e9SAndroid Build Coastguard Worker for backend in backends: 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker def my_hook(grad, *, k=0): 365*da0073e9SAndroid Build Coastguard Worker return grad + k 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker hook = functools.partial(my_hook, k=3) 368*da0073e9SAndroid Build Coastguard Worker 369*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 370*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 371*da0073e9SAndroid Build Coastguard Worker x.register_hook(hook) 372*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 373*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 374*da0073e9SAndroid Build Coastguard Worker return (z,) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 377*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 378*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0) 379*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 382*da0073e9SAndroid Build Coastguard Worker mod_compiled = aot_module_simplified(mod, (x1,), nop) 383*da0073e9SAndroid Build Coastguard Worker aot_out = mod_compiled(x1) 384*da0073e9SAndroid Build Coastguard Worker aot_out[0].backward(torch.ones(4)) 385*da0073e9SAndroid Build Coastguard Worker 386*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 387*da0073e9SAndroid Build Coastguard Worker dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2) 388*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 389*da0073e9SAndroid Build Coastguard Worker dynamo_out[0].backward(torch.ones(4)) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, aot_out) 392*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, eager_out) 393*da0073e9SAndroid Build Coastguard Worker 394*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x1.grad) 395*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker def test_intermediary_hooks_same_on_inductor(self): 398*da0073e9SAndroid Build Coastguard Worker def my_hook(grad, *, k=0): 399*da0073e9SAndroid Build Coastguard Worker return grad + k 400*da0073e9SAndroid Build Coastguard Worker 401*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 402*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 403*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 404*da0073e9SAndroid Build Coastguard Worker hook1 = functools.partial(my_hook, k=3) 405*da0073e9SAndroid Build Coastguard Worker hook2 = functools.partial(my_hook, k=4) 406*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook1) 407*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook2) 408*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 409*da0073e9SAndroid Build Coastguard Worker return (z,) 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 412*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 413*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0) 414*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 417*da0073e9SAndroid Build Coastguard Worker mod_compiled = aot_module_simplified(mod, (x1,), nop) 418*da0073e9SAndroid Build Coastguard Worker aot_out = mod_compiled(x1) 419*da0073e9SAndroid Build Coastguard Worker aot_out[0].backward(torch.ones(4)) 420*da0073e9SAndroid Build Coastguard Worker 421*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 422*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 423*da0073e9SAndroid Build Coastguard Worker dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2) 424*da0073e9SAndroid Build Coastguard Worker dynamo_out[0].backward(torch.ones(4)) 425*da0073e9SAndroid Build Coastguard Worker 426*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, aot_out) 427*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, eager_out) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x1.grad) 430*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self): 433*da0073e9SAndroid Build Coastguard Worker class SomePyClass: 434*da0073e9SAndroid Build Coastguard Worker count = 0 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Worker def do_stuff(self, grad): 437*da0073e9SAndroid Build Coastguard Worker if self.count % 2 == 0: 438*da0073e9SAndroid Build Coastguard Worker r = grad * grad 439*da0073e9SAndroid Build Coastguard Worker else: 440*da0073e9SAndroid Build Coastguard Worker r = grad + grad 441*da0073e9SAndroid Build Coastguard Worker self.count += 1 442*da0073e9SAndroid Build Coastguard Worker return r 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker def complex_state_touching_hook(grad, *, obj): 445*da0073e9SAndroid Build Coastguard Worker return obj.do_stuff(grad) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 448*da0073e9SAndroid Build Coastguard Worker def forward(self, x, obj): 449*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 450*da0073e9SAndroid Build Coastguard Worker hook1 = functools.partial(complex_state_touching_hook, obj=obj) 451*da0073e9SAndroid Build Coastguard Worker hook2 = functools.partial(complex_state_touching_hook, obj=obj) 452*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook1) 453*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook2) 454*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 455*da0073e9SAndroid Build Coastguard Worker return (z,) 456*da0073e9SAndroid Build Coastguard Worker 457*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 458*da0073e9SAndroid Build Coastguard Worker obj = SomePyClass() 459*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 460*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0, obj) 461*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 462*da0073e9SAndroid Build Coastguard Worker 463*da0073e9SAndroid Build Coastguard Worker # Eager 2 464*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.count, 2) 465*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 466*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 467*da0073e9SAndroid Build Coastguard Worker dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj) 468*da0073e9SAndroid Build Coastguard Worker dynamo_out[0].backward(torch.ones(4)) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker self.assertEqual(dynamo_out, eager_out) 471*da0073e9SAndroid Build Coastguard Worker 472*da0073e9SAndroid Build Coastguard Worker # Eager 2 + compiled 2 473*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.count, 4) 474*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break( 477*da0073e9SAndroid Build Coastguard Worker self, 478*da0073e9SAndroid Build Coastguard Worker ): 479*da0073e9SAndroid Build Coastguard Worker class SomePyClass: 480*da0073e9SAndroid Build Coastguard Worker grad_as_str = "None" 481*da0073e9SAndroid Build Coastguard Worker count = 0 482*da0073e9SAndroid Build Coastguard Worker 483*da0073e9SAndroid Build Coastguard Worker def write_grad_as_str_and_do_stuff(self, grad): 484*da0073e9SAndroid Build Coastguard Worker self.grad_as_str = str(grad) 485*da0073e9SAndroid Build Coastguard Worker if self.count % 2 == 0: 486*da0073e9SAndroid Build Coastguard Worker r = grad * grad 487*da0073e9SAndroid Build Coastguard Worker else: 488*da0073e9SAndroid Build Coastguard Worker r = grad + grad 489*da0073e9SAndroid Build Coastguard Worker print("Break!") 490*da0073e9SAndroid Build Coastguard Worker self.count += 1 491*da0073e9SAndroid Build Coastguard Worker return r 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def complex_state_touching_hook(grad, *, obj): 494*da0073e9SAndroid Build Coastguard Worker return obj.write_grad_as_str_and_do_stuff(grad) 495*da0073e9SAndroid Build Coastguard Worker 496*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 497*da0073e9SAndroid Build Coastguard Worker def forward(self, x, obj): 498*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 499*da0073e9SAndroid Build Coastguard Worker hook1 = functools.partial(complex_state_touching_hook, obj=obj) 500*da0073e9SAndroid Build Coastguard Worker hook2 = functools.partial(complex_state_touching_hook, obj=obj) 501*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook1) 502*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook2) 503*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 504*da0073e9SAndroid Build Coastguard Worker return (z,) 505*da0073e9SAndroid Build Coastguard Worker 506*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 507*da0073e9SAndroid Build Coastguard Worker obj = SomePyClass() 508*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 509*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0, obj) 510*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 513*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 514*da0073e9SAndroid Build Coastguard Worker dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj) 515*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"): 516*da0073e9SAndroid Build Coastguard Worker dynamo_out[0].backward(torch.ones(4)) 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker self.assertEqual(obj.count, 2) 519*da0073e9SAndroid Build Coastguard Worker 520*da0073e9SAndroid Build Coastguard Worker def test_register_hook_partial_guarding( 521*da0073e9SAndroid Build Coastguard Worker self, 522*da0073e9SAndroid Build Coastguard Worker ): 523*da0073e9SAndroid Build Coastguard Worker def some_hook(grad, *, obj): 524*da0073e9SAndroid Build Coastguard Worker return grad + obj.val 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 527*da0073e9SAndroid Build Coastguard Worker def forward(self, x, obj): 528*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 529*da0073e9SAndroid Build Coastguard Worker hook1 = functools.partial(some_hook, obj=obj) 530*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook1) 531*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 532*da0073e9SAndroid Build Coastguard Worker return (z,) 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 535*da0073e9SAndroid Build Coastguard Worker obj1 = ClassWithVal(torch.tensor(88)) 536*da0073e9SAndroid Build Coastguard Worker obj2 = ClassWithVal(torch.tensor(99)) 537*da0073e9SAndroid Build Coastguard Worker obj3 = ClassWithVal(11) 538*da0073e9SAndroid Build Coastguard Worker cnt = torch._dynamo.testing.CompileCounter() 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 541*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 544*da0073e9SAndroid Build Coastguard Worker torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1) 545*da0073e9SAndroid Build Coastguard Worker torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1) 546*da0073e9SAndroid Build Coastguard Worker torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2) 547*da0073e9SAndroid Build Coastguard Worker torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3) 548*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt.frame_count, 1) 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker def test_hook_with_closure(self): 551*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 552*da0073e9SAndroid Build Coastguard Worker y = x.sin() 553*da0073e9SAndroid Build Coastguard Worker x.register_hook(lambda grad: grad + obj.val) 554*da0073e9SAndroid Build Coastguard Worker z = y.sin() 555*da0073e9SAndroid Build Coastguard Worker return z 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker cnt_fw = torch._dynamo.testing.CompileCounter() 558*da0073e9SAndroid Build Coastguard Worker cnt_bw = torch._dynamo.testing.CompileCounter() 559*da0073e9SAndroid Build Coastguard Worker opt = torch.compile(fn, backend=cnt_fw, fullgraph=True) 560*da0073e9SAndroid Build Coastguard Worker 561*da0073e9SAndroid Build Coastguard Worker obj1 = ClassWithVal(torch.tensor(88)) 562*da0073e9SAndroid Build Coastguard Worker obj2 = ClassWithVal(torch.tensor(99)) 563*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 564*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 565*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 566*da0073e9SAndroid Build Coastguard Worker x3 = torch.ones(4, requires_grad=True) 567*da0073e9SAndroid Build Coastguard Worker fn(x0, obj1).sum().backward() 568*da0073e9SAndroid Build Coastguard Worker fn(x1, obj2).sum().backward() 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable( 571*da0073e9SAndroid Build Coastguard Worker functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) 572*da0073e9SAndroid Build Coastguard Worker ): 573*da0073e9SAndroid Build Coastguard Worker opt(x2, obj1).sum().backward() 574*da0073e9SAndroid Build Coastguard Worker opt(x3, obj2).sum().backward() 575*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_fw.frame_count, 1) 576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_bw.frame_count, 1) 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x3.grad) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker def test_intermediate_hook_with_closure_eager(self): 582*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 583*da0073e9SAndroid Build Coastguard Worker y = x.sin() 584*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda grad: grad + obj.val) 585*da0073e9SAndroid Build Coastguard Worker z = y.sin() 586*da0073e9SAndroid Build Coastguard Worker return z 587*da0073e9SAndroid Build Coastguard Worker 588*da0073e9SAndroid Build Coastguard Worker cnt_fw = torch._dynamo.testing.CompileCounter() 589*da0073e9SAndroid Build Coastguard Worker cnt_bw = torch._dynamo.testing.CompileCounter() 590*da0073e9SAndroid Build Coastguard Worker opt = torch.compile(fn, backend=cnt_fw, fullgraph=True) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker obj1 = ClassWithVal(torch.tensor(88)) 593*da0073e9SAndroid Build Coastguard Worker obj2 = ClassWithVal(torch.tensor(99)) 594*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 595*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 596*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 597*da0073e9SAndroid Build Coastguard Worker x3 = torch.ones(4, requires_grad=True) 598*da0073e9SAndroid Build Coastguard Worker fn(x0, obj1).sum().backward() 599*da0073e9SAndroid Build Coastguard Worker fn(x1, obj2).sum().backward() 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable( 602*da0073e9SAndroid Build Coastguard Worker functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) 603*da0073e9SAndroid Build Coastguard Worker ): 604*da0073e9SAndroid Build Coastguard Worker opt(x2, obj1).sum().backward() 605*da0073e9SAndroid Build Coastguard Worker opt(x3, obj2).sum().backward() 606*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_fw.frame_count, 1) 607*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_bw.frame_count, 1) 608*da0073e9SAndroid Build Coastguard Worker 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 610*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x3.grad) 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker def test_intermediate_hook_with_closure_aot(self): 613*da0073e9SAndroid Build Coastguard Worker def fn(x, obj): 614*da0073e9SAndroid Build Coastguard Worker y = x.sin() 615*da0073e9SAndroid Build Coastguard Worker y.register_hook(lambda grad: grad + obj.val) 616*da0073e9SAndroid Build Coastguard Worker z = y.sin() 617*da0073e9SAndroid Build Coastguard Worker return z 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker cnt_bw = torch._dynamo.testing.CompileCounter() 620*da0073e9SAndroid Build Coastguard Worker opt = torch.compile(fn, backend="aot_eager", fullgraph=True) 621*da0073e9SAndroid Build Coastguard Worker 622*da0073e9SAndroid Build Coastguard Worker obj1 = ClassWithVal(torch.tensor(88)) 623*da0073e9SAndroid Build Coastguard Worker obj2 = ClassWithVal(torch.tensor(99)) 624*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 625*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 626*da0073e9SAndroid Build Coastguard Worker x2 = torch.ones(4, requires_grad=True) 627*da0073e9SAndroid Build Coastguard Worker x3 = torch.ones(4, requires_grad=True) 628*da0073e9SAndroid Build Coastguard Worker fn(x0, obj1).sum().backward() 629*da0073e9SAndroid Build Coastguard Worker fn(x1, obj2).sum().backward() 630*da0073e9SAndroid Build Coastguard Worker 631*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable( 632*da0073e9SAndroid Build Coastguard Worker functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) 633*da0073e9SAndroid Build Coastguard Worker ): 634*da0073e9SAndroid Build Coastguard Worker opt(x2, obj1).sum().backward() 635*da0073e9SAndroid Build Coastguard Worker opt(x3, obj2).sum().backward() 636*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnt_bw.frame_count, 1) 637*da0073e9SAndroid Build Coastguard Worker 638*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x2.grad) 639*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x1.grad, x3.grad) 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker def test_no_recompile_on_hook_identity_change(self): 642*da0073e9SAndroid Build Coastguard Worker def my_hook(grad, k=0): 643*da0073e9SAndroid Build Coastguard Worker return grad + k 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker def my_hook2(grad): 646*da0073e9SAndroid Build Coastguard Worker return grad * 2 647*da0073e9SAndroid Build Coastguard Worker 648*da0073e9SAndroid Build Coastguard Worker class MyMod(torch.nn.Module): 649*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 650*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 651*da0073e9SAndroid Build Coastguard Worker y.register_hook(my_hook) 652*da0073e9SAndroid Build Coastguard Worker y.register_hook(my_hook) 653*da0073e9SAndroid Build Coastguard Worker z = y.mul(3) 654*da0073e9SAndroid Build Coastguard Worker return (z,) 655*da0073e9SAndroid Build Coastguard Worker 656*da0073e9SAndroid Build Coastguard Worker mod = MyMod() 657*da0073e9SAndroid Build Coastguard Worker x0 = torch.ones(4, requires_grad=True) 658*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0) 659*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker x1 = torch.ones(4, requires_grad=True) 662*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 663*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 664*da0073e9SAndroid Build Coastguard Worker comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod) 665*da0073e9SAndroid Build Coastguard Worker comp_out = comp_mod(x1) 666*da0073e9SAndroid Build Coastguard Worker comp_out[0].backward(torch.ones(4)) 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 669*da0073e9SAndroid Build Coastguard Worker my_hook = my_hook2 # noqa: F811 670*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x1.grad) 671*da0073e9SAndroid Build Coastguard Worker 672*da0073e9SAndroid Build Coastguard Worker eager_out = mod(x0) 673*da0073e9SAndroid Build Coastguard Worker eager_out[0].backward(torch.ones(4)) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker comp_out = comp_mod(x1) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 678*da0073e9SAndroid Build Coastguard Worker comp_out[0].backward(torch.ones(4)) 679*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x0.grad, x1.grad) 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker def test_functools_arg_vary(self): 682*da0073e9SAndroid Build Coastguard Worker def pre_hook(grad, *, k): 683*da0073e9SAndroid Build Coastguard Worker return grad * k 684*da0073e9SAndroid Build Coastguard Worker 685*da0073e9SAndroid Build Coastguard Worker hook = functools.partial(pre_hook, k=1) 686*da0073e9SAndroid Build Coastguard Worker 687*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend="eager", fullgraph=True) 688*da0073e9SAndroid Build Coastguard Worker def h(x): 689*da0073e9SAndroid Build Coastguard Worker y = x.mul(2) 690*da0073e9SAndroid Build Coastguard Worker y.register_hook(hook) 691*da0073e9SAndroid Build Coastguard Worker return y.mul(3) 692*da0073e9SAndroid Build Coastguard Worker 693*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(torch.compile(backend="eager", fullgraph=True)): 694*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, requires_grad=True) 695*da0073e9SAndroid Build Coastguard Worker h(x).sum().backward() 696*da0073e9SAndroid Build Coastguard Worker orig_grad = x.grad 697*da0073e9SAndroid Build Coastguard Worker x.grad = None 698*da0073e9SAndroid Build Coastguard Worker 699*da0073e9SAndroid Build Coastguard Worker hook = functools.partial(pre_hook, k=2) 700*da0073e9SAndroid Build Coastguard Worker h(x).sum().backward() 701*da0073e9SAndroid Build Coastguard Worker self.assertEqual(orig_grad * 2, x.grad) 702*da0073e9SAndroid Build Coastguard Worker 703*da0073e9SAndroid Build Coastguard Worker def test_post_acc_grad_hook(self): 704*da0073e9SAndroid Build Coastguard Worker def hook(input_t): 705*da0073e9SAndroid Build Coastguard Worker input_t.mul_(input_t.grad) 706*da0073e9SAndroid Build Coastguard Worker input_t.grad.mul_(5) 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker def reg_and_mul(x, y): 709*da0073e9SAndroid Build Coastguard Worker x.register_post_accumulate_grad_hook(hook) 710*da0073e9SAndroid Build Coastguard Worker return x * y 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker cnts = None 713*da0073e9SAndroid Build Coastguard Worker 714*da0073e9SAndroid Build Coastguard Worker def test_fn(fn): 715*da0073e9SAndroid Build Coastguard Worker fn(x, y) 716*da0073e9SAndroid Build Coastguard Worker b = torch.tensor([2.0, 2.0, 2.0], requires_grad=True) 717*da0073e9SAndroid Build Coastguard Worker x.backward(b) 718*da0073e9SAndroid Build Coastguard Worker if cnts: 719*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 720*da0073e9SAndroid Build Coastguard Worker # These same exact assertions run on both eager and compiled 721*da0073e9SAndroid Build Coastguard Worker # X goes to x*2 becaue of mul_ 722*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2) 723*da0073e9SAndroid Build Coastguard Worker # This test proves grad aliasing works - 724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, b * 5) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker # Eager values 727*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True) 728*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 729*da0073e9SAndroid Build Coastguard Worker test_fn(reg_and_mul) 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker # Compiled 732*da0073e9SAndroid Build Coastguard Worker for backend in ["eager", "aot_eager", "inductor"]: 733*da0073e9SAndroid Build Coastguard Worker for compiled_bwd in [False, True]: 734*da0073e9SAndroid Build Coastguard Worker torch._dynamo.reset() 735*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True) 736*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 737*da0073e9SAndroid Build Coastguard Worker 738*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounterWithBackend(backend) 739*da0073e9SAndroid Build Coastguard Worker compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul) 740*da0073e9SAndroid Build Coastguard Worker 741*da0073e9SAndroid Build Coastguard Worker compiled_bwd_ctx = ( 742*da0073e9SAndroid Build Coastguard Worker compiled_autograd.enable( 743*da0073e9SAndroid Build Coastguard Worker torch.compile(backend=backend, fullgraph=True) 744*da0073e9SAndroid Build Coastguard Worker ) 745*da0073e9SAndroid Build Coastguard Worker if compiled_bwd 746*da0073e9SAndroid Build Coastguard Worker else contextlib.nullcontext() 747*da0073e9SAndroid Build Coastguard Worker ) 748*da0073e9SAndroid Build Coastguard Worker with compiled_bwd_ctx: 749*da0073e9SAndroid Build Coastguard Worker test_fn(compiled_fn) 750*da0073e9SAndroid Build Coastguard Worker 751*da0073e9SAndroid Build Coastguard Worker def test_recompile(self): 752*da0073e9SAndroid Build Coastguard Worker def hook(param): 753*da0073e9SAndroid Build Coastguard Worker param.grad *= 2 754*da0073e9SAndroid Build Coastguard Worker 755*da0073e9SAndroid Build Coastguard Worker x = torch.ones(10) 756*da0073e9SAndroid Build Coastguard Worker x.requires_grad = True 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker def run(input): 759*da0073e9SAndroid Build Coastguard Worker return x * input 760*da0073e9SAndroid Build Coastguard Worker 761*da0073e9SAndroid Build Coastguard Worker x.register_post_accumulate_grad_hook(hook) 762*da0073e9SAndroid Build Coastguard Worker with compiled_autograd.enable(compiler_fn): 763*da0073e9SAndroid Build Coastguard Worker for i in range(5): 764*da0073e9SAndroid Build Coastguard Worker with unittest.mock.patch( 765*da0073e9SAndroid Build Coastguard Worker "torch._dynamo.config.error_on_recompile", True 766*da0073e9SAndroid Build Coastguard Worker ): 767*da0073e9SAndroid Build Coastguard Worker # Mimic optimizer.zero_grad() to clear the gradient 768*da0073e9SAndroid Build Coastguard Worker x.grad = None 769*da0073e9SAndroid Build Coastguard Worker run(i).sum().backward() 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) 772*da0073e9SAndroid Build Coastguard Worker def test_no_recompile_on_same_hook(self): 773*da0073e9SAndroid Build Coastguard Worker cnts = torch._dynamo.testing.CompileCounter() 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker def fw_hook(inp): 776*da0073e9SAndroid Build Coastguard Worker return (inp[0] + 1,) 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker class Mod(torch.nn.Module): 779*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 780*da0073e9SAndroid Build Coastguard Worker super().__init__() 781*da0073e9SAndroid Build Coastguard Worker self.layers = torch.nn.ModuleList() 782*da0073e9SAndroid Build Coastguard Worker for i in range(10): 783*da0073e9SAndroid Build Coastguard Worker layer = torch.nn.Linear(16, 16) 784*da0073e9SAndroid Build Coastguard Worker layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp)) 785*da0073e9SAndroid Build Coastguard Worker layer = torch.compile(layer, backend=cnts) 786*da0073e9SAndroid Build Coastguard Worker self.layers.append(layer) 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker def forward(self, x): 789*da0073e9SAndroid Build Coastguard Worker for l in self.layers: 790*da0073e9SAndroid Build Coastguard Worker x = l(x) 791*da0073e9SAndroid Build Coastguard Worker return x 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker mod = Mod() 794*da0073e9SAndroid Build Coastguard Worker x = torch.ones(16, 16, requires_grad=True) 795*da0073e9SAndroid Build Coastguard Worker mod(x) 796*da0073e9SAndroid Build Coastguard Worker 797*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cnts.frame_count, 1) 798*da0073e9SAndroid Build Coastguard Worker 799*da0073e9SAndroid Build Coastguard Worker 800*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 801*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker run_tests() 804