1# Owner(s): ["module: dynamo"] 2import unittest 3 4import torch 5import torch._dynamo.test_case 6import torch._dynamo.testing 7import torch.onnx.operators 8from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same 9from torch.nn import functional as F 10from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION 11from torch.testing._internal.common_utils import TEST_WITH_ROCM 12 13 14class CustomizedCtxManager: 15 def __init__(self, mode): 16 self.prev = torch.is_grad_enabled() 17 self.mode = mode 18 19 def __enter__(self): 20 torch._C._set_grad_enabled(self.mode) 21 22 def __exit__(self, exc_type, exc_value, traceback): 23 torch._C._set_grad_enabled(self.prev) 24 25 26class CustomizedCtxManagerWithGraphBreak(CustomizedCtxManager): 27 def __enter__(self): 28 torch._dynamo.graph_break() 29 super().__enter__() 30 31 32class CtxManagerTests(torch._dynamo.test_case.TestCase): 33 def test_no_grad(self): 34 def fn1(a, b): 35 x = a + 1 36 # redundant no_grad should get ignored 37 with torch.no_grad(): 38 x = x + b 39 x = x + 2 40 return x 41 42 def fn2(a, b): 43 x = a + 1 44 with torch.set_grad_enabled(False): 45 x = x + b 46 x = x + 2 47 return x 48 49 def fn3(a, b): 50 x = a + 1 51 with torch.enable_grad(): 52 x = x + b 53 x = x + 2 54 return x 55 56 def fn4(a, b): 57 x = a + 1 58 with torch.set_grad_enabled(True): 59 if torch.is_grad_enabled(): 60 x = x + b 61 x = x + 2 62 return x 63 64 with torch.no_grad(): 65 torch._dynamo.testing.standard_test( 66 self, fn=fn1, nargs=2, expected_ops=3 67 ) # coalesced noop 68 torch._dynamo.testing.standard_test( 69 self, fn=fn2, nargs=2, expected_ops=3 70 ) # coalesced noop 71 torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5) 72 torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5) 73 with torch.enable_grad(): 74 torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5) 75 torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5) 76 torch._dynamo.testing.standard_test( 77 self, fn=fn3, nargs=2, expected_ops=3 78 ) # coalesced noop 79 torch._dynamo.testing.standard_test( 80 self, fn=fn4, nargs=2, expected_ops=3 81 ) # coalesced noop 82 83 def test_grad_mode_guard(self): 84 def fn(a, b): 85 prev_grad = torch.is_grad_enabled() 86 torch.set_grad_enabled(False) 87 a = a + 1 88 a.tolist() # graph break 89 ret = a + b 90 torch.set_grad_enabled(prev_grad) 91 return ret 92 93 a = torch.randn([3, 4]) 94 b = torch.randn([3, 4]) 95 cnts = torch._dynamo.testing.CompileCounter() 96 opt_fn = torch._dynamo.optimize(cnts)(fn) 97 for _ in range(10): 98 opt_fn(a, b) 99 self.assertEqual(cnts.frame_count, 2) 100 101 def test_nested_grad_mode_graph_break(self): 102 def fn(x): 103 before = torch.is_grad_enabled() 104 with torch.set_grad_enabled(False): 105 torch._dynamo.graph_break() 106 with torch.set_grad_enabled(True): 107 x = torch.mul(x, 5) 108 torch._dynamo.graph_break() 109 x = torch.sqrt(x) 110 assert torch.is_grad_enabled() 111 assert not torch.is_grad_enabled() 112 assert torch.is_grad_enabled() == before 113 return x 114 115 a = torch.randn([3, 4]) 116 cnts = torch._dynamo.testing.CompileCounter() 117 opt_fn = torch._dynamo.optimize(cnts)(fn) 118 119 for _ in range(10): 120 opt_fn(a) 121 self.assertEqual(cnts.frame_count, 2) 122 123 def test_torch_profiler(self): 124 # wrap torch.profiler.* as NullContextVariable and do nothing 125 def fn(x): 126 y = x**2 127 with torch.profiler.profile(): 128 y = y + 2 129 with torch.profiler.record_function("my_function"): 130 z = y**3 131 z.tolist() # graph break 132 z = z + 1 133 return z 134 135 x = torch.randn((2, 2), requires_grad=True) 136 ref = fn(x) 137 cnts = torch._dynamo.testing.CompileCounter() 138 opt_fn = torch._dynamo.optimize(cnts)(fn) 139 res = opt_fn(x) 140 self.assertTrue(same(ref, res)) 141 self.assertEqual(cnts.frame_count, 2) 142 143 def test_autograd_profiler(self): 144 # wrap torch.autograd.profiler.* as NullContextVariable and do nothing 145 def fn(x): 146 y = x**2 147 with torch.autograd.profiler.profile(): 148 y = y + 2 149 with torch.autograd.profiler.record_function("my_function"): 150 z = y**3 151 z.tolist() # graph break 152 z = z + 1 153 return z 154 155 x = torch.randn((2, 2), requires_grad=True) 156 ref = fn(x) 157 cnts = torch._dynamo.testing.CompileCounter() 158 opt_fn = torch._dynamo.optimize(cnts)(fn) 159 res = opt_fn(x) 160 self.assertTrue(same(ref, res)) 161 self.assertEqual(cnts.frame_count, 2) 162 163 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 164 def test_cuda_stream_context_manager1(self): 165 def fn(x): 166 s = torch.cuda.Stream() 167 x = torch.mul(x, 5) 168 x = torch.add(x, 2) 169 current_stream = torch.cuda.current_stream() 170 s.wait_stream(current_stream) 171 with torch.cuda.stream(s): 172 x = torch.relu(x) 173 current_stream.wait_stream(s) 174 x = torch.add(x, 1) 175 x = torch.cos(x) 176 return x 177 178 x = torch.randn((2, 2), device="cuda") 179 ref = fn(x) 180 cnts = torch._dynamo.testing.CompileCounter() 181 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 182 res = opt_fn(x) 183 self.assertEqual(ref, res) 184 self.assertEqual(cnts.frame_count, 1) 185 self.assertEqual(cnts.op_count, 12) 186 187 @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204 188 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 189 def test_cuda_stream_across_graph_break(self): 190 def fn(x): 191 s = torch.cuda.Stream() 192 x = torch.mul(x, 5) 193 x = torch.add(x, 2) 194 195 print("foo") 196 197 tcs = torch.cuda.stream(s) 198 current_stream = torch.cuda.current_stream() 199 s.wait_stream(current_stream) 200 201 with tcs: 202 x = torch.relu(x) 203 204 current_stream.wait_stream(s) 205 x = torch.add(x, 1) 206 x = torch.cos(x) 207 return x 208 209 x = torch.randn((2, 2), device="cuda") 210 ref = fn(x) 211 cnts = torch._dynamo.testing.CompileCounter() 212 opt_fn = torch._dynamo.optimize(cnts)(fn) 213 res = opt_fn(x) 214 self.assertEqual(ref, res) 215 self.assertEqual(cnts.frame_count, 2) 216 self.assertEqual(cnts.op_count, 9) 217 218 @unittest.expectedFailure # https://github.com/pytorch/pytorch/issues/118204 219 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 220 def test_cuda_stream_context_manager2(self): 221 def fn(x, s): 222 x = torch.mul(x, 5) 223 x = torch.add(x, 2) 224 225 current_stream = torch.cuda.current_stream() 226 s.wait_stream(current_stream) 227 228 with torch.cuda.stream(s): 229 x = torch.relu(x) 230 231 current_stream.wait_stream(s) 232 with torch.cuda.stream(current_stream): 233 x = torch.relu(x) 234 235 s2 = torch.cuda.Stream() 236 s2.wait_stream(current_stream) 237 with torch.cuda.stream(s2): 238 x = torch.relu(x) 239 240 current_stream.wait_stream(s2) 241 x = torch.add(x, 1) 242 x = torch.cos(x) 243 return x 244 245 x = torch.randn((2, 2), device="cuda") 246 s = torch.cuda.Stream() 247 ref = fn(x, s) 248 cnts = torch._dynamo.testing.CompileCounter() 249 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 250 res = opt_fn(x, s) 251 self.assertEqual(ref, res) 252 self.assertEqual(cnts.frame_count, 1) 253 self.assertEqual(cnts.op_count, 18) 254 255 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 256 def test_cuda_stream_method(self): 257 def fn(x): 258 x = torch.mul(x, 1) 259 x = torch.add(x, 2) 260 261 new_stream = torch.cuda.Stream() 262 cur_stream = torch.cuda.current_stream() 263 new_stream.wait_stream(cur_stream) 264 265 with torch.cuda.stream(new_stream): 266 x = torch.sin(x) 267 x = torch.add(x, 3) 268 269 cur_stream.wait_stream(new_stream) 270 271 x = torch.add(x, 4) 272 is_idle = cur_stream.query() 273 cur_stream.synchronize() 274 275 with torch.cuda.stream(new_stream): 276 x = torch.add(x, 5) 277 new_stream.synchronize() 278 279 is_equal = cur_stream == new_stream 280 281 x = torch.relu(x) 282 x = torch.cos(x) 283 return x 284 285 x = torch.randn((2, 2), device="cuda") 286 ref = fn(x) 287 cnts = torch._dynamo.testing.CompileCounter() 288 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 289 res = opt_fn(x) 290 self.assertEqual(ref, res) 291 self.assertEqual(cnts.frame_count, 1) 292 self.assertEqual(cnts.op_count, 21) 293 294 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 295 def test_cuda_stream_compared_with_constant(self): 296 def fn(x): 297 x = torch.mul(x, 1) 298 x = torch.add(x, 2) 299 300 cur_stream = torch.cuda.current_stream() 301 if cur_stream is not None: 302 return x + 1 303 return x - 1 304 305 def fn2(x): 306 x = torch.mul(x, 1) 307 x = torch.add(x, 2) 308 309 cur_stream = torch.cuda.current_stream() 310 if cur_stream != "const_str": 311 return x + 1 312 return x - 1 313 314 x = torch.randn((2, 2), device="cuda") 315 ref = fn(x) 316 cnts = torch._dynamo.testing.CompileCounter() 317 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 318 opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2) 319 res = opt_fn(x) 320 res2 = opt_fn2(x) 321 self.assertEqual(ref, res) 322 self.assertEqual(ref, res2) 323 324 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 325 def test_cuda_stream_compared_with_stream(self): 326 def fn(x, s0, s1): 327 if s0 == s1: 328 return x + 1 329 else: 330 return x - 1 331 332 s0 = torch.cuda.Stream() 333 s1 = torch.cuda.Stream() 334 x = torch.randn(2, 2) 335 cnts = torch._dynamo.testing.CompileCounter() 336 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 337 338 ref0 = fn(x, s0, s1) 339 res0 = opt_fn(x, s0, s1) 340 self.assertEqual(cnts.frame_count, 1) 341 self.assertEqual(ref0, res0) 342 343 ref1 = fn(x, s1, s1) 344 res1 = opt_fn(x, s1, s1) 345 # We have a re-compilation because of chaning inputs 346 self.assertEqual(cnts.frame_count, 2) 347 self.assertEqual(ref1, res1) 348 349 torch._dynamo.reset() 350 cnts = torch._dynamo.testing.CompileCounter() 351 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 352 353 ref1 = fn(x, s1, s1) 354 res1 = opt_fn(x, s1, s1) 355 self.assertEqual(cnts.frame_count, 1) 356 self.assertEqual(ref1, res1) 357 358 ref0 = fn(x, s0, s1) 359 res0 = opt_fn(x, s0, s1) 360 # We have a re-compilation because of chaning inputs 361 self.assertEqual(cnts.frame_count, 2) 362 self.assertEqual(ref0, res0) 363 364 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 365 def test_cuda_event_reconstruct(self): 366 def fn(x): 367 e = torch.cuda.Event() 368 x = torch.mul(x, 5) 369 x = torch.add(x, 2) 370 return x, e 371 372 x = torch.randn((2, 2), device="cuda") 373 ref = fn(x) 374 cnts = torch._dynamo.testing.CompileCounter() 375 opt_fn = torch._dynamo.optimize(cnts)(fn) 376 res = opt_fn(x) 377 self.assertEqual(ref[0], res[0]) 378 self.assertEqual(cnts.frame_count, 1) 379 self.assertEqual(cnts.op_count, 3) 380 381 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 382 def test_cuda_event_across_graph_break(self): 383 def fn(x): 384 e = torch.cuda.Event() 385 e.record() 386 x = torch.mul(x, 5) 387 x = torch.add(x, 2) 388 389 print("foo") 390 391 torch.cuda.current_stream().wait_event(e) 392 x = torch.add(x, 1) 393 x = torch.cos(x) 394 return x, e 395 396 x = torch.randn((2, 2), device="cuda") 397 ref = fn(x) 398 cnts = torch._dynamo.testing.CompileCounter() 399 opt_fn = torch._dynamo.optimize(cnts)(fn) 400 res = opt_fn(x) 401 self.assertEqual(ref[0], res[0]) 402 self.assertEqual(cnts.frame_count, 2) 403 self.assertEqual(cnts.op_count, 9) 404 405 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 406 def test_cuda_event_created_outside_of_graph(self): 407 user_stream = torch.cuda.Stream() 408 event = torch.cuda.Event() 409 foo = torch.empty((2, 2), device="cuda") 410 411 def func(foo): 412 event.wait() 413 return foo + 1, event 414 415 x = torch.randn((1024, 1024), device="cuda") 416 cnts = torch._dynamo.testing.CompileCounter() 417 418 def run_iters(fn, compile=False): 419 if compile: 420 fn = torch._dynamo.optimize(cnts)(fn) 421 for _ in range(10): 422 with torch.cuda.stream(user_stream): 423 torch.mm(x, x, out=foo) 424 event.record() 425 out = fn(foo) 426 return out 427 428 ref = run_iters(func, compile=False) 429 res = run_iters(func, compile=True) 430 self.assertEqual(ref, res) 431 self.assertEqual(cnts.frame_count, 1) 432 self.assertEqual(cnts.op_count, 3) 433 434 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 435 def test_cuda_event_method_create_stream_outside_of_compile(self): 436 def fn(x, cur_stream, new_stream): 437 x = torch.mul(x, 1) 438 x = torch.add(x, 2) 439 440 x = torch.add(x, 3) 441 442 event = cur_stream.record_event() 443 is_idle = event.query() 444 445 new_stream.wait_event(event) 446 with torch.cuda.stream(new_stream): 447 x = torch.add(x, 4) 448 449 new_event = torch.cuda.Event() 450 new_event.record(new_stream) 451 452 new_event.wait(cur_stream) 453 x = torch.add(x, 5) 454 455 # use new event to sync 456 new_event.synchronize() 457 458 x = torch.relu(x) 459 x = torch.cos(x) 460 return x 461 462 x = torch.randn((2, 2), device="cuda") 463 cur_stream = torch.cuda.current_stream() 464 new_stream = torch.cuda.Stream() 465 ref = fn(x, cur_stream, new_stream) 466 cnts = torch._dynamo.testing.CompileCounter() 467 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 468 res = opt_fn(x, cur_stream, new_stream) 469 self.assertEqual(ref, res) 470 self.assertEqual(cnts.frame_count, 1) 471 self.assertEqual(cnts.op_count, 19) 472 473 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 474 def test_cuda_event_method(self): 475 def fn(x): 476 x = torch.mul(x, 1) 477 x = torch.add(x, 2) 478 479 cur_stream = torch.cuda.current_stream() 480 new_stream = torch.cuda.Stream() 481 482 x = torch.add(x, 3) 483 484 event = cur_stream.record_event() 485 is_idle = event.query() 486 487 new_stream.wait_event(event) 488 with torch.cuda.stream(new_stream): 489 x = torch.add(x, 4) 490 491 new_event = torch.cuda.Event() 492 new_event.record(new_stream) 493 494 new_event.wait(cur_stream) 495 x = torch.add(x, 5) 496 497 # use new event to sync 498 new_event.synchronize() 499 500 x = torch.relu(x) 501 x = torch.cos(x) 502 return x 503 504 x = torch.randn((2, 2), device="cuda") 505 ref = fn(x) 506 cnts = torch._dynamo.testing.CompileCounter() 507 opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn) 508 res = opt_fn(x) 509 self.assertEqual(ref, res) 510 self.assertEqual(cnts.frame_count, 1) 511 self.assertEqual(cnts.op_count, 19) 512 513 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 514 def test_cuda_device(self): 515 def fn(x): 516 with torch.cuda.device(x.device.index - 1): 517 x = torch.sin(x + 1) 518 return x 519 520 x = torch.randn((2, 2), device="cuda") 521 ref = fn(x) 522 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 523 res = opt_fn(x) 524 self.assertEqual(ref, res) 525 526 def test_autograd_profiler_enabled(self): 527 def fn(x): 528 if torch.autograd._profiler_enabled(): 529 return x + 1 530 else: 531 return x - 1 532 533 x = torch.randn((2, 2), requires_grad=True) 534 cnts = torch._dynamo.testing.CompileCounter() 535 opt_fn = torch._dynamo.optimize(cnts)(fn) 536 537 if torch.autograd._profiler_enabled(): 538 torch.autograd._disable_profiler() 539 assert not torch.autograd._profiler_enabled() 540 ref = fn(x) 541 res = opt_fn(x) 542 self.assertTrue(same(ref, res)) 543 544 with torch.autograd.profiler.profile(): 545 assert torch.autograd._profiler_enabled() 546 ref = fn(x) 547 res = opt_fn(x) 548 self.assertTrue(same(ref, res)) 549 550 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 551 def test_autocast(self): 552 if not torch.cuda.is_bf16_supported(): 553 raise unittest.SkipTest("requires bf16") 554 555 class MyModule(torch.nn.Module): 556 def forward(self, x): 557 a_float32 = torch.rand((8, 8), device="cuda") 558 b_float32 = torch.rand((8, 8), device="cuda") 559 d_float32 = torch.rand((8, 8), device="cuda") 560 561 with torch.autocast(device_type="cuda", dtype=torch.bfloat16): 562 e_float16 = torch.mm(a_float32, b_float32) 563 f_float16 = torch.mm(d_float32, e_float16) 564 return f_float16 565 566 module = MyModule() 567 real = module(torch.tensor([0.5])) 568 real_device = real.device 569 real_dtype = real.dtype 570 571 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 572 exported = graph(torch.tensor([0.5])) 573 self.assertEqual(exported.device, real_device) 574 self.assertEqual(exported.dtype, real_dtype) 575 576 self.assertEqual(exported.device.type, "cuda") 577 self.assertEqual(exported.device.index, 0) 578 self.assertEqual(exported.dtype, torch.bfloat16) 579 580 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 581 def test_cuda_amp_autocast(self): 582 class MyModule(torch.nn.Module): 583 def forward(self, x): 584 a_float32 = torch.rand((8, 8), device="cuda") 585 b_float32 = torch.rand((8, 8), device="cuda") 586 587 with torch.cuda.amp.autocast(dtype=torch.float64): 588 c_float64 = torch.mm(a_float32, b_float32) 589 return c_float64 590 591 module = MyModule() 592 real = module(torch.tensor([0.5])) 593 real_device = real.device 594 real_dtype = real.dtype 595 596 graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 597 exported = graph(torch.tensor([0.5])) 598 self.assertEqual(exported.device, real_device) 599 self.assertEqual(exported.dtype, real_dtype) 600 601 self.assertEqual(exported.device.type, "cuda") 602 self.assertEqual(exported.device.index, 0) 603 self.assertEqual(exported.dtype, torch.float64) 604 605 def test_is_autocast_cpu_enabled(self): 606 def fn(a_float32, b_float32): 607 with torch.cpu.amp.autocast(dtype=torch.bfloat16): 608 c_float16 = torch.mm(a_float32, b_float32) 609 if torch.is_autocast_cpu_enabled(): 610 c_float16 = c_float16 + 1 611 return c_float16 612 613 a = torch.rand((8, 8)) 614 b = torch.rand((8, 8)) 615 ref = fn(a, b) 616 opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn) 617 res = opt_fn(a, b) 618 self.assertTrue(same(ref, res)) 619 620 @unittest.skipIf( 621 not PLATFORM_SUPPORTS_FLASH_ATTENTION or TEST_WITH_ROCM, 622 "Can't run fused SDPA on this platform", 623 ) 624 def test_autocast_sdpa(self): 625 class MyModule(torch.nn.Module): 626 def forward(self, query, key, value): 627 with torch.autocast("cpu"): 628 with torch.autocast("cuda", dtype=torch.float32): 629 out = F.scaled_dot_product_attention( 630 query, key, value, None, 0.0, True 631 ) 632 return out 633 634 dtype = torch.float32 635 seq_len_q = 1 636 seq_len_k = 1 637 head_dim = 8 638 query = torch.ones( 639 1, 8, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True 640 ) 641 key = torch.ones( 642 1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True 643 ) 644 value = torch.ones( 645 1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True 646 ) 647 648 module = MyModule() 649 real = module(query, key, value) 650 real_device = real.device 651 real_dtype = real.dtype 652 653 opt_mod = torch._dynamo.optimize("inductor")(module) 654 compiled = opt_mod(query, key, value) 655 656 self.assertEqual(compiled.device, real_device) 657 self.assertEqual(compiled.dtype, real_dtype) 658 659 self.assertEqual(compiled.device.type, "cuda") 660 self.assertEqual(compiled.device.index, 0) 661 self.assertEqual(compiled.dtype, torch.float32) 662 663 def test_autocast_cpu(self): 664 class MyModule(torch.nn.Module): 665 def forward(self, x): 666 a_float32 = torch.rand((8, 8), device="cpu") 667 b_float32 = torch.rand((8, 8), device="cpu") 668 d_float32 = torch.rand((8, 8), device="cpu") 669 670 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 671 e_float16 = torch.mm(a_float32, b_float32) 672 f_float16 = torch.mm(d_float32, e_float16) 673 return f_float16 674 675 module = MyModule() 676 real = module(torch.tensor([0.5])) 677 real_device = real.device 678 real_dtype = real.dtype 679 680 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 681 exported = graph(torch.tensor([0.5])) 682 self.assertEqual(exported.device, real_device) 683 self.assertEqual(exported.dtype, real_dtype) 684 685 self.assertEqual(exported.device.type, "cpu") 686 self.assertEqual(exported.dtype, torch.bfloat16) 687 688 def test_autocast_cpu_graph_break(self): 689 class MyModule(torch.nn.Module): 690 def forward(self, x): 691 a_float32 = torch.rand((8, 8), device="cpu") 692 b_float32 = torch.rand((8, 8), device="cpu") 693 torch._dynamo.graph_break() 694 d_float32 = torch.rand((8, 8), device="cpu") 695 696 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 697 e_float16 = torch.mm(a_float32, b_float32) 698 torch._dynamo.graph_break() 699 f_float16 = torch.mm(d_float32, e_float16) 700 return f_float16 701 702 module = MyModule() 703 real = module(torch.tensor([0.5])) 704 real_device = real.device 705 real_dtype = real.dtype 706 707 opt = torch._dynamo.optimize("eager")(module) 708 res = opt(torch.tensor([0.5])) 709 self.assertEqual(res.device, real_device) 710 self.assertEqual(res.dtype, real_dtype) 711 712 self.assertEqual(res.device.type, "cpu") 713 self.assertEqual(res.dtype, torch.bfloat16) 714 715 def test_autocast_cpu_graph_break_2(self): 716 # Regression for: https://github.com/pytorch/pytorch/issues/93890 717 def fn(x): 718 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 719 x = torch.mm(x, x) 720 torch._dynamo.graph_break() 721 x = torch.relu(x) 722 return x 723 724 x = torch.rand([4, 4]) 725 self.assertEqual(x.dtype, torch.float32) 726 res = fn(x) 727 opt_fn = torch._dynamo.optimize("eager")(fn) 728 opt_res = opt_fn(x) 729 self.assertTrue(torch.allclose(res, opt_res)) 730 self.assertEqual(res.dtype, torch.bfloat16) 731 self.assertEqual(opt_res.dtype, torch.bfloat16) 732 733 def test_autocast_cpu_graph_break_inner_fn(self): 734 class MyModule(torch.nn.Module): 735 @staticmethod 736 def mm_breaks(x, y): 737 torch._dynamo.graph_break() 738 return torch.mm(x, y) 739 740 def forward(self, x): 741 a_float32 = torch.rand((8, 8), device="cpu") 742 b_float32 = torch.rand((8, 8), device="cpu") 743 744 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 745 torch._dynamo.graph_break() 746 with torch.autocast( 747 device_type="cpu", dtype=torch.bfloat16, enabled=False 748 ): 749 torch._dynamo.graph_break() 750 g_float32 = torch.mm(a_float32, b_float32) 751 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 752 # Check that nested with non-inlineable function with graph break 753 torch._dynamo.graph_break() 754 f_float16_1 = self.mm_breaks(a_float32, b_float32) 755 # We remember to exit the inner autocast correctly to outer 756 # even after graph breaks 757 f_float16 = self.mm_breaks(a_float32, b_float32) 758 assert f_float16.dtype == f_float16_1.dtype 759 return f_float16, g_float32 760 761 module = MyModule() 762 real_16, real_32 = module(torch.tensor([0.5])) 763 real_device_16 = real_16.device 764 real_dtype_16 = real_16.dtype 765 real_device_32 = real_32.device 766 real_dtype_32 = real_32.dtype 767 768 graph = torch._dynamo.optimize("eager")(module) 769 out_16, out_32 = graph(torch.tensor([0.5])) 770 self.assertEqual(out_16.device, real_device_16) 771 self.assertEqual(out_16.dtype, real_dtype_16) 772 self.assertEqual(out_32.device, real_device_32) 773 self.assertEqual(out_32.dtype, real_dtype_32) 774 775 self.assertEqual(out_16.device.type, "cpu") 776 self.assertEqual(out_16.dtype, torch.bfloat16) 777 self.assertEqual(out_32.device.type, "cpu") 778 self.assertEqual(out_32.dtype, torch.float32) 779 780 def test_autocast_graph_break_method(self): 781 class MyModule(torch.nn.Module): 782 def __init__(self, bias): 783 super().__init__() 784 self.bias = bias 785 786 def mm_not_break(self, x, y): 787 return torch.mm(x, y) + self.bias 788 789 def mm_breaks(self, x, y): 790 torch._dynamo.graph_break() 791 return torch.mm(x, y) + self.bias 792 793 def forward(self, x): 794 a_float32 = torch.rand((8, 8), device="cpu") 795 b_float32 = torch.rand((8, 8), device="cpu") 796 797 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 798 with torch.autocast( 799 device_type="cpu", dtype=torch.bfloat16, enabled=False 800 ): 801 g_float32 = torch.mm(a_float32, b_float32) 802 f_float16 = self.mm_breaks(a_float32, b_float32) 803 804 assert ( 805 f_float16[0][0] == self.mm_not_break(a_float32, b_float32)[0][0] 806 ) 807 return f_float16, g_float32 808 809 module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16)) 810 811 with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 812 # Autocast doesn't work on addition, so we need the bias to be `bfloat16` 813 res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand( 814 (8, 8), device="cpu", dtype=torch.bfloat16 815 ) 816 self.assertEqual(res.dtype, torch.float32) 817 818 real_16, real_32 = module(torch.tensor([0.5])) 819 real_device_16 = real_16.device 820 real_dtype_16 = real_16.dtype 821 real_device_32 = real_32.device 822 real_dtype_32 = real_32.dtype 823 824 graph = torch._dynamo.optimize("eager")(module) 825 out_16, out_32 = graph(torch.tensor([0.5])) 826 self.assertEqual(out_16.device, real_device_16) 827 self.assertEqual(out_16.dtype, real_dtype_16) 828 self.assertEqual(out_32.device, real_device_32) 829 self.assertEqual(out_32.dtype, real_dtype_32) 830 831 self.assertEqual(out_16.device.type, "cpu") 832 self.assertEqual(out_16.dtype, torch.bfloat16) 833 self.assertEqual(out_32.device.type, "cpu") 834 self.assertEqual(out_32.dtype, torch.float32) 835 836 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 837 def test_autocast_float64(self): 838 class MyModule(torch.nn.Module): 839 def forward(self, x): 840 a_float32 = torch.rand((8, 8), device="cuda") 841 b_float32 = torch.rand((8, 8), device="cuda") 842 d_float32 = torch.rand((8, 8), device="cuda") 843 844 with torch.autocast(device_type="cuda", dtype=torch.float64): 845 e_float64 = torch.mm(a_float32, b_float32) 846 f_float64 = torch.mm(d_float32, e_float64) 847 return f_float64 848 849 module = MyModule() 850 real = module(torch.tensor([0.5])) 851 real_device = real.device 852 real_dtype = real.dtype 853 854 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 855 exported = graph(torch.tensor([0.5])) 856 self.assertEqual(exported.device, real_device) 857 self.assertEqual(exported.dtype, real_dtype) 858 859 self.assertEqual(exported.device.index, 0) 860 self.assertEqual(exported.dtype, torch.float64) 861 862 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 863 def test_autocast_device(self): 864 class MyModule(torch.nn.Module): 865 def forward(self, x): 866 a_float32 = torch.rand((8, 8), device="cuda") 867 b_float32 = torch.rand((8, 8), device="cuda") 868 d_float32 = torch.rand((8, 8), device="cuda") 869 870 with torch.autocast("cuda"): 871 e_float64 = torch.mm(a_float32, b_float32) 872 f_float64 = torch.mm(d_float32, e_float64) 873 return f_float64 874 875 module = MyModule() 876 real = module(torch.tensor([0.5])) 877 real_device = real.device 878 real_dtype = real.dtype 879 880 graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]])) 881 exported = graph(torch.tensor([0.5])) 882 self.assertEqual(exported.device, real_device) 883 self.assertEqual(exported.dtype, real_dtype) 884 885 self.assertEqual(exported.device.index, 0) 886 self.assertEqual(exported.dtype, torch.float16) 887 888 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 889 def test_autocast_arguments_binding(self): 890 def f1(x): 891 with torch.cuda.amp.autocast(False): 892 x = torch.sin(x + 1) 893 return x 894 895 def f2(x): 896 with torch.cpu.amp.autocast(False): 897 x = torch.cos(x + 1) 898 return x 899 900 x = torch.rand([2, 3]) 901 ref1 = f1(x) 902 ref2 = f2(x) 903 opt_f1 = torch.compile(backend="eager")(f1) 904 opt_f2 = torch.compile(backend="eager")(f2) 905 res1 = opt_f1(x) 906 res2 = opt_f2(x) 907 self.assertTrue(same(ref1, res1)) 908 self.assertTrue(same(ref2, res2)) 909 910 @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 911 def test_autocast_decorator(self): 912 def autocast_func(orig_func): 913 @torch.amp.autocast(device_type="cuda", dtype=torch.float16) 914 def new_fwd(*args, **kwargs): 915 return orig_func(*args, **kwargs) 916 917 return new_fwd 918 919 def autocast_func_cuda(orig_func): 920 @torch.cuda.amp.autocast(dtype=torch.float16) 921 def new_fwd(*args, **kwargs): 922 return orig_func(*args, **kwargs) 923 924 return new_fwd 925 926 def autocast_func_cpu(orig_func): 927 @torch.cpu.amp.autocast(dtype=torch.float16) 928 def new_fwd(*args, **kwargs): 929 return orig_func(*args, **kwargs) 930 931 return new_fwd 932 933 def mm(a, b): 934 return torch.mm(a, b) 935 936 mm_float16 = autocast_func(mm) 937 mm_float16_cuda = autocast_func_cuda(mm) 938 mm_float16_cpu = autocast_func_cpu(mm) 939 940 def fn(a, b): 941 return mm_float16(a, b), mm_float16_cuda(a, b), mm_float16_cpu(a, b) 942 943 a_float32 = torch.rand((8, 8), device="cuda") 944 b_float32 = torch.rand((8, 8), device="cuda") 945 946 ref = fn(a_float32, b_float32) 947 opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) 948 res = opt_fn(a_float32, b_float32) 949 self.assertTrue(same(ref, res)) 950 self.assertTrue(res[0].dtype == torch.float16) 951 self.assertTrue(res[1].dtype == torch.float16) 952 953 def test_generic_ctx_manager_with_graph_break(self): 954 def fn(x): 955 with CustomizedCtxManagerWithGraphBreak(False): 956 # body runs on eager 957 y = x * 2 958 z = y.sin() + 3 959 return z 960 961 x = torch.randn(2, 3) 962 opt_fn = torch.compile(backend="eager", fullgraph=False)(fn) 963 self.assertEqual(fn(x), opt_fn(x)) 964 965 def test_return_context_manager(self): 966 @torch.compile(backend="eager", fullgraph=True) 967 def f(x): 968 cm = CustomizedCtxManager(False) 969 with cm: 970 pass 971 return cm 972 973 x = torch.randn(2, 3) 974 cm = f(x) 975 self.assertFalse(cm.mode) 976 977 def test_return_context_manager_with_graph_break(self): 978 @torch.compile(backend="eager", fullgraph=False) 979 def f(x): 980 cm = CustomizedCtxManager(False) 981 torch._dynamo.graph_break() 982 with cm: 983 pass 984 return cm 985 986 x = torch.randn(2, 3) 987 cm = f(x) 988 self.assertFalse(cm.mode) 989 990 def test_generic_context_manager(self): 991 def fn(x): 992 with CustomizedCtxManager(True): 993 x = x + 1 994 if torch.is_grad_enabled(): 995 x = x * 2 996 x = torch.relu(x) 997 return x - 1 998 999 x = torch.rand(2, 3) 1000 cnts = torch._dynamo.testing.CompileCounter() 1001 opt_fn = torch.compile(backend=cnts, fullgraph=True)(fn) 1002 1003 with torch.no_grad(): 1004 ref = fn(x) 1005 res = opt_fn(x) 1006 self.assertTrue(same(ref, res)) 1007 self.assertEqual(cnts.frame_count, 1) 1008 self.assertEqual(cnts.op_count, 6) 1009 1010 with torch.enable_grad(): 1011 ref = fn(x) 1012 res = opt_fn(x) 1013 self.assertTrue(same(ref, res)) 1014 self.assertEqual(cnts.frame_count, 2) 1015 self.assertEqual(cnts.op_count, 12) 1016 1017 def test_nested_generic_context_manager(self): 1018 def fn(x): 1019 with CustomizedCtxManager(True): 1020 x = x + 1 1021 if torch.is_grad_enabled(): 1022 x = x * 2 1023 with CustomizedCtxManager(False): 1024 if torch.is_grad_enabled(): 1025 x = x - 3 1026 x = x * 1.5 1027 x = torch.relu(x) 1028 return x - 1 1029 1030 x = torch.rand(2, 3) 1031 cnts = torch._dynamo.testing.CompileCounter() 1032 opt_fn = torch.compile(backend=cnts, fullgraph=True)(fn) 1033 1034 with torch.no_grad(): 1035 ref = fn(x) 1036 res = opt_fn(x) 1037 self.assertTrue(same(ref, res)) 1038 self.assertEqual(cnts.frame_count, 1) 1039 self.assertEqual(cnts.op_count, 9) 1040 1041 with torch.enable_grad(): 1042 ref = fn(x) 1043 res = opt_fn(x) 1044 self.assertTrue(same(ref, res)) 1045 self.assertEqual(cnts.frame_count, 2) 1046 self.assertEqual(cnts.op_count, 18) 1047 1048 def test_generic_context_manager_with_graph_break(self): 1049 def fn(x): 1050 with CustomizedCtxManager(True): 1051 x = x + 1 1052 if torch.is_grad_enabled(): 1053 x = x * 2 1054 torch._dynamo.graph_break() 1055 x = torch.relu(x) 1056 return x - 1 1057 1058 x = torch.rand(2, 3) 1059 cnts = torch._dynamo.testing.CompileCounter() 1060 opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn) 1061 1062 with torch.no_grad(): 1063 ref = fn(x) 1064 res = opt_fn(x) 1065 self.assertTrue(same(ref, res)) 1066 self.assertEqual(cnts.frame_count, 2) 1067 self.assertEqual(cnts.op_count, 2) 1068 1069 with torch.enable_grad(): 1070 ref = fn(x) 1071 res = opt_fn(x) 1072 self.assertTrue(same(ref, res)) 1073 self.assertEqual(cnts.frame_count, 4) 1074 self.assertEqual(cnts.op_count, 4) 1075 1076 def test_nested_generic_context_manager_with_graph_break(self): 1077 def fn(x): 1078 with CustomizedCtxManager(True): 1079 x = x + 1 1080 if torch.is_grad_enabled(): 1081 x = x * 2 1082 with CustomizedCtxManager(False): 1083 if torch.is_grad_enabled(): 1084 x = x - 3 1085 torch._dynamo.graph_break() 1086 x = x * 1.5 1087 x = torch.relu(x) 1088 return x - 1 1089 1090 x = torch.rand(2, 3) 1091 cnts = torch._dynamo.testing.CompileCounter() 1092 opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn) 1093 1094 with torch.no_grad(): 1095 ref = fn(x) 1096 res = opt_fn(x) 1097 self.assertTrue(same(ref, res)) 1098 self.assertEqual(cnts.frame_count, 4) 1099 self.assertEqual(cnts.op_count, 4) 1100 1101 torch._dynamo.reset() 1102 cnts = torch._dynamo.testing.CompileCounter() 1103 opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn) 1104 1105 with torch.enable_grad(): 1106 ref = fn(x) 1107 res = opt_fn(x) 1108 self.assertTrue(same(ref, res)) 1109 self.assertEqual(cnts.frame_count, 4) 1110 self.assertEqual(cnts.op_count, 4) 1111 1112 def test_graph_break_inlining_grad(self): 1113 def gn(z): 1114 with torch.no_grad(): 1115 torch._dynamo.graph_break() 1116 return torch.sin(z) 1117 1118 def fn(x, y, z): 1119 a = torch.mm(x, y) 1120 z = gn(z) 1121 return a 1122 1123 torch._dynamo.reset() 1124 cnts = torch._dynamo.testing.CompileCounter() 1125 opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn) 1126 x = torch.randn(4, 4, requires_grad=True) 1127 y = torch.randn(4, 4, requires_grad=True) 1128 z = torch.randn(4) 1129 opt_fn(x, y, z).sum().backward() 1130 1131 self.assertEqual(cnts.frame_count, 2) 1132 1133 def _graph_break_inlining_autocast_test_helper(self, device): 1134 def gn(x, y): 1135 with torch.autocast(device_type=device, dtype=torch.bfloat16): 1136 z = torch.mm(x, y) 1137 torch._dynamo.graph_break() 1138 return torch.sin(z) 1139 1140 def fn(x, y): 1141 z = torch.mm(x, y) 1142 z = z + gn(x, y) 1143 return z 1144 1145 x = torch.rand(3, 3).to(device) 1146 y = torch.rand(3, 3).to(device) 1147 opt_fn = torch.compile(backend="eager")(fn) 1148 ref = fn(x, y) 1149 res = opt_fn(x, y) 1150 self.assertEqual(ref, res) 1151 1152 def test_graph_break_inlining_autocast(self): 1153 for device in ["cuda", "cpu"]: 1154 if device == "cuda" and not ( 1155 torch.cuda.is_available() and torch.cuda.is_bf16_supported() 1156 ): 1157 continue 1158 self._graph_break_inlining_autocast_test_helper(device) 1159 1160 def test_disable_saved_tensors_hooks(self): 1161 def fn(z): 1162 @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported") 1163 def f(x, y): 1164 return x + y 1165 1166 x, y = torch.ones( 1167 1, 1168 ), torch.zeros( 1169 1, 1170 ) 1171 return f(x, y) 1172 1173 eager = EagerAndRecordGraphs() 1174 torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(())) 1175 1176 graph = eager.graphs[0] 1177 actual = normalize_gm(graph.print_readable(False)) 1178 1179 self.assertExpectedInline( 1180 actual, 1181 """\ 1182class GraphModule(torch.nn.Module): 1183 def forward(self): 1184 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None 1185 1186 x: "f32[1]" = torch.ones(1) 1187 1188 y: "f32[1]" = torch.zeros(1) 1189 1190 add: "f32[1]" = x + y; x = y = None 1191 1192 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 1193 return (add,) 1194""", # NOQA: B950 1195 ) 1196 1197 def test_disable_saved_tensors_hooks_prev_disabled(self): 1198 def fn(z): 1199 @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported") 1200 def f(x, y): 1201 return x + y 1202 1203 x, y = torch.ones( 1204 1, 1205 ), torch.zeros( 1206 1, 1207 ) 1208 return f(x, y) 1209 1210 eager = EagerAndRecordGraphs() 1211 with torch.autograd.graph.disable_saved_tensors_hooks( 1212 "Previously disabled message" 1213 ): 1214 torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(())) 1215 1216 graph = eager.graphs[0] 1217 actual = normalize_gm(graph.print_readable(False)) 1218 1219 self.assertExpectedInline( 1220 actual, 1221 """\ 1222class GraphModule(torch.nn.Module): 1223 def forward(self): 1224 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None 1225 1226 x: "f32[1]" = torch.ones(1) 1227 1228 y: "f32[1]" = torch.zeros(1) 1229 1230 add: "f32[1]" = x + y; x = y = None 1231 1232 _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message'); _saved_tensors_hooks_disable_1 = None 1233 return (add,) 1234""", # NOQA: B950 1235 ) 1236 1237 def test_disable_saved_tensors_hooks_prev_disabled_nested(self): 1238 def fn(z): 1239 @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported") 1240 def f(x, y): 1241 @torch.autograd.graph.disable_saved_tensors_hooks( 1242 "This is not supported inner" 1243 ) 1244 def inner_fn(x, y): 1245 return x + y 1246 1247 return inner_fn(x, y) + x 1248 1249 x, y = torch.ones( 1250 1, 1251 ), torch.zeros( 1252 1, 1253 ) 1254 return f(x, y) 1255 1256 eager = EagerAndRecordGraphs() 1257 with torch.autograd.graph.disable_saved_tensors_hooks( 1258 "Previously disabled message" 1259 ): 1260 torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(())) 1261 1262 graph = eager.graphs[0] 1263 actual = normalize_gm(graph.print_readable(False)) 1264 1265 self.assertExpectedInline( 1266 actual, 1267 """\ 1268class GraphModule(torch.nn.Module): 1269 def forward(self): 1270 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None 1271 1272 x: "f32[1]" = torch.ones(1) 1273 1274 y: "f32[1]" = torch.zeros(1) 1275 1276 _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner'); _saved_tensors_hooks_disable_1 = None 1277 1278 add: "f32[1]" = x + y; y = None 1279 1280 _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable_2 = None 1281 1282 add_1: "f32[1]" = add + x; add = x = None 1283 1284 _saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message'); _saved_tensors_hooks_disable_3 = None 1285 return (add_1,) 1286""", # NOQA: B950 1287 ) 1288 1289 def test_disable_saved_tensors_hooks_graph_break(self): 1290 def fn(x): 1291 with torch.autograd.graph.disable_saved_tensors_hooks( 1292 "This is not supported" 1293 ): 1294 y = x + 1 1295 torch._dynamo.graph_break() 1296 return y * 2 1297 1298 eager = EagerAndRecordGraphs() 1299 torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(())) 1300 1301 def check_graph(actual, expected): 1302 self.assertExpectedInline(actual, expected) 1303 1304 graph = eager.graphs[0] 1305 actual = normalize_gm(graph.print_readable(False)) 1306 self.assertExpectedInline( 1307 actual, 1308 """\ 1309class GraphModule(torch.nn.Module): 1310 def forward(self, L_x_: "f32[]"): 1311 l_x_ = L_x_ 1312 1313 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None 1314 1315 y: "f32[]" = l_x_ + 1; l_x_ = None 1316 1317 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 1318 return (y,) 1319""", # NOQA: B950 1320 ) 1321 1322 graph = eager.graphs[1] 1323 actual = normalize_gm(graph.print_readable(False)) 1324 self.assertExpectedInline( 1325 actual, 1326 """\ 1327class GraphModule(torch.nn.Module): 1328 def forward(self, L_y_: "f32[]"): 1329 l_y_ = L_y_ 1330 1331 _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported'); _saved_tensors_hooks_disable = None 1332 1333 mul: "f32[]" = l_y_ * 2; l_y_ = None 1334 1335 _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None 1336 return (mul,) 1337""", # NOQA: B950 1338 ) 1339 1340 def test_context_wrapping_grad_mode_decorator(self): 1341 ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)] 1342 for call in [True, False]: 1343 for i in range(2): 1344 torch._dynamo.reset() 1345 1346 ctx_wrapper, mode = ctx_wrappers[i] 1347 ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2] 1348 1349 def fn(x): 1350 def inner_func(x): 1351 return x.sin() 1352 1353 with ctx_wrapper_inverse(): 1354 if call: 1355 inner_func = ctx_wrapper()(inner_func) 1356 else: 1357 inner_func = ctx_wrapper(inner_func) 1358 1359 # Calling no_grad or enabled_grad should not mutate global state 1360 assert torch.is_grad_enabled() == mode_inverse 1361 1362 with ctx_wrapper_inverse(): 1363 return inner_func(x) 1364 1365 x = torch.zeros(10, requires_grad=True) 1366 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1367 self.assertEqual(fn(x), opt_fn(x)) 1368 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1369 1370 def test_context_wrapping_grad_mode_nested_function_decorator(self): 1371 ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)] 1372 1373 for call in [True, False]: 1374 for i in range(2): 1375 torch._dynamo.reset() 1376 1377 ctx_wrapper, mode = ctx_wrappers[i] 1378 ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2] 1379 1380 def fn(x): 1381 with ctx_wrapper_inverse(): 1382 if call: 1383 1384 @ctx_wrapper() 1385 def inner_func(x): 1386 return x.sin() 1387 1388 else: 1389 1390 @ctx_wrapper 1391 def inner_func(x): 1392 return x.sin() 1393 1394 # Calling no_grad or enabled_grad should not mutate global state 1395 assert torch.is_grad_enabled() == mode_inverse 1396 1397 with ctx_wrapper_inverse(): 1398 return inner_func(x) 1399 1400 x = torch.zeros(10, requires_grad=True) 1401 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1402 self.assertEqual(fn(x), opt_fn(x)) 1403 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1404 1405 def test_context_wrapping_set_grad_enabled_nested_function(self): 1406 modes = [True, False] 1407 for decorator in [True, False]: 1408 for i in range(2): 1409 torch._dynamo.reset() 1410 1411 mode = modes[i] 1412 mode_inverse = modes[(i + 1) % 2] 1413 1414 def fn(x): 1415 with torch.set_grad_enabled(mode_inverse): 1416 if decorator: 1417 1418 @torch.set_grad_enabled(mode) 1419 def inner_func(x): 1420 return x.sin() 1421 1422 else: 1423 1424 def inner_func(x): 1425 return x.sin() 1426 1427 inner_func = torch.set_grad_enabled(mode)(inner_func) 1428 1429 # Consuming set_grad_enabled by calling it on a function 1430 # should not mutate global state 1431 assert torch.is_grad_enabled() == mode_inverse 1432 1433 with torch.set_grad_enabled(mode_inverse): 1434 return inner_func(x) 1435 1436 x = torch.zeros(10, requires_grad=True) 1437 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1438 self.assertEqual(fn(x), opt_fn(x)) 1439 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1440 1441 def test_inactive_context_graph_break_local(self): 1442 def fn(x): 1443 x = x + 1 1444 ctx = torch.set_grad_enabled(True) 1445 torch._dynamo.graph_break() 1446 with ctx: 1447 x = x + 1 1448 return x 1449 1450 x = torch.zeros(10, requires_grad=False) 1451 cnts = torch._dynamo.testing.CompileCounter() 1452 opt_fn = torch.compile(fn, backend=cnts) 1453 self.assertEqual(fn(x), opt_fn(x)) 1454 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1455 self.assertEqual(cnts.frame_count, 2) 1456 1457 def test_inactive_context_graph_break_local_nullctx(self): 1458 import contextlib 1459 1460 # test with context manager that results in None target_values 1461 def fn(x): 1462 x = x + 1 1463 ctx = contextlib.nullcontext() 1464 torch._dynamo.graph_break() 1465 with ctx: 1466 x = x + 1 1467 return x 1468 1469 x = torch.zeros(10, requires_grad=False) 1470 cnts = torch._dynamo.testing.CompileCounter() 1471 opt_fn = torch.compile(fn, backend=cnts) 1472 self.assertEqual(fn(x), opt_fn(x)) 1473 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1474 self.assertEqual(cnts.frame_count, 2) 1475 1476 def test_inactive_context_graph_break_local_nullctx2(self): 1477 import contextlib 1478 1479 # test with nullcontext where graph break happens 1480 # in an inlined function that returns something 1481 def gn(): 1482 torch._dynamo.graph_break() 1483 return [0, 1, 2] 1484 1485 def fn(x): 1486 x = x + 1 1487 ctx = contextlib.nullcontext() 1488 lst = gn() 1489 with ctx: 1490 x = x + lst[1] 1491 return x 1492 1493 x = torch.zeros(10, requires_grad=False) 1494 cnts = torch._dynamo.testing.CompileCounter() 1495 opt_fn = torch.compile(fn, backend=cnts) 1496 self.assertEqual(fn(x), opt_fn(x)) 1497 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1498 self.assertEqual(cnts.frame_count, 2) 1499 1500 def test_inactive_context_graph_break_stack(self): 1501 def gn(ctx): 1502 torch._dynamo.graph_break() 1503 return ctx 1504 1505 def fn(x): 1506 x = x + 1 1507 ctx = gn(torch.set_grad_enabled(True)) 1508 # we expect a graph break on next line as well 1509 with ctx: 1510 x = x + 1 1511 return x 1512 1513 x = torch.zeros(10, requires_grad=False) 1514 cnts = torch._dynamo.testing.CompileCounter() 1515 opt_fn = torch.compile(fn, backend=cnts) 1516 self.assertEqual(fn(x), opt_fn(x)) 1517 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1518 1519 def test_inactive_context_graph_break_stack2(self): 1520 def gn(x, ctx, y, z, dummy): 1521 with ctx: 1522 return x * y * z 1523 1524 def fn(x): 1525 x = x + 1 1526 x = gn(x, torch.set_grad_enabled(True), 2, 3, torch._dynamo.graph_break()) 1527 return x 1528 1529 x = torch.zeros(10, requires_grad=False) 1530 cnts = torch._dynamo.testing.CompileCounter() 1531 opt_fn = torch.compile(fn, backend=cnts) 1532 self.assertEqual(fn(x), opt_fn(x)) 1533 self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) 1534 self.assertEqual(cnts.frame_count, 2) 1535 1536 1537if __name__ == "__main__": 1538 from torch._dynamo.test_case import run_tests 1539 1540 run_tests() 1541