1# Owner(s): ["module: dynamo"] 2import copy 3import functools 4import math 5import unittest # noqa: F811 6from importlib import import_module 7 8import torch 9import torch._dynamo.config 10import torch._dynamo.test_case 11import torch._functorch.config 12import torch.distributed as dist 13import torch.nn as nn 14import torch.utils.checkpoint 15from functorch.compile import min_cut_rematerialization_partition 16from torch._dynamo.backends.common import aot_autograd 17from torch._dynamo.testing import CompileCounterWithBackend 18from torch._higher_order_ops.wrap import tag_activation_checkpoint 19from torch.testing._internal.common_cuda import ( 20 PLATFORM_SUPPORTS_CUDNN_ATTENTION, 21 SM90OrLater, 22) 23from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm 24from torch.testing._internal.inductor_utils import HAS_CUDA 25from torch.testing._internal.two_tensor import TwoTensor 26from torch.utils.checkpoint import ( 27 checkpoint, 28 CheckpointPolicy, 29 create_selective_checkpoint_contexts, 30) 31 32 33requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 34requires_distributed = functools.partial( 35 unittest.skipIf, not dist.is_available(), "requires distributed" 36) 37 38 39def checkpoint_wrapper(fn): 40 def inner(*args): 41 return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) 42 43 return inner 44 45 46def count_ops( 47 gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None 48): 49 def match_rng_op(node, op): 50 if isinstance(node.target, torch._ops.HigherOrderOperator): 51 if node.name == "run_and_save_rng_state": 52 return node.args[0] == op 53 elif node.name == "run_with_rng_state": 54 return node.args[1] == op 55 return False 56 57 # assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops) 58 if op is not None: 59 assert not isinstance(op, list) 60 ops = [op] 61 if freq is not None: 62 freqs = [freq] 63 if freq_ge is not None: 64 freqs_ge = [freq_ge] 65 if freqs: 66 for op, freq in zip(ops, freqs): 67 actual_count = 0 68 for node in gm.graph.nodes: 69 if match_rng_op(node, op) or node.target == op: 70 actual_count += 1 71 err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}." 72 assert actual_count == freq, err_msg 73 else: 74 assert freqs_ge is not None 75 for op, freq_ge in zip(ops, freqs_ge): 76 actual_count = 0 77 for node in gm.graph.nodes: 78 if match_rng_op(node, op) or node.target == op: 79 actual_count += 1 80 assert ( 81 actual_count >= freq_ge 82 ), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." 83 return gm 84 85 86class _InvalidContext: 87 def __init__(self) -> None: 88 pass 89 90 def __enter__(self): 91 return self 92 93 def __exit__(self, exc_type, exc_val, exc_tb): 94 pass 95 96 97def _invalid_context_gen(): 98 return _InvalidContext(), _InvalidContext() 99 100 101def find_first_node(gm, func): 102 for node in gm.graph.nodes: 103 if node.target is func: 104 return node 105 return None 106 107 108def op_count(gm): 109 result = 0 110 for node in gm.graph.nodes: 111 if "call" in node.op: 112 result += 1 113 return result 114 115 116def _get_custom_policy(no_recompute_list=None, must_recompute_list=None): 117 def _custom_policy(ctx, func, *args, **kwargs): 118 if no_recompute_list is not None and func in no_recompute_list: 119 return CheckpointPolicy.MUST_SAVE 120 if must_recompute_list is not None and func in must_recompute_list: 121 return CheckpointPolicy.MUST_RECOMPUTE 122 else: 123 return CheckpointPolicy.PREFER_RECOMPUTE 124 125 return _custom_policy 126 127 128class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase): 129 def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True): 130 cloned_args = [] 131 for arg in args: 132 cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad)) 133 134 torch.manual_seed(0) 135 expected = fn(*args) 136 expected.sum().backward() 137 138 torch.manual_seed(0) 139 result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args) 140 result.sum().backward() 141 142 if not skip_check: 143 self.assertEqual( 144 result, 145 expected, 146 msg="Output mismatch between torch.compile and eager versions", 147 ) 148 for arg, cloned_arg in zip(args, cloned_args): 149 self.assertEqual( 150 arg.grad, 151 cloned_arg.grad, 152 msg="Gradient mismatch between torch.compile and eager versions", 153 ) 154 155 def _compare_orig_and_checkpointed_fns( 156 self, orig_fn, checkpointed_fn, *args, fullgraph=True 157 ): 158 # The original version and the checkpointed version of the same function 159 # should produce the same outputs and the same gradients under torch.compile. 160 161 # Run original version 162 cloned_args_orig_fn = [] 163 for arg in args: 164 cloned_args_orig_fn.append( 165 arg.clone().detach().requires_grad_(arg.requires_grad) 166 ) 167 torch.manual_seed(0) 168 compiled_orig_fn = torch.compile( 169 orig_fn, fullgraph=fullgraph, backend="inductor" 170 ) 171 result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn) 172 result_orig_fn.sum().backward() 173 174 # Run checkpointed version 175 cloned_args_checkpointed_fn = [] 176 for arg in args: 177 cloned_args_checkpointed_fn.append( 178 arg.clone().detach().requires_grad_(arg.requires_grad) 179 ) 180 torch.manual_seed(0) 181 compiled_checkpointed_fn = torch.compile( 182 checkpointed_fn, fullgraph=fullgraph, backend="inductor" 183 ) 184 result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn) 185 result_checkpointed_fn.sum().backward() 186 187 # Check that outputs and gradients are equal 188 self.assertEqual( 189 result_orig_fn, 190 result_checkpointed_fn, 191 msg="Output mismatch between the original version and the checkpointed version of the same function", 192 ) 193 for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip( 194 cloned_args_orig_fn, cloned_args_checkpointed_fn 195 ): 196 self.assertEqual( 197 cloned_arg_orig_fn.grad, 198 cloned_arg_checkpointed_fn.grad, 199 msg="Gradient mismatch between the original version and the checkpointed version of the same function", 200 ) 201 202 @requires_cuda 203 def test_tags_function(self): 204 def gn(x, y): 205 return torch.sigmoid(torch.matmul(x, y)) 206 207 def fn(x, y): 208 return torch.utils.checkpoint.checkpoint( 209 gn, torch.sin(x), y, use_reentrant=True 210 ) 211 212 x = torch.randn(4, 4, device="cuda", requires_grad=True) 213 y = torch.randn(4, 4, device="cuda", requires_grad=True) 214 215 fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) 216 bw_compiler = functools.partial( 217 count_ops, freq=3, op=torch.ops.aten.mm.default 218 ) # mm recomputed in the bwd 219 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 220 self._validate(fn, backend, x, y) 221 222 @requires_cuda 223 def test_tags_function_via_global_checkpoint(self): 224 def gn(x, y): 225 return torch.sigmoid(torch.matmul(x, y)) 226 227 def fn(x, y): 228 # This goes through VariableBuilder 229 return checkpoint(gn, torch.sin(x), y, use_reentrant=True) 230 231 x = torch.randn(4, 4, device="cuda", requires_grad=True) 232 y = torch.randn(4, 4, device="cuda", requires_grad=True) 233 234 fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) 235 bw_compiler = functools.partial( 236 count_ops, freq=3, op=torch.ops.aten.mm.default 237 ) # mm recomputed in the bwd 238 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 239 self._validate(fn, backend, x, y) 240 241 @requires_cuda 242 def test_tags_function_with_kwargs(self): 243 def gn(x, y): 244 return torch.sigmoid(torch.matmul(x, y)) 245 246 def fn(x, y): 247 return torch.utils.checkpoint.checkpoint( 248 gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False 249 ) 250 251 x = torch.randn(4, 4, device="cuda", requires_grad=True) 252 y = torch.randn(4, 4, device="cuda", requires_grad=True) 253 254 fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default) 255 bw_compiler = functools.partial( 256 count_ops, freq=3, op=torch.ops.aten.mm.default 257 ) # mm recomputed in the bwd 258 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 259 self._validate(fn, backend, x, y) 260 261 @requires_cuda 262 def test_tags_sequential_layers(self): 263 def gn(x): 264 x = x.cos() 265 for _ in range(3): 266 x = torch.mm(x, x) 267 x = x.cos() 268 return x 269 270 def fn(x): 271 x = torch.utils.checkpoint.checkpoint(gn, x) 272 x = torch.utils.checkpoint.checkpoint(gn, x) 273 return x 274 275 x = torch.randn(4, 4, device="cuda", requires_grad=True) 276 277 fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default) 278 bw_compiler = functools.partial( 279 count_ops, 280 freqs=[2, 18], 281 ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default], 282 ) # mm recomputed in the bwd 283 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 284 self._validate(fn, backend, x) 285 286 @requires_cuda 287 def test_tags_multiple_checkpoints(self): 288 def gn(x, y): 289 return torch.sigmoid(torch.matmul(x, y)) 290 291 def fn(x, y): 292 x = torch.sin(x) 293 z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 294 x = torch.sin(z) 295 z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 296 return z 297 298 x = torch.randn(4, 4, device="cuda", requires_grad=True) 299 y = torch.randn(4, 4, device="cuda", requires_grad=True) 300 301 fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) 302 bw_compiler = functools.partial( 303 count_ops, freq=6, op=torch.ops.aten.mm.default 304 ) # mm recomputed in the bwd 305 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 306 self._validate(fn, backend, x, y) 307 308 @requires_cuda 309 def test_tags_module(self): 310 class MockModule(torch.nn.Module): 311 def __init__(self) -> None: 312 super().__init__() 313 self.linear = torch.nn.Linear(10, 10) 314 315 def forward(self, x): 316 return torch.sigmoid(self.linear(x)) 317 318 mod = MockModule().cuda() 319 320 def fn(x): 321 return torch.utils.checkpoint.checkpoint( 322 mod, torch.sin(x), use_reentrant=True 323 ) 324 325 x = torch.randn(10, 10, device="cuda", requires_grad=True) 326 327 fw_compiler = functools.partial( 328 count_ops, freq=1, op=torch.ops.aten.sigmoid.default 329 ) 330 bw_compiler = functools.partial( 331 count_ops, freq=1, op=torch.ops.aten.sigmoid.default 332 ) 333 backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 334 self._validate(fn, backend, x) 335 336 @requires_cuda 337 def test_tags_decomps(self): 338 # Ensures that tags are passed on through decompositions as well 339 class MockModule(torch.nn.Module): 340 def __init__(self) -> None: 341 super().__init__() 342 self.linear = torch.nn.Linear(10, 10) 343 344 def forward(self, x): 345 return torch.nn.functional.gelu(self.linear(x)) 346 347 mod = MockModule().cuda() 348 349 def fn(x): 350 return torch.utils.checkpoint.checkpoint( 351 mod, torch.sin(x), use_reentrant=True 352 ) 353 354 x = torch.randn(10, 10, device="cuda", requires_grad=True) 355 356 fw_compiler = functools.partial( 357 count_ops, freq=1, op=torch.ops.aten.erf.default 358 ) 359 bw_compiler = functools.partial( 360 count_ops, freq=1, op=torch.ops.aten.erf.default 361 ) 362 backend = aot_autograd( 363 fw_compiler=fw_compiler, 364 bw_compiler=bw_compiler, 365 decompositions=lambda: import_module( 366 "torch._inductor.compile_fx" 367 ).select_decomp_table(), 368 ) 369 self._validate(fn, backend, x) 370 371 @requires_cuda 372 @torch._inductor.config.patch(fallback_random=True) 373 def test_tags_recomputed_rand(self): 374 def gn(x, y): 375 return torch.sigmoid(torch.rand_like(x) * y) * x 376 377 def fn(x, y): 378 x = torch.sin(x) 379 x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 380 x = torch.sin(x) 381 z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 382 return z 383 384 x = torch.randn(4, 4, device="cuda", requires_grad=True) 385 y = torch.randn(4, 4, device="cuda", requires_grad=True) 386 387 # fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) 388 # bw_compiler = functools.partial( 389 # count_ops, freq=6, op=torch.ops.aten.mm.default 390 # ) # mm recomputed in the bwd 391 # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 392 backend = "inductor" 393 self._validate(fn, backend, x, y) 394 395 @requires_cuda 396 @torch._inductor.config.patch(fallback_random=True) 397 def test_tags_rand(self): 398 def gn(x, y): 399 x = torch.mm(x, y) 400 x = torch.mm(x, y) 401 return x 402 403 def fn(x, y): 404 x = torch.sin(x) 405 x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 406 x = torch.sin(x) 407 # x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 408 return x 409 410 x = torch.randn(4, 4, device="cuda", requires_grad=True) 411 y = torch.randn(4, 4, device="cuda", requires_grad=True) 412 413 # fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default) 414 # bw_compiler = functools.partial( 415 # count_ops, freq=6, op=torch.ops.aten.mm.default 416 # ) # mm recomputed in the bwd 417 # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler) 418 # backend = "aot_eager" 419 backend = "inductor" 420 self._validate(fn, backend, x, y) 421 422 @requires_cuda 423 @torch._inductor.config.patch(fallback_random=True) 424 def test_tags_dropout(self): 425 # Figure out a way to test the number of inductor_random calls 426 class MockModule(torch.nn.Module): 427 def __init__(self) -> None: 428 super().__init__() 429 self.linear = torch.nn.Linear(10, 10) 430 self.dropout = torch.nn.Dropout(0.2) 431 432 def forward(self, x): 433 return self.dropout(self.linear(x)) 434 435 mod = MockModule().cuda() 436 437 def fn(x): 438 return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True) 439 440 x = torch.randn(10, 10, device="cuda", requires_grad=True) 441 backend = "inductor" 442 # rand decomps do not have have numerical results as eager 443 self._validate(fn, backend, x, skip_check=True) 444 445 @requires_cuda 446 def test_fallback(self): 447 def gn(x, y): 448 torch._dynamo.graph_break() 449 a = torch.sigmoid(torch.matmul(x, y)) 450 torch._dynamo.graph_break() 451 return torch.cos(a) 452 453 def fn(x, y): 454 return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False)) 455 456 x = torch.randn(4, 4, requires_grad=True) 457 y = torch.randn(4, 4, requires_grad=True) 458 args = (x, y) 459 460 backend = "aot_eager" 461 cnt = CompileCounterWithBackend(backend) 462 463 expected = fn(*args) 464 result = torch.compile(fn, backend=cnt)(*args) 465 466 self.assertEqual(result, expected) 467 468 # One graph for torch.sin on the input, and other for torch.cos. 469 self.assertEqual(cnt.frame_count, 2) 470 self.assertEqual(cnt.op_count, 2) 471 self.assertEqual(len(cnt.graphs), 2) 472 473 @requires_cuda 474 def test_kwargs(self): 475 def gn(x, y, z=None): 476 a = torch.matmul(x, y) 477 if z is not None: 478 return torch.matmul(a, z) 479 return a 480 481 def fn(x, y, z): 482 return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z)) 483 484 x = torch.randn(4, 4, requires_grad=True) 485 y = torch.randn(4, 4, requires_grad=True) 486 z = torch.randn(4, 4, requires_grad=True) 487 args = (x, y, z) 488 489 backend = "aot_eager" 490 cnt = CompileCounterWithBackend(backend) 491 492 expected = fn(*args) 493 result = torch.compile(fn, backend=cnt)(*args) 494 495 self.assertEqual(result, expected) 496 497 self.assertEqual(cnt.frame_count, 1) 498 self.assertEqual(len(cnt.graphs), 1) 499 500 wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) 501 # one for checkpoint, and 3 for x, y, z 502 self.assertEqual(len(wrap_node.args), 4) 503 504 body_function = getattr(cnt.graphs[0], wrap_node.args[0].name) 505 self.assertEqual(op_count(body_function), 2) 506 507 @requires_cuda 508 def test_symints_location(self): 509 def gn(x, y): 510 return torch.matmul(x, torch.nn.functional.dropout(y, 0.5)) 511 512 def fn(x, y): 513 return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True) 514 515 backend = "aot_eager" 516 cnt = CompileCounterWithBackend(backend) 517 opt_fn = torch.compile(fn, backend=cnt) 518 519 x = torch.randn(4, 4, requires_grad=True) 520 y = torch.randn(4, 4, requires_grad=True) 521 args = (x, y) 522 expected = fn(*args) 523 result = opt_fn(*args) 524 525 x = torch.randn(5, 5, requires_grad=True) 526 y = torch.randn(5, 5, requires_grad=True) 527 args = (x, y) 528 expected = fn(*args) 529 result = opt_fn(*args) 530 531 self.assertEqual(result.shape, expected.shape) 532 self.assertEqual(cnt.frame_count, 2) 533 self.assertEqual(len(cnt.graphs), 2) 534 wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint) 535 self.assertEqual(len(wrap_node.args), 3) 536 537 @requires_cuda 538 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 539 def test_compile_selective_checkpoint_must_recompute(self): 540 def context_fn_must_recompute_mm(): 541 must_recompute_list = [ 542 torch.ops.aten.mm.default, 543 ] 544 return create_selective_checkpoint_contexts( 545 _get_custom_policy( 546 must_recompute_list=must_recompute_list, 547 ), 548 ) 549 550 def context_fn_no_recompute_mm(): 551 no_recompute_list = [ 552 torch.ops.aten.mm.default, 553 ] 554 return create_selective_checkpoint_contexts( 555 _get_custom_policy( 556 no_recompute_list=no_recompute_list, 557 ), 558 ) 559 560 def _test(context_fn, bw_compiler): 561 def gn(x): 562 return torch.sigmoid(torch.matmul(x, x)) 563 564 def fn(x): 565 return torch.utils.checkpoint.checkpoint( 566 gn, 567 x, 568 use_reentrant=False, 569 context_fn=context_fn, 570 ) 571 572 x = torch.randn(4, 4, requires_grad=True) 573 574 fw_compiler = functools.partial( 575 count_ops, 576 freq=1, 577 op=torch.ops.aten.mm.default, 578 ) 579 580 backend = aot_autograd( 581 fw_compiler=fw_compiler, 582 bw_compiler=bw_compiler, 583 partition_fn=min_cut_rematerialization_partition, 584 ) 585 self._validate(fn, backend, x) 586 587 _test( 588 context_fn=context_fn_must_recompute_mm, 589 bw_compiler=functools.partial( 590 count_ops, 591 freq=3, # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3) 592 op=torch.ops.aten.mm.default, 593 ), 594 ) 595 _test( 596 context_fn=context_fn_no_recompute_mm, 597 bw_compiler=functools.partial( 598 count_ops, 599 freq=2, # 2 bwd mm ops per fwd matmul 600 op=torch.ops.aten.mm.default, 601 ), 602 ) 603 604 @requires_cuda 605 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 606 def test_compile_selective_checkpoint_must_not_recompute_gemm(self): 607 def selective_checkpointing_context_fn(): 608 no_recompute_list = [ 609 torch.ops.aten.mm.default, 610 ] 611 return create_selective_checkpoint_contexts( 612 _get_custom_policy(no_recompute_list=no_recompute_list) 613 ) 614 615 def gn(x, y): 616 return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 617 618 def fn(x, y): 619 return torch.utils.checkpoint.checkpoint( 620 gn, 621 x, 622 y, 623 use_reentrant=False, 624 context_fn=selective_checkpointing_context_fn, 625 ) 626 627 x = torch.randn(4, 4, requires_grad=True, device="cuda") 628 y = torch.randn(4, 4, requires_grad=True, device="cuda") 629 630 fw_compiler = functools.partial( 631 count_ops, 632 freq=2, 633 op=torch.ops.aten.mm.default, 634 ) 635 bw_compiler = functools.partial( 636 count_ops, 637 # We would've expected 6 here 638 # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6) 639 # if we didn't enable selective checkpointing. 640 freq=4, 641 op=torch.ops.aten.mm.default, 642 ) 643 backend = aot_autograd( 644 fw_compiler=fw_compiler, 645 bw_compiler=bw_compiler, 646 partition_fn=min_cut_rematerialization_partition, 647 ) 648 self._validate(fn, backend, x, y) 649 self._compare_orig_and_checkpointed_fns(gn, fn, x, y) 650 651 @requires_cuda 652 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 653 def test_compile_selective_checkpoint_tensor_subclass(self): 654 def selective_checkpointing_context_fn(): 655 no_recompute_list = [ 656 torch.ops.aten.mm.default, 657 ] 658 return create_selective_checkpoint_contexts( 659 _get_custom_policy(no_recompute_list=no_recompute_list) 660 ) 661 662 def gn(x, y): 663 return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 664 665 def fn(x, y): 666 return torch.utils.checkpoint.checkpoint( 667 gn, 668 x, 669 y, 670 use_reentrant=False, 671 context_fn=selective_checkpointing_context_fn, 672 ) 673 674 rand_tensor = torch.randn(4, 4, requires_grad=True, device="cuda") 675 676 # tensor subclasses as inputs 677 x = TwoTensor(rand_tensor, rand_tensor.clone()) 678 y = TwoTensor(rand_tensor.clone(), rand_tensor.clone()) 679 680 fw_compiler = functools.partial( 681 count_ops, 682 freq=4, 683 op=torch.ops.aten.mm.default, 684 ) 685 bw_compiler = functools.partial( 686 count_ops, 687 # We would've expected 12 here 688 # (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12) 689 # if we didn't enable selective checkpointing. 690 freq=8, 691 op=torch.ops.aten.mm.default, 692 ) 693 backend = aot_autograd( 694 fw_compiler=fw_compiler, 695 bw_compiler=bw_compiler, 696 partition_fn=min_cut_rematerialization_partition, 697 ) 698 self._validate(fn, backend, x, y) 699 self._compare_orig_and_checkpointed_fns(gn, fn, x, y) 700 701 @requires_cuda 702 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 703 def test_compile_selective_checkpoint_custom_rule(self): 704 def _get_custom_policy(meta): 705 no_recompute_list = [ 706 torch.ops.aten.mm.default, 707 ] 708 709 def _custom_policy(mode, func, *args, **kwargs): 710 mm_count_key = f"{mode}_mm_count" 711 if mm_count_key not in meta: 712 meta[mm_count_key] = 0 713 if func == torch.ops.aten.mm.default: 714 meta[mm_count_key] += 1 715 # Saves output of all compute ops, except second mm 716 # (i.e. we will hint the partitioner to recompute second mm in backward pass) 717 return func in no_recompute_list and not ( 718 func == torch.ops.aten.mm.default and meta[mm_count_key] == 2 719 ) 720 721 return _custom_policy 722 723 def selective_checkpointing_context_fn(): 724 meta = {} 725 return create_selective_checkpoint_contexts(_get_custom_policy(meta)) 726 727 def gn(x, y): 728 return torch.sigmoid( 729 torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y) 730 ) 731 732 def fn(x, y): 733 return torch.utils.checkpoint.checkpoint( 734 gn, 735 x, 736 y, 737 use_reentrant=False, 738 context_fn=selective_checkpointing_context_fn, 739 ) 740 741 x = torch.randn(4, 4, requires_grad=True, device="cuda") 742 y = torch.randn(4, 4, requires_grad=True, device="cuda") 743 744 fw_compiler = functools.partial( 745 count_ops, 746 freq=2, 747 op=torch.ops.aten.mm.default, 748 ) 749 bw_compiler = functools.partial( 750 count_ops, 751 # Q: How do we come to this number 4? 752 # A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass, 753 # so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in 754 # the forward pass is recomputed in the backward pass is up to the partitioner to decide. 755 freq_ge=4, 756 op=torch.ops.aten.mm.default, 757 ) 758 backend = aot_autograd( 759 fw_compiler=fw_compiler, 760 bw_compiler=bw_compiler, 761 partition_fn=min_cut_rematerialization_partition, 762 ) 763 self._validate(fn, backend, x, y) 764 self._compare_orig_and_checkpointed_fns(gn, fn, x, y) 765 766 @requires_cuda 767 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 768 def test_compile_selective_checkpoint_partial_ctx_fn(self): 769 def selective_checkpointing_context_fn(no_recompute_list): 770 return create_selective_checkpoint_contexts( 771 _get_custom_policy(no_recompute_list=no_recompute_list) 772 ) 773 774 def gn(x, y): 775 return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y 776 777 def fn(x, y): 778 return torch.utils.checkpoint.checkpoint( 779 gn, 780 x, 781 y, 782 use_reentrant=False, 783 context_fn=functools.partial( 784 selective_checkpointing_context_fn, [torch.ops.aten.mm.default] 785 ), 786 ) 787 788 x = torch.randn(4, 4, requires_grad=True, device="cuda") 789 y = torch.randn(4, 4, requires_grad=True, device="cuda") 790 791 fw_compiler = functools.partial( 792 count_ops, 793 freq=2, 794 op=torch.ops.aten.mm.default, 795 ) 796 bw_compiler = functools.partial( 797 count_ops, 798 # We would've expected 6 here 799 # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6) 800 # if we didn't enable selective checkpointing. 801 freq=4, 802 op=torch.ops.aten.mm.default, 803 ) 804 backend = aot_autograd( 805 fw_compiler=fw_compiler, 806 bw_compiler=bw_compiler, 807 partition_fn=min_cut_rematerialization_partition, 808 ) 809 self._validate(fn, backend, x, y) 810 self._compare_orig_and_checkpointed_fns(gn, fn, x, y) 811 812 @requires_cuda 813 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 814 def test_compile_selective_checkpoint_outplace_op(self): 815 def selective_checkpointing_context_fn(): 816 no_recompute_list = [ 817 torch.ops.aten.mm.default, 818 torch.ops.aten.sigmoid.default, 819 ] 820 return create_selective_checkpoint_contexts( 821 _get_custom_policy(no_recompute_list=no_recompute_list), 822 ) 823 824 def gn(x, y): 825 return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu() 826 827 def fn(x, y): 828 return torch.utils.checkpoint.checkpoint( 829 gn, 830 x, 831 y, 832 use_reentrant=False, 833 context_fn=selective_checkpointing_context_fn, 834 ) 835 836 x = torch.randn(4, 4, requires_grad=True, device="cuda") 837 y = torch.randn(4, 4, requires_grad=True, device="cuda") 838 839 fw_compiler = functools.partial( 840 count_ops, 841 freqs=[2, 1], 842 ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], 843 ) 844 bw_compiler = functools.partial( 845 count_ops, 846 freqs=[4, 0], 847 ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], 848 ) 849 backend = aot_autograd( 850 fw_compiler=fw_compiler, 851 bw_compiler=bw_compiler, 852 partition_fn=min_cut_rematerialization_partition, 853 ) 854 self._validate(fn, backend, x, y) 855 self._compare_orig_and_checkpointed_fns(gn, fn, x, y) 856 857 @requires_cuda 858 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 859 @unittest.skip( 860 "In-place op support in selective checkpointing + torch.compile " 861 "requires TorchDispatchMode + torch.compile work to complete" 862 ) 863 def test_compile_selective_checkpoint_inplace_op(self): 864 def selective_checkpointing_context_fn(): 865 no_recompute_list = [ 866 torch.ops.aten.mm.default, 867 torch.ops.aten.sigmoid.default, 868 ] 869 return create_selective_checkpoint_contexts( 870 _get_custom_policy(no_recompute_list=no_recompute_list) 871 ) 872 873 def gn(x, y): 874 return torch.sigmoid( 875 torch.selu_(torch.matmul(torch.matmul(x, y), y)) 876 ).relu_() 877 878 def fn(x, y): 879 return torch.utils.checkpoint.checkpoint( 880 gn, 881 x, 882 y, 883 use_reentrant=False, 884 context_fn=selective_checkpointing_context_fn, 885 ) 886 887 x = torch.randn(4, 4, requires_grad=True, device="cuda") 888 y = torch.randn(4, 4, requires_grad=True, device="cuda") 889 890 fw_compiler = functools.partial( 891 count_ops, 892 freqs=[2, 1], 893 ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], 894 ) 895 bw_compiler = functools.partial( 896 count_ops, 897 freqs=[4, 0], 898 ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default], 899 ) 900 backend = aot_autograd( 901 fw_compiler=fw_compiler, 902 bw_compiler=bw_compiler, 903 partition_fn=min_cut_rematerialization_partition, 904 ) 905 self._validate(fn, backend, x, y) 906 self._compare_orig_and_checkpointed_fns(gn, fn, x, y) 907 908 @requires_cuda 909 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 910 def test_compile_selective_checkpoint_random_op(self): 911 for preserve_rng_state in [True, False]: 912 913 def selective_checkpointing_context_fn(): 914 no_recompute_list = [ 915 torch.ops.aten.sigmoid.default, 916 ] 917 return create_selective_checkpoint_contexts( 918 _get_custom_policy(no_recompute_list=no_recompute_list) 919 ) 920 921 def gn(x): 922 return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True)) 923 924 def fn(x): 925 return torch.utils.checkpoint.checkpoint( 926 gn, 927 x, 928 use_reentrant=False, 929 # Regardless of whether `preserve_rng_state` is True or False, 930 # we will always preserve RNG state when using `torch.compile`. 931 preserve_rng_state=preserve_rng_state, 932 context_fn=selective_checkpointing_context_fn, 933 ) 934 935 x = torch.randn(4, 4, requires_grad=True, device="cuda") 936 937 fw_compiler = functools.partial( 938 count_ops, 939 freqs=[2, 1], 940 ops=[ 941 torch.ops.aten.sigmoid.default, 942 torch.ops.aten.native_dropout.default, 943 ], 944 ) 945 bw_compiler = functools.partial( 946 count_ops, 947 # NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1). 948 freqs=[0, 1], 949 ops=[ 950 torch.ops.aten.sigmoid.default, 951 torch.ops.aten.native_dropout.default, 952 ], 953 ) 954 backend = aot_autograd( 955 fw_compiler=fw_compiler, 956 bw_compiler=bw_compiler, 957 partition_fn=min_cut_rematerialization_partition, 958 ) 959 960 # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager, 961 # because eager version doesn't preserve RNG state while torch.compile still does. 962 # Hence when `preserve_rng_state` is False, we skip the output and gradient comparison 963 # between torch.compile and eager. 964 self._validate(fn, backend, x, skip_check=not preserve_rng_state) 965 self._compare_orig_and_checkpointed_fns(gn, fn, x) 966 967 @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows") 968 def test_compile_selective_checkpoint_invalid_context(self): 969 def gn(x, y): 970 return torch.sigmoid(torch.matmul(x, y)) * y 971 972 def fn(x, y): 973 return torch.utils.checkpoint.checkpoint( 974 gn, 975 x, 976 y, 977 use_reentrant=False, 978 context_fn=_invalid_context_gen, 979 ) 980 981 x = torch.randn(4, 4, requires_grad=True) 982 y = torch.randn(4, 4, requires_grad=True) 983 984 fw_compiler = functools.partial( 985 count_ops, 986 freq=1, 987 op=torch.ops.aten.mm.default, 988 ) 989 bw_compiler = functools.partial( 990 count_ops, 991 freq_ge=2, 992 op=torch.ops.aten.mm.default, 993 ) 994 backend = aot_autograd( 995 fw_compiler=fw_compiler, 996 bw_compiler=bw_compiler, 997 partition_fn=min_cut_rematerialization_partition, 998 ) 999 with self.assertRaisesRegex( 1000 Exception, "must generate a tuple of two `TorchDispatchMode`s" 1001 ): 1002 self._validate(fn, backend, x, y) 1003 1004 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) 1005 def test_compile_selective_checkpoint_parametrization(self): 1006 def sac_policy(): 1007 def _recomp_policy(): 1008 def _custom_policy(ctx, func, *args, **kwargs): 1009 to_recompute = func in { 1010 torch.ops.aten.mul.Tensor, 1011 torch.ops.aten.sigmoid.default, 1012 } 1013 return ( 1014 CheckpointPolicy.MUST_RECOMPUTE 1015 if to_recompute 1016 else CheckpointPolicy.MUST_SAVE 1017 ) 1018 1019 return _custom_policy 1020 1021 return create_selective_checkpoint_contexts(_recomp_policy()) 1022 1023 class Parametrization(torch.nn.Module): 1024 def __init__(self) -> None: 1025 super().__init__() 1026 1027 def parametrization(self, x): 1028 return torch.sigmoid(torch.mul(x, x)) 1029 1030 def forward(self, x): 1031 return checkpoint( 1032 self.parametrization, x, use_reentrant=False, context_fn=sac_policy 1033 ) 1034 1035 def apply_parametrization(model): 1036 modules = list(model.modules()) 1037 1038 for mod in modules: 1039 params_dict = dict(mod.named_parameters(recurse=False)) 1040 for p_name, p in params_dict.items(): 1041 mod.register_parameter(p_name, nn.Parameter(p)) 1042 nn.utils.parametrize.register_parametrization( 1043 mod, p_name, Parametrization(), unsafe=True 1044 ) 1045 1046 return model 1047 1048 class MLPModule(nn.Module): 1049 def __init__(self) -> None: 1050 super().__init__() 1051 torch.manual_seed(5) 1052 self.net1 = nn.Linear(16, 16, bias=False) 1053 1054 def forward(self, x): 1055 return self.net1(x) 1056 1057 def reset_parameters(self): 1058 self.net1.reset_parameters() 1059 1060 fw_compiler = functools.partial( 1061 count_ops, 1062 freqs=[1, 1], 1063 ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], 1064 ) 1065 bw_compiler = functools.partial( 1066 count_ops, 1067 freqs=[ 1068 2, # 1 from mul recompute, 1 from mul backward 1069 1, 1070 ], 1071 ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default], 1072 ) 1073 1074 backend = aot_autograd( 1075 fw_compiler=fw_compiler, 1076 bw_compiler=bw_compiler, 1077 partition_fn=min_cut_rematerialization_partition, 1078 ) 1079 1080 model = MLPModule() 1081 model = apply_parametrization(model) 1082 model_compiled = torch.compile( 1083 copy.deepcopy(model), backend=backend, fullgraph=True 1084 ) 1085 input = torch.randn(8, 16, requires_grad=True) 1086 input_compiled = copy.deepcopy(input) 1087 1088 out = model(input) 1089 out.sum().backward() 1090 out_compiled = model_compiled(input_compiled) 1091 out_compiled.sum().backward() 1092 1093 self.assertEqual(out, out_compiled) 1094 self.assertEqual(input.grad, input_compiled.grad) 1095 1096 @requires_cuda 1097 @skipIfRocm 1098 def test_autocast_flash_attention(self): 1099 def fn(primals_1, primals_2, primals_3): 1100 return torch.ops.aten._scaled_dot_product_efficient_attention.default( 1101 primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687 1102 )[0] 1103 1104 def gn(*args): 1105 return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True) 1106 1107 with torch.cuda.amp.autocast(): 1108 x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True) 1109 y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True) 1110 z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True) 1111 args = (x, y, z) 1112 1113 torch.manual_seed(0) 1114 ref = gn(*args) 1115 1116 opt_gn = torch.compile(gn) 1117 torch.manual_seed(0) 1118 res = opt_gn(*args) 1119 self.assertEqual(ref, res) 1120 1121 @requires_cuda 1122 def test_error_msg(self): 1123 class MockModule(torch.nn.Module): 1124 def __init__(self) -> None: 1125 super().__init__() 1126 1127 def forward(self, x): 1128 x = torch.sin(x) 1129 torch._dynamo.graph_break() 1130 x = torch.cos(x) 1131 return x 1132 1133 mod = MockModule().cuda() 1134 1135 def fn(x): 1136 return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True) 1137 1138 x = torch.randn(4, 4).cuda() 1139 opt_fn = torch.compile(fn, fullgraph=True) 1140 with self.assertRaisesRegex( 1141 torch._dynamo.exc.Unsupported, "skip function graph_break in file" 1142 ): 1143 opt_fn(x) 1144 1145 @requires_cuda 1146 def test_list_inputs(self): 1147 class MockModule(torch.nn.Module): 1148 def __init__(self) -> None: 1149 super().__init__() 1150 1151 def forward(self, x, ys): 1152 a = torch.sin(x) 1153 b = torch.cos(ys[0]) 1154 c = torch.cos(ys[1]) 1155 return (x, [b, c]) 1156 1157 mod = MockModule().cuda() 1158 1159 def fn(x, ys): 1160 return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True) 1161 1162 x = torch.randn(4, 4).cuda() 1163 y = torch.randn(4, 4).cuda() 1164 z = torch.randn(4, 4).cuda() 1165 ref = fn(x, [y, z]) 1166 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) 1167 res = opt_fn(x, [y, z]) 1168 self.assertEqual(ref, res) 1169 1170 @requires_cuda 1171 def test_pattern_matcher(self): 1172 # Check that the sdpa op is recomputed in the backward graph 1173 # tests percolate_tags 1174 1175 @checkpoint_wrapper 1176 def dot_prod_attention( 1177 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor 1178 ) -> torch.Tensor: 1179 return ( 1180 torch.matmul(query, key.transpose(-2, -1)) 1181 .mul(1.0 / math.sqrt(key.shape[-1])) 1182 .softmax(dim=-1) 1183 .matmul(value) 1184 ) 1185 1186 def fn(query, key, value): 1187 # Checks that sin is not recomputed in the backward graph 1188 return dot_prod_attention(query.sin(), key, value) 1189 1190 tensor_shape = (4, 2, 16, 32) 1191 dtype = torch.float16 1192 args1 = [ 1193 torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True), 1194 torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True), 1195 torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True), 1196 ] 1197 1198 # Save the AOT graphs 1199 aot_graphs = [] 1200 from torch._inductor import compile_fx 1201 1202 def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs): 1203 aot_graphs.append(graph) 1204 return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs) 1205 1206 backend = functools.partial( 1207 compile_fx.compile_fx, inner_compile=debug_compile_fx_inner 1208 ) 1209 1210 opt_fn = torch.compile(fn, backend=backend, fullgraph=True) 1211 opt_fn(*args1).sum().backward() 1212 if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: 1213 op = torch.ops.aten._scaled_dot_product_cudnn_attention.default 1214 else: 1215 op = torch.ops.aten._scaled_dot_product_flash_attention.default 1216 1217 fwd_graph = aot_graphs[0] 1218 self.assertTrue( 1219 count_ops( 1220 fwd_graph, 1221 [], 1222 freq=1, 1223 op=op, 1224 ) 1225 ) 1226 1227 bwd_graph = aot_graphs[1] 1228 # Check that sin is not recomputed in the backward graph - checks percolate tags 1229 self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default)) 1230 # Check that the sdpa op is recomputed in the backward graph 1231 self.assertTrue( 1232 count_ops( 1233 bwd_graph, 1234 [], 1235 freq=1, 1236 op=op, 1237 ) 1238 ) 1239 1240 @requires_cuda 1241 @requires_distributed() 1242 def test_distributed_utils_checkpoint_wrapper(self): 1243 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 1244 checkpoint_wrapper as dist_checkpoint_wrapper, 1245 ) 1246 1247 class MockModule(torch.nn.Module): 1248 def __init__(self) -> None: 1249 super().__init__() 1250 self.linear = torch.nn.Linear(4, 4) 1251 self.c = 2 1252 1253 def forward(self, x): 1254 x = torch.sin(x) 1255 x = self.linear(x) 1256 x = torch.cos(x) 1257 return x * self.c 1258 1259 mod = dist_checkpoint_wrapper(MockModule()) 1260 x = torch.randn(4, 4) 1261 ref = mod(x) 1262 opt_mod = torch.compile(mod, backend="eager", fullgraph=True) 1263 res = opt_mod(x) 1264 self.assertEqual(ref, res) 1265 1266 @requires_cuda 1267 @requires_distributed() 1268 @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True) 1269 def test_dynamo_does_not_trace_getattr_as_top_frame(self): 1270 # inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do. 1271 from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 1272 CheckpointWrapper, 1273 ) 1274 1275 cnt = CompileCounterWithBackend("eager") 1276 1277 lin = torch.nn.Linear(1, 1) 1278 mod = torch.nn.Sequential(lin, lin) 1279 mod = CheckpointWrapper(mod) 1280 mod._checkpoint_wrapped_module.a = torch.ones(1, 1) 1281 1282 def fn(x): 1283 return mod(x) * mod.a 1284 1285 opt_fn = torch.compile(fn, backend=cnt, fullgraph=True) 1286 x = torch.randn(1, 1) 1287 1288 self.assertEqual(opt_fn(x), fn(x)) 1289 1290 1291if __name__ == "__main__": 1292 from torch._dynamo.test_case import run_tests 1293 1294 run_tests() 1295