1# Owner(s): ["module: dynamo"] 2 3import contextlib 4import functools 5import unittest 6 7import torch 8import torch._dynamo 9import torch._dynamo.test_case 10import torch._dynamo.testing 11from functorch.compile import nop 12from torch._dynamo import compiled_autograd 13from torch._functorch.aot_autograd import aot_module_simplified 14from torch.utils.hooks import RemovableHandle 15 16 17def compiler_fn(gm): 18 return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm) 19 20 21def global_hook_0(grad): 22 return grad * 4 23 24 25def global_hook_1(grad): 26 return grad / 2 27 28 29def global_hook_2(grad): 30 return grad * 3 31 32 33h0 = None 34 35 36class ClassWithVal: 37 def __init__(self, val): 38 self.val = val 39 40 41class HooksTests(torch._dynamo.test_case.TestCase): 42 def test_tensor_only_register_hook_in_graph_lambda(self): 43 def fn(x): 44 x.register_hook(lambda grad: grad * 2) 45 return x 46 47 cnts = torch._dynamo.testing.CompileCounter() 48 fn = torch._dynamo.optimize(cnts)(fn) 49 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 50 v = fn(v) 51 v.backward(torch.tensor([1.0, 2.0, 3.0])) 52 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 53 self.assertEqual(cnts.frame_count, 0) 54 55 def test_tensor_register_hook_in_graph_lambda(self): 56 def fn(x, y, z): 57 x.register_hook(lambda grad: grad * 2) 58 return x, y * y, z * z 59 60 cnts = torch._dynamo.testing.CompileCounter() 61 fn = torch._dynamo.optimize(cnts)(fn) 62 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 63 v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 64 v.backward(torch.tensor([1.0, 2.0, 3.0])) 65 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 66 self.assertEqual(cnts.frame_count, 1) 67 68 def test_tensor_register_hook_in_graph_break_handle_lambda(self): 69 def fn(x, y, z): 70 handle = x.register_hook(lambda grad: grad * 2) 71 z = z * z 72 handle.remove() 73 x.register_hook(lambda grad: grad * 3) 74 return x, y * y, z 75 76 cnts = torch._dynamo.testing.CompileCounter() 77 fn = torch._dynamo.optimize(cnts)(fn) 78 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 79 v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 80 v.backward(torch.tensor([1.0, 2.0, 3.0])) 81 self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) 82 self.assertEqual(cnts.frame_count, 1) 83 84 def test_tensor_register_hook_multi_handle_return(self): 85 def fn(x, y, z): 86 handle = x.register_hook(lambda grad: grad * 2) 87 h2 = handle 88 z = z * z 89 return x, y * y, z, handle, h2 90 91 cnts = torch._dynamo.testing.CompileCounter() 92 fn = torch._dynamo.optimize(cnts)(fn) 93 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 94 v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) 95 v.backward(torch.tensor([1.0, 2.0, 3.0])) 96 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 97 self.assertEqual(cnts.frame_count, 1) 98 self.assertNotEqual(h, None) 99 self.assertNotEqual(h2, None) 100 self.assertEqual(h2, h) 101 102 def test_tensor_register_hook_repeated_handle_return(self): 103 def fn(x, y, z): 104 handle = x.register_hook(lambda grad: grad * 2) 105 h2 = handle 106 z = z * z 107 return x, y * y, z, handle, handle 108 109 cnts = torch._dynamo.testing.CompileCounter() 110 fn = torch._dynamo.optimize(cnts)(fn) 111 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 112 v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) 113 v.backward(torch.tensor([1.0, 2.0, 3.0])) 114 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 115 self.assertEqual(cnts.frame_count, 1) 116 self.assertIsInstance(h, RemovableHandle) 117 self.assertIs(h2, h) 118 119 def test_removed_handle_return(self): 120 cnt = torch._dynamo.testing.CompileCounter() 121 122 @torch.compile(backend=cnt, fullgraph=True) 123 def fn(x, y, z): 124 handle = x.register_hook(lambda grad: grad * 2) 125 z = z * z 126 handle.remove() 127 handle.remove() 128 return x, y * y, z, handle, handle 129 130 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 131 v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2])) 132 v.backward(torch.tensor([1.0, 2.0, 3.0])) 133 self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0])) 134 self.assertEqual(cnt.frame_count, 1) 135 self.assertIsInstance(h, RemovableHandle) 136 self.assertIs(h2, h) 137 138 def test_tensor_register_hook_repeated_handle_not_local(self): 139 def fn(x, y, z, mod): 140 mod.handle = x.register_hook(lambda grad: grad * 2) 141 z = z * z 142 return x, y * y, z 143 144 cnts = torch._dynamo.testing.CompileCounter() 145 fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 146 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 147 148 mod = torch.nn.Module() 149 mod.handle = None 150 151 v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod) 152 v.backward(torch.tensor([1.0, 2.0, 3.0])) 153 154 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 155 self.assertEqual(cnts.frame_count, 1) 156 157 self.assertNotEqual(mod.handle, None) 158 159 def test_tensor_only_register_hook_in_graph_local(self): 160 def local_hook(grad): 161 return grad * 2 162 163 def fn(x): 164 x.register_hook(local_hook) 165 return x 166 167 cnts = torch._dynamo.testing.CompileCounter() 168 fn = torch._dynamo.optimize(cnts)(fn) 169 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 170 v = fn(v) 171 v.backward(torch.tensor([1.0, 2.0, 3.0])) 172 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 173 self.assertEqual(cnts.frame_count, 0) 174 175 def test_tensor_only_register_hook_in_graph_local_inner(self): 176 def fn(x): 177 def local_hook(grad): 178 return grad * 2 179 180 z = x * x 181 x.register_hook(local_hook) 182 z.register_hook(local_hook) 183 return x, z 184 185 cnts = torch._dynamo.testing.CompileCounter() 186 fn = torch._dynamo.optimize(cnts)(fn) 187 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 188 v = fn(v) 189 v[0].backward(torch.tensor([1.0, 2.0, 3.0])) 190 self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0])) 191 self.assertEqual(cnts.frame_count, 1) 192 193 def test_tensor_register_hook_in_graph_local(self): 194 def local_hook(grad): 195 return grad * 2 196 197 def fn(x, y, z): 198 x.register_hook(local_hook) 199 return x, y * y, z * z 200 201 cnts = torch._dynamo.testing.CompileCounter() 202 fn = torch._dynamo.optimize(cnts)(fn) 203 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 204 v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 205 v.backward(torch.tensor([1.0, 2.0, 3.0])) 206 self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0])) 207 self.assertEqual(cnts.frame_count, 1) 208 209 def test_tensor_register_hook_in_graph_break_handle_local(self): 210 def local_hook(grad): 211 return grad * 2 212 213 def local_hook2(grad): 214 return grad * 3 215 216 def fn(x, y, z): 217 handle = x.register_hook(local_hook) 218 z = z * z 219 handle.remove() 220 x.register_hook(local_hook2) 221 return x, y * y, z 222 223 cnts = torch._dynamo.testing.CompileCounter() 224 fn = torch._dynamo.optimize(cnts)(fn) 225 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 226 v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0] 227 v.backward(torch.tensor([1.0, 2.0, 3.0])) 228 229 self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0])) 230 231 def test_tensor_register_global_hook(self): 232 def fn(x): 233 x.register_hook(global_hook_0) 234 return x, x * x 235 236 cnts = torch._dynamo.testing.CompileCounter() 237 fn = torch._dynamo.optimize(cnts)(fn) 238 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 239 v = fn(v)[0] 240 v.backward(torch.tensor([1.0, 2.0, 3.0])) 241 self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0])) 242 self.assertEqual(cnts.frame_count, 1) 243 244 def test_tensor_register_multiple_hooks(self): 245 def fn(x): 246 x.register_hook(global_hook_0) # * 4 247 x.register_hook(global_hook_1) # / 2 248 x.register_hook(global_hook_2) # * 3 249 return x, x * x 250 251 cnts = torch._dynamo.testing.CompileCounter() 252 fn = torch._dynamo.optimize(cnts)(fn) 253 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 254 v = fn(v)[0] 255 v.backward(torch.tensor([1.0, 2.0, 3.0])) 256 self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0])) 257 self.assertEqual(cnts.frame_count, 1) 258 259 def test_tensor_register_multiple_hooks_handles_in_list(self): 260 def fn(x): 261 h0 = x.register_hook(global_hook_0) # * 4 262 h1 = x.register_hook(global_hook_1) # / 2 263 h2 = x.register_hook(global_hook_2) # * 3 264 return x, x * x, h0, h1, h2 265 266 cnts = torch._dynamo.testing.CompileCounter() 267 fn = torch._dynamo.optimize(cnts)(fn) 268 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 269 v, r, handle_0, handle_1, handle_2 = fn(v) 270 v.backward(torch.tensor([1.0, 2.0, 3.0])) 271 self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0])) 272 handle_0.remove() 273 handle_1.remove() 274 handle_2.remove() 275 276 v.backward(torch.tensor([1.0, 2.0, 3.0])) 277 # Handles gone, grad is just applied as is 278 self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0])) 279 280 self.assertEqual(cnts.frame_count, 1) 281 282 def test_tensor_register_global_hooks_handles_in_list(self): 283 def fn(x): 284 global h0 285 h0 = x.register_hook(global_hook_0) # * 4 286 return x, x * x 287 288 cnts = torch._dynamo.testing.CompileCounter() 289 fn = torch._dynamo.optimize(cnts)(fn) 290 v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True) 291 v, r = fn(v) 292 293 self.assertIsNotNone(h0) 294 v.backward(torch.tensor([1.0, 2.0, 3.0])) 295 self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0])) 296 h0.remove() 297 298 v.backward(torch.tensor([1.0, 2.0, 3.0])) 299 # Handles gone, grad is just applied as is 300 self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0])) 301 302 # NYI! 303 self.assertEqual(cnts.frame_count, 0) 304 305 def test_intermediary_hooks(self): 306 # Graph breaks because compiled_autograd is not set 307 def simple_hook(g): 308 return g * 2 309 310 def f(x): 311 y = x + 1 312 y.register_hook(simple_hook) 313 z = y + 1 314 return z 315 316 out = torch.randn(1, requires_grad=True) 317 cnts = torch._dynamo.testing.CompileCounter() 318 fn = torch._dynamo.optimize(cnts, nopython=False)(f) 319 res = fn(out) 320 res.backward() 321 self.assertEqual(res, f(out)) 322 self.assertEqual(cnts.frame_count, 2) 323 self.assertEqual(out.grad, torch.Tensor([2.0])) 324 325 def test_intermediary_hooks_same_on_aot_eager(self): 326 def my_hook(grad, *, k=0): 327 return grad + k 328 329 class MyMod(torch.nn.Module): 330 def forward(self, x): 331 y = x.mul(2) 332 hook1 = functools.partial(my_hook, k=3) 333 hook2 = functools.partial(my_hook, k=4) 334 y.register_hook(hook1) 335 y.register_hook(hook2) 336 z = y.mul(3) 337 return (z,) 338 339 mod = MyMod() 340 x0 = torch.ones(4, requires_grad=True) 341 eager_out = mod(x0) 342 eager_out[0].backward(torch.ones(4)) 343 344 x1 = torch.ones(4, requires_grad=True) 345 mod_compiled = aot_module_simplified(mod, (x1,), nop) 346 aot_out = mod_compiled(x1) 347 aot_out[0].backward(torch.ones(4)) 348 349 x2 = torch.ones(4, requires_grad=True) 350 with compiled_autograd.enable(compiler_fn): 351 dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2) 352 dynamo_out[0].backward(torch.ones(4)) 353 354 self.assertEqual(dynamo_out, aot_out) 355 self.assertEqual(dynamo_out, eager_out) 356 357 self.assertEqual(x0.grad, x1.grad) 358 self.assertEqual(x0.grad, x2.grad) 359 360 def test_input_hooks_same(self): 361 backends = ["eager", "aot_eager", "inductor"] 362 for backend in backends: 363 364 def my_hook(grad, *, k=0): 365 return grad + k 366 367 hook = functools.partial(my_hook, k=3) 368 369 class MyMod(torch.nn.Module): 370 def forward(self, x): 371 x.register_hook(hook) 372 y = x.mul(2) 373 z = y.mul(3) 374 return (z,) 375 376 mod = MyMod() 377 x0 = torch.ones(4, requires_grad=True) 378 eager_out = mod(x0) 379 eager_out[0].backward(torch.ones(4)) 380 381 x1 = torch.ones(4, requires_grad=True) 382 mod_compiled = aot_module_simplified(mod, (x1,), nop) 383 aot_out = mod_compiled(x1) 384 aot_out[0].backward(torch.ones(4)) 385 386 x2 = torch.ones(4, requires_grad=True) 387 dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2) 388 with compiled_autograd.enable(compiler_fn): 389 dynamo_out[0].backward(torch.ones(4)) 390 391 self.assertEqual(dynamo_out, aot_out) 392 self.assertEqual(dynamo_out, eager_out) 393 394 self.assertEqual(x0.grad, x1.grad) 395 self.assertEqual(x0.grad, x2.grad) 396 397 def test_intermediary_hooks_same_on_inductor(self): 398 def my_hook(grad, *, k=0): 399 return grad + k 400 401 class MyMod(torch.nn.Module): 402 def forward(self, x): 403 y = x.mul(2) 404 hook1 = functools.partial(my_hook, k=3) 405 hook2 = functools.partial(my_hook, k=4) 406 y.register_hook(hook1) 407 y.register_hook(hook2) 408 z = y.mul(3) 409 return (z,) 410 411 mod = MyMod() 412 x0 = torch.ones(4, requires_grad=True) 413 eager_out = mod(x0) 414 eager_out[0].backward(torch.ones(4)) 415 416 x1 = torch.ones(4, requires_grad=True) 417 mod_compiled = aot_module_simplified(mod, (x1,), nop) 418 aot_out = mod_compiled(x1) 419 aot_out[0].backward(torch.ones(4)) 420 421 x2 = torch.ones(4, requires_grad=True) 422 with compiled_autograd.enable(compiler_fn): 423 dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2) 424 dynamo_out[0].backward(torch.ones(4)) 425 426 self.assertEqual(dynamo_out, aot_out) 427 self.assertEqual(dynamo_out, eager_out) 428 429 self.assertEqual(x0.grad, x1.grad) 430 self.assertEqual(x0.grad, x2.grad) 431 432 def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self): 433 class SomePyClass: 434 count = 0 435 436 def do_stuff(self, grad): 437 if self.count % 2 == 0: 438 r = grad * grad 439 else: 440 r = grad + grad 441 self.count += 1 442 return r 443 444 def complex_state_touching_hook(grad, *, obj): 445 return obj.do_stuff(grad) 446 447 class MyMod(torch.nn.Module): 448 def forward(self, x, obj): 449 y = x.mul(2) 450 hook1 = functools.partial(complex_state_touching_hook, obj=obj) 451 hook2 = functools.partial(complex_state_touching_hook, obj=obj) 452 y.register_hook(hook1) 453 y.register_hook(hook2) 454 z = y.mul(3) 455 return (z,) 456 457 mod = MyMod() 458 obj = SomePyClass() 459 x0 = torch.ones(4, requires_grad=True) 460 eager_out = mod(x0, obj) 461 eager_out[0].backward(torch.ones(4)) 462 463 # Eager 2 464 self.assertEqual(obj.count, 2) 465 x2 = torch.ones(4, requires_grad=True) 466 with compiled_autograd.enable(compiler_fn): 467 dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj) 468 dynamo_out[0].backward(torch.ones(4)) 469 470 self.assertEqual(dynamo_out, eager_out) 471 472 # Eager 2 + compiled 2 473 self.assertEqual(obj.count, 4) 474 self.assertEqual(x0.grad, x2.grad) 475 476 def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break( 477 self, 478 ): 479 class SomePyClass: 480 grad_as_str = "None" 481 count = 0 482 483 def write_grad_as_str_and_do_stuff(self, grad): 484 self.grad_as_str = str(grad) 485 if self.count % 2 == 0: 486 r = grad * grad 487 else: 488 r = grad + grad 489 print("Break!") 490 self.count += 1 491 return r 492 493 def complex_state_touching_hook(grad, *, obj): 494 return obj.write_grad_as_str_and_do_stuff(grad) 495 496 class MyMod(torch.nn.Module): 497 def forward(self, x, obj): 498 y = x.mul(2) 499 hook1 = functools.partial(complex_state_touching_hook, obj=obj) 500 hook2 = functools.partial(complex_state_touching_hook, obj=obj) 501 y.register_hook(hook1) 502 y.register_hook(hook2) 503 z = y.mul(3) 504 return (z,) 505 506 mod = MyMod() 507 obj = SomePyClass() 508 x0 = torch.ones(4, requires_grad=True) 509 eager_out = mod(x0, obj) 510 eager_out[0].backward(torch.ones(4)) 511 512 x2 = torch.ones(4, requires_grad=True) 513 with compiled_autograd.enable(compiler_fn): 514 dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj) 515 with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"): 516 dynamo_out[0].backward(torch.ones(4)) 517 518 self.assertEqual(obj.count, 2) 519 520 def test_register_hook_partial_guarding( 521 self, 522 ): 523 def some_hook(grad, *, obj): 524 return grad + obj.val 525 526 class MyMod(torch.nn.Module): 527 def forward(self, x, obj): 528 y = x.mul(2) 529 hook1 = functools.partial(some_hook, obj=obj) 530 y.register_hook(hook1) 531 z = y.mul(3) 532 return (z,) 533 534 mod = MyMod() 535 obj1 = ClassWithVal(torch.tensor(88)) 536 obj2 = ClassWithVal(torch.tensor(99)) 537 obj3 = ClassWithVal(11) 538 cnt = torch._dynamo.testing.CompileCounter() 539 540 x0 = torch.ones(4, requires_grad=True) 541 x1 = torch.ones(4, requires_grad=True) 542 543 with compiled_autograd.enable(compiler_fn): 544 torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1) 545 torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1) 546 torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2) 547 torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3) 548 self.assertEqual(cnt.frame_count, 1) 549 550 def test_hook_with_closure(self): 551 def fn(x, obj): 552 y = x.sin() 553 x.register_hook(lambda grad: grad + obj.val) 554 z = y.sin() 555 return z 556 557 cnt_fw = torch._dynamo.testing.CompileCounter() 558 cnt_bw = torch._dynamo.testing.CompileCounter() 559 opt = torch.compile(fn, backend=cnt_fw, fullgraph=True) 560 561 obj1 = ClassWithVal(torch.tensor(88)) 562 obj2 = ClassWithVal(torch.tensor(99)) 563 x0 = torch.ones(4, requires_grad=True) 564 x1 = torch.ones(4, requires_grad=True) 565 x2 = torch.ones(4, requires_grad=True) 566 x3 = torch.ones(4, requires_grad=True) 567 fn(x0, obj1).sum().backward() 568 fn(x1, obj2).sum().backward() 569 570 with compiled_autograd.enable( 571 functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) 572 ): 573 opt(x2, obj1).sum().backward() 574 opt(x3, obj2).sum().backward() 575 self.assertEqual(cnt_fw.frame_count, 1) 576 self.assertEqual(cnt_bw.frame_count, 1) 577 578 self.assertEqual(x0.grad, x2.grad) 579 self.assertEqual(x1.grad, x3.grad) 580 581 def test_intermediate_hook_with_closure_eager(self): 582 def fn(x, obj): 583 y = x.sin() 584 y.register_hook(lambda grad: grad + obj.val) 585 z = y.sin() 586 return z 587 588 cnt_fw = torch._dynamo.testing.CompileCounter() 589 cnt_bw = torch._dynamo.testing.CompileCounter() 590 opt = torch.compile(fn, backend=cnt_fw, fullgraph=True) 591 592 obj1 = ClassWithVal(torch.tensor(88)) 593 obj2 = ClassWithVal(torch.tensor(99)) 594 x0 = torch.ones(4, requires_grad=True) 595 x1 = torch.ones(4, requires_grad=True) 596 x2 = torch.ones(4, requires_grad=True) 597 x3 = torch.ones(4, requires_grad=True) 598 fn(x0, obj1).sum().backward() 599 fn(x1, obj2).sum().backward() 600 601 with compiled_autograd.enable( 602 functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) 603 ): 604 opt(x2, obj1).sum().backward() 605 opt(x3, obj2).sum().backward() 606 self.assertEqual(cnt_fw.frame_count, 1) 607 self.assertEqual(cnt_bw.frame_count, 1) 608 609 self.assertEqual(x0.grad, x2.grad) 610 self.assertEqual(x1.grad, x3.grad) 611 612 def test_intermediate_hook_with_closure_aot(self): 613 def fn(x, obj): 614 y = x.sin() 615 y.register_hook(lambda grad: grad + obj.val) 616 z = y.sin() 617 return z 618 619 cnt_bw = torch._dynamo.testing.CompileCounter() 620 opt = torch.compile(fn, backend="aot_eager", fullgraph=True) 621 622 obj1 = ClassWithVal(torch.tensor(88)) 623 obj2 = ClassWithVal(torch.tensor(99)) 624 x0 = torch.ones(4, requires_grad=True) 625 x1 = torch.ones(4, requires_grad=True) 626 x2 = torch.ones(4, requires_grad=True) 627 x3 = torch.ones(4, requires_grad=True) 628 fn(x0, obj1).sum().backward() 629 fn(x1, obj2).sum().backward() 630 631 with compiled_autograd.enable( 632 functools.partial(torch.compile, backend=cnt_bw, fullgraph=True) 633 ): 634 opt(x2, obj1).sum().backward() 635 opt(x3, obj2).sum().backward() 636 self.assertEqual(cnt_bw.frame_count, 1) 637 638 self.assertEqual(x0.grad, x2.grad) 639 self.assertEqual(x1.grad, x3.grad) 640 641 def test_no_recompile_on_hook_identity_change(self): 642 def my_hook(grad, k=0): 643 return grad + k 644 645 def my_hook2(grad): 646 return grad * 2 647 648 class MyMod(torch.nn.Module): 649 def forward(self, x): 650 y = x.mul(2) 651 y.register_hook(my_hook) 652 y.register_hook(my_hook) 653 z = y.mul(3) 654 return (z,) 655 656 mod = MyMod() 657 x0 = torch.ones(4, requires_grad=True) 658 eager_out = mod(x0) 659 eager_out[0].backward(torch.ones(4)) 660 661 x1 = torch.ones(4, requires_grad=True) 662 with compiled_autograd.enable(compiler_fn): 663 cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager") 664 comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod) 665 comp_out = comp_mod(x1) 666 comp_out[0].backward(torch.ones(4)) 667 668 self.assertEqual(cnts.frame_count, 1) 669 my_hook = my_hook2 # noqa: F811 670 self.assertEqual(x0.grad, x1.grad) 671 672 eager_out = mod(x0) 673 eager_out[0].backward(torch.ones(4)) 674 675 comp_out = comp_mod(x1) 676 677 self.assertEqual(cnts.frame_count, 1) 678 comp_out[0].backward(torch.ones(4)) 679 self.assertEqual(x0.grad, x1.grad) 680 681 def test_functools_arg_vary(self): 682 def pre_hook(grad, *, k): 683 return grad * k 684 685 hook = functools.partial(pre_hook, k=1) 686 687 @torch.compile(backend="eager", fullgraph=True) 688 def h(x): 689 y = x.mul(2) 690 y.register_hook(hook) 691 return y.mul(3) 692 693 with compiled_autograd.enable(torch.compile(backend="eager", fullgraph=True)): 694 x = torch.randn(2, requires_grad=True) 695 h(x).sum().backward() 696 orig_grad = x.grad 697 x.grad = None 698 699 hook = functools.partial(pre_hook, k=2) 700 h(x).sum().backward() 701 self.assertEqual(orig_grad * 2, x.grad) 702 703 def test_post_acc_grad_hook(self): 704 def hook(input_t): 705 input_t.mul_(input_t.grad) 706 input_t.grad.mul_(5) 707 708 def reg_and_mul(x, y): 709 x.register_post_accumulate_grad_hook(hook) 710 return x * y 711 712 cnts = None 713 714 def test_fn(fn): 715 fn(x, y) 716 b = torch.tensor([2.0, 2.0, 2.0], requires_grad=True) 717 x.backward(b) 718 if cnts: 719 self.assertEqual(cnts.frame_count, 1) 720 # These same exact assertions run on both eager and compiled 721 # X goes to x*2 becaue of mul_ 722 self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2) 723 # This test proves grad aliasing works - 724 self.assertEqual(x.grad, b * 5) 725 726 # Eager values 727 x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True) 728 y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 729 test_fn(reg_and_mul) 730 731 # Compiled 732 for backend in ["eager", "aot_eager", "inductor"]: 733 for compiled_bwd in [False, True]: 734 torch._dynamo.reset() 735 x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True) 736 y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) 737 738 cnts = torch._dynamo.testing.CompileCounterWithBackend(backend) 739 compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul) 740 741 compiled_bwd_ctx = ( 742 compiled_autograd.enable( 743 torch.compile(backend=backend, fullgraph=True) 744 ) 745 if compiled_bwd 746 else contextlib.nullcontext() 747 ) 748 with compiled_bwd_ctx: 749 test_fn(compiled_fn) 750 751 def test_recompile(self): 752 def hook(param): 753 param.grad *= 2 754 755 x = torch.ones(10) 756 x.requires_grad = True 757 758 def run(input): 759 return x * input 760 761 x.register_post_accumulate_grad_hook(hook) 762 with compiled_autograd.enable(compiler_fn): 763 for i in range(5): 764 with unittest.mock.patch( 765 "torch._dynamo.config.error_on_recompile", True 766 ): 767 # Mimic optimizer.zero_grad() to clear the gradient 768 x.grad = None 769 run(i).sum().backward() 770 771 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) 772 def test_no_recompile_on_same_hook(self): 773 cnts = torch._dynamo.testing.CompileCounter() 774 775 def fw_hook(inp): 776 return (inp[0] + 1,) 777 778 class Mod(torch.nn.Module): 779 def __init__(self) -> None: 780 super().__init__() 781 self.layers = torch.nn.ModuleList() 782 for i in range(10): 783 layer = torch.nn.Linear(16, 16) 784 layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp)) 785 layer = torch.compile(layer, backend=cnts) 786 self.layers.append(layer) 787 788 def forward(self, x): 789 for l in self.layers: 790 x = l(x) 791 return x 792 793 mod = Mod() 794 x = torch.ones(16, 16, requires_grad=True) 795 mod(x) 796 797 self.assertEqual(cnts.frame_count, 1) 798 799 800if __name__ == "__main__": 801 from torch._dynamo.test_case import run_tests 802 803 run_tests() 804