xref: /aosp_15_r20/external/pytorch/test/dynamo/test_activation_checkpointing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import copy
3import functools
4import math
5import unittest  # noqa: F811
6from importlib import import_module
7
8import torch
9import torch._dynamo.config
10import torch._dynamo.test_case
11import torch._functorch.config
12import torch.distributed as dist
13import torch.nn as nn
14import torch.utils.checkpoint
15from functorch.compile import min_cut_rematerialization_partition
16from torch._dynamo.backends.common import aot_autograd
17from torch._dynamo.testing import CompileCounterWithBackend
18from torch._higher_order_ops.wrap import tag_activation_checkpoint
19from torch.testing._internal.common_cuda import (
20    PLATFORM_SUPPORTS_CUDNN_ATTENTION,
21    SM90OrLater,
22)
23from torch.testing._internal.common_utils import IS_WINDOWS, skipIfRocm
24from torch.testing._internal.inductor_utils import HAS_CUDA
25from torch.testing._internal.two_tensor import TwoTensor
26from torch.utils.checkpoint import (
27    checkpoint,
28    CheckpointPolicy,
29    create_selective_checkpoint_contexts,
30)
31
32
33requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
34requires_distributed = functools.partial(
35    unittest.skipIf, not dist.is_available(), "requires distributed"
36)
37
38
39def checkpoint_wrapper(fn):
40    def inner(*args):
41        return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
42
43    return inner
44
45
46def count_ops(
47    gm, args, freq=None, freq_ge=None, op=None, freqs=None, freqs_ge=None, ops=None
48):
49    def match_rng_op(node, op):
50        if isinstance(node.target, torch._ops.HigherOrderOperator):
51            if node.name == "run_and_save_rng_state":
52                return node.args[0] == op
53            elif node.name == "run_with_rng_state":
54                return node.args[1] == op
55        return False
56
57    # assert ((freq or freq_ge) and op) or ((freqs or freqs_ge) and ops)
58    if op is not None:
59        assert not isinstance(op, list)
60        ops = [op]
61    if freq is not None:
62        freqs = [freq]
63    if freq_ge is not None:
64        freqs_ge = [freq_ge]
65    if freqs:
66        for op, freq in zip(ops, freqs):
67            actual_count = 0
68            for node in gm.graph.nodes:
69                if match_rng_op(node, op) or node.target == op:
70                    actual_count += 1
71            err_msg = f"In graph {gm}, expected {op} to have occurred {freq} times in the graph, but got {actual_count}."
72            assert actual_count == freq, err_msg
73    else:
74        assert freqs_ge is not None
75        for op, freq_ge in zip(ops, freqs_ge):
76            actual_count = 0
77            for node in gm.graph.nodes:
78                if match_rng_op(node, op) or node.target == op:
79                    actual_count += 1
80            assert (
81                actual_count >= freq_ge
82            ), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}."
83    return gm
84
85
86class _InvalidContext:
87    def __init__(self) -> None:
88        pass
89
90    def __enter__(self):
91        return self
92
93    def __exit__(self, exc_type, exc_val, exc_tb):
94        pass
95
96
97def _invalid_context_gen():
98    return _InvalidContext(), _InvalidContext()
99
100
101def find_first_node(gm, func):
102    for node in gm.graph.nodes:
103        if node.target is func:
104            return node
105    return None
106
107
108def op_count(gm):
109    result = 0
110    for node in gm.graph.nodes:
111        if "call" in node.op:
112            result += 1
113    return result
114
115
116def _get_custom_policy(no_recompute_list=None, must_recompute_list=None):
117    def _custom_policy(ctx, func, *args, **kwargs):
118        if no_recompute_list is not None and func in no_recompute_list:
119            return CheckpointPolicy.MUST_SAVE
120        if must_recompute_list is not None and func in must_recompute_list:
121            return CheckpointPolicy.MUST_RECOMPUTE
122        else:
123            return CheckpointPolicy.PREFER_RECOMPUTE
124
125    return _custom_policy
126
127
128class ActivationCheckpointingViaTagsTests(torch._dynamo.test_case.TestCase):
129    def _validate(self, fn, backend, *args, skip_check=False, fullgraph=True):
130        cloned_args = []
131        for arg in args:
132            cloned_args.append(arg.clone().detach().requires_grad_(arg.requires_grad))
133
134        torch.manual_seed(0)
135        expected = fn(*args)
136        expected.sum().backward()
137
138        torch.manual_seed(0)
139        result = torch.compile(fn, fullgraph=fullgraph, backend=backend)(*cloned_args)
140        result.sum().backward()
141
142        if not skip_check:
143            self.assertEqual(
144                result,
145                expected,
146                msg="Output mismatch between torch.compile and eager versions",
147            )
148            for arg, cloned_arg in zip(args, cloned_args):
149                self.assertEqual(
150                    arg.grad,
151                    cloned_arg.grad,
152                    msg="Gradient mismatch between torch.compile and eager versions",
153                )
154
155    def _compare_orig_and_checkpointed_fns(
156        self, orig_fn, checkpointed_fn, *args, fullgraph=True
157    ):
158        # The original version and the checkpointed version of the same function
159        # should produce the same outputs and the same gradients under torch.compile.
160
161        # Run original version
162        cloned_args_orig_fn = []
163        for arg in args:
164            cloned_args_orig_fn.append(
165                arg.clone().detach().requires_grad_(arg.requires_grad)
166            )
167        torch.manual_seed(0)
168        compiled_orig_fn = torch.compile(
169            orig_fn, fullgraph=fullgraph, backend="inductor"
170        )
171        result_orig_fn = compiled_orig_fn(*cloned_args_orig_fn)
172        result_orig_fn.sum().backward()
173
174        # Run checkpointed version
175        cloned_args_checkpointed_fn = []
176        for arg in args:
177            cloned_args_checkpointed_fn.append(
178                arg.clone().detach().requires_grad_(arg.requires_grad)
179            )
180        torch.manual_seed(0)
181        compiled_checkpointed_fn = torch.compile(
182            checkpointed_fn, fullgraph=fullgraph, backend="inductor"
183        )
184        result_checkpointed_fn = compiled_checkpointed_fn(*cloned_args_checkpointed_fn)
185        result_checkpointed_fn.sum().backward()
186
187        # Check that outputs and gradients are equal
188        self.assertEqual(
189            result_orig_fn,
190            result_checkpointed_fn,
191            msg="Output mismatch between the original version and the checkpointed version of the same function",
192        )
193        for cloned_arg_orig_fn, cloned_arg_checkpointed_fn in zip(
194            cloned_args_orig_fn, cloned_args_checkpointed_fn
195        ):
196            self.assertEqual(
197                cloned_arg_orig_fn.grad,
198                cloned_arg_checkpointed_fn.grad,
199                msg="Gradient mismatch between the original version and the checkpointed version of the same function",
200            )
201
202    @requires_cuda
203    def test_tags_function(self):
204        def gn(x, y):
205            return torch.sigmoid(torch.matmul(x, y))
206
207        def fn(x, y):
208            return torch.utils.checkpoint.checkpoint(
209                gn, torch.sin(x), y, use_reentrant=True
210            )
211
212        x = torch.randn(4, 4, device="cuda", requires_grad=True)
213        y = torch.randn(4, 4, device="cuda", requires_grad=True)
214
215        fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
216        bw_compiler = functools.partial(
217            count_ops, freq=3, op=torch.ops.aten.mm.default
218        )  # mm recomputed in the bwd
219        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
220        self._validate(fn, backend, x, y)
221
222    @requires_cuda
223    def test_tags_function_via_global_checkpoint(self):
224        def gn(x, y):
225            return torch.sigmoid(torch.matmul(x, y))
226
227        def fn(x, y):
228            # This goes through VariableBuilder
229            return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
230
231        x = torch.randn(4, 4, device="cuda", requires_grad=True)
232        y = torch.randn(4, 4, device="cuda", requires_grad=True)
233
234        fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
235        bw_compiler = functools.partial(
236            count_ops, freq=3, op=torch.ops.aten.mm.default
237        )  # mm recomputed in the bwd
238        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
239        self._validate(fn, backend, x, y)
240
241    @requires_cuda
242    def test_tags_function_with_kwargs(self):
243        def gn(x, y):
244            return torch.sigmoid(torch.matmul(x, y))
245
246        def fn(x, y):
247            return torch.utils.checkpoint.checkpoint(
248                gn, torch.sin(x), y, use_reentrant=True, preserve_rng_state=False
249            )
250
251        x = torch.randn(4, 4, device="cuda", requires_grad=True)
252        y = torch.randn(4, 4, device="cuda", requires_grad=True)
253
254        fw_compiler = functools.partial(count_ops, freq=1, op=torch.ops.aten.mm.default)
255        bw_compiler = functools.partial(
256            count_ops, freq=3, op=torch.ops.aten.mm.default
257        )  # mm recomputed in the bwd
258        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
259        self._validate(fn, backend, x, y)
260
261    @requires_cuda
262    def test_tags_sequential_layers(self):
263        def gn(x):
264            x = x.cos()
265            for _ in range(3):
266                x = torch.mm(x, x)
267            x = x.cos()
268            return x
269
270        def fn(x):
271            x = torch.utils.checkpoint.checkpoint(gn, x)
272            x = torch.utils.checkpoint.checkpoint(gn, x)
273            return x
274
275        x = torch.randn(4, 4, device="cuda", requires_grad=True)
276
277        fw_compiler = functools.partial(count_ops, freq=6, op=torch.ops.aten.mm.default)
278        bw_compiler = functools.partial(
279            count_ops,
280            freqs=[2, 18],
281            ops=[torch.ops.aten.cos.default, torch.ops.aten.mm.default],
282        )  # mm recomputed in the bwd
283        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
284        self._validate(fn, backend, x)
285
286    @requires_cuda
287    def test_tags_multiple_checkpoints(self):
288        def gn(x, y):
289            return torch.sigmoid(torch.matmul(x, y))
290
291        def fn(x, y):
292            x = torch.sin(x)
293            z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
294            x = torch.sin(z)
295            z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
296            return z
297
298        x = torch.randn(4, 4, device="cuda", requires_grad=True)
299        y = torch.randn(4, 4, device="cuda", requires_grad=True)
300
301        fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
302        bw_compiler = functools.partial(
303            count_ops, freq=6, op=torch.ops.aten.mm.default
304        )  # mm recomputed in the bwd
305        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
306        self._validate(fn, backend, x, y)
307
308    @requires_cuda
309    def test_tags_module(self):
310        class MockModule(torch.nn.Module):
311            def __init__(self) -> None:
312                super().__init__()
313                self.linear = torch.nn.Linear(10, 10)
314
315            def forward(self, x):
316                return torch.sigmoid(self.linear(x))
317
318        mod = MockModule().cuda()
319
320        def fn(x):
321            return torch.utils.checkpoint.checkpoint(
322                mod, torch.sin(x), use_reentrant=True
323            )
324
325        x = torch.randn(10, 10, device="cuda", requires_grad=True)
326
327        fw_compiler = functools.partial(
328            count_ops, freq=1, op=torch.ops.aten.sigmoid.default
329        )
330        bw_compiler = functools.partial(
331            count_ops, freq=1, op=torch.ops.aten.sigmoid.default
332        )
333        backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
334        self._validate(fn, backend, x)
335
336    @requires_cuda
337    def test_tags_decomps(self):
338        # Ensures that tags are passed on through decompositions as well
339        class MockModule(torch.nn.Module):
340            def __init__(self) -> None:
341                super().__init__()
342                self.linear = torch.nn.Linear(10, 10)
343
344            def forward(self, x):
345                return torch.nn.functional.gelu(self.linear(x))
346
347        mod = MockModule().cuda()
348
349        def fn(x):
350            return torch.utils.checkpoint.checkpoint(
351                mod, torch.sin(x), use_reentrant=True
352            )
353
354        x = torch.randn(10, 10, device="cuda", requires_grad=True)
355
356        fw_compiler = functools.partial(
357            count_ops, freq=1, op=torch.ops.aten.erf.default
358        )
359        bw_compiler = functools.partial(
360            count_ops, freq=1, op=torch.ops.aten.erf.default
361        )
362        backend = aot_autograd(
363            fw_compiler=fw_compiler,
364            bw_compiler=bw_compiler,
365            decompositions=lambda: import_module(
366                "torch._inductor.compile_fx"
367            ).select_decomp_table(),
368        )
369        self._validate(fn, backend, x)
370
371    @requires_cuda
372    @torch._inductor.config.patch(fallback_random=True)
373    def test_tags_recomputed_rand(self):
374        def gn(x, y):
375            return torch.sigmoid(torch.rand_like(x) * y) * x
376
377        def fn(x, y):
378            x = torch.sin(x)
379            x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
380            x = torch.sin(x)
381            z = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
382            return z
383
384        x = torch.randn(4, 4, device="cuda", requires_grad=True)
385        y = torch.randn(4, 4, device="cuda", requires_grad=True)
386
387        # fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
388        # bw_compiler = functools.partial(
389        #     count_ops, freq=6, op=torch.ops.aten.mm.default
390        # )  # mm recomputed in the bwd
391        # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
392        backend = "inductor"
393        self._validate(fn, backend, x, y)
394
395    @requires_cuda
396    @torch._inductor.config.patch(fallback_random=True)
397    def test_tags_rand(self):
398        def gn(x, y):
399            x = torch.mm(x, y)
400            x = torch.mm(x, y)
401            return x
402
403        def fn(x, y):
404            x = torch.sin(x)
405            x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
406            x = torch.sin(x)
407            # x = torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
408            return x
409
410        x = torch.randn(4, 4, device="cuda", requires_grad=True)
411        y = torch.randn(4, 4, device="cuda", requires_grad=True)
412
413        # fw_compiler = functools.partial(count_ops, freq=2, op=torch.ops.aten.mm.default)
414        # bw_compiler = functools.partial(
415        #     count_ops, freq=6, op=torch.ops.aten.mm.default
416        # )  # mm recomputed in the bwd
417        # backend = aot_autograd(fw_compiler=fw_compiler, bw_compiler=bw_compiler)
418        # backend = "aot_eager"
419        backend = "inductor"
420        self._validate(fn, backend, x, y)
421
422    @requires_cuda
423    @torch._inductor.config.patch(fallback_random=True)
424    def test_tags_dropout(self):
425        # Figure out a way to test the number of inductor_random calls
426        class MockModule(torch.nn.Module):
427            def __init__(self) -> None:
428                super().__init__()
429                self.linear = torch.nn.Linear(10, 10)
430                self.dropout = torch.nn.Dropout(0.2)
431
432            def forward(self, x):
433                return self.dropout(self.linear(x))
434
435        mod = MockModule().cuda()
436
437        def fn(x):
438            return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
439
440        x = torch.randn(10, 10, device="cuda", requires_grad=True)
441        backend = "inductor"
442        # rand decomps do not have have numerical results as eager
443        self._validate(fn, backend, x, skip_check=True)
444
445    @requires_cuda
446    def test_fallback(self):
447        def gn(x, y):
448            torch._dynamo.graph_break()
449            a = torch.sigmoid(torch.matmul(x, y))
450            torch._dynamo.graph_break()
451            return torch.cos(a)
452
453        def fn(x, y):
454            return torch.cos(checkpoint(gn, torch.sin(x), y, use_reentrant=False))
455
456        x = torch.randn(4, 4, requires_grad=True)
457        y = torch.randn(4, 4, requires_grad=True)
458        args = (x, y)
459
460        backend = "aot_eager"
461        cnt = CompileCounterWithBackend(backend)
462
463        expected = fn(*args)
464        result = torch.compile(fn, backend=cnt)(*args)
465
466        self.assertEqual(result, expected)
467
468        # One graph for torch.sin on the input, and other for torch.cos.
469        self.assertEqual(cnt.frame_count, 2)
470        self.assertEqual(cnt.op_count, 2)
471        self.assertEqual(len(cnt.graphs), 2)
472
473    @requires_cuda
474    def test_kwargs(self):
475        def gn(x, y, z=None):
476            a = torch.matmul(x, y)
477            if z is not None:
478                return torch.matmul(a, z)
479            return a
480
481        def fn(x, y, z):
482            return torch.cos(checkpoint(gn, x, y, use_reentrant=False, z=z))
483
484        x = torch.randn(4, 4, requires_grad=True)
485        y = torch.randn(4, 4, requires_grad=True)
486        z = torch.randn(4, 4, requires_grad=True)
487        args = (x, y, z)
488
489        backend = "aot_eager"
490        cnt = CompileCounterWithBackend(backend)
491
492        expected = fn(*args)
493        result = torch.compile(fn, backend=cnt)(*args)
494
495        self.assertEqual(result, expected)
496
497        self.assertEqual(cnt.frame_count, 1)
498        self.assertEqual(len(cnt.graphs), 1)
499
500        wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
501        # one for checkpoint, and 3 for x, y, z
502        self.assertEqual(len(wrap_node.args), 4)
503
504        body_function = getattr(cnt.graphs[0], wrap_node.args[0].name)
505        self.assertEqual(op_count(body_function), 2)
506
507    @requires_cuda
508    def test_symints_location(self):
509        def gn(x, y):
510            return torch.matmul(x, torch.nn.functional.dropout(y, 0.5))
511
512        def fn(x, y):
513            return torch.utils.checkpoint.checkpoint(gn, x, y, use_reentrant=True)
514
515        backend = "aot_eager"
516        cnt = CompileCounterWithBackend(backend)
517        opt_fn = torch.compile(fn, backend=cnt)
518
519        x = torch.randn(4, 4, requires_grad=True)
520        y = torch.randn(4, 4, requires_grad=True)
521        args = (x, y)
522        expected = fn(*args)
523        result = opt_fn(*args)
524
525        x = torch.randn(5, 5, requires_grad=True)
526        y = torch.randn(5, 5, requires_grad=True)
527        args = (x, y)
528        expected = fn(*args)
529        result = opt_fn(*args)
530
531        self.assertEqual(result.shape, expected.shape)
532        self.assertEqual(cnt.frame_count, 2)
533        self.assertEqual(len(cnt.graphs), 2)
534        wrap_node = find_first_node(cnt.graphs[0], tag_activation_checkpoint)
535        self.assertEqual(len(wrap_node.args), 3)
536
537    @requires_cuda
538    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
539    def test_compile_selective_checkpoint_must_recompute(self):
540        def context_fn_must_recompute_mm():
541            must_recompute_list = [
542                torch.ops.aten.mm.default,
543            ]
544            return create_selective_checkpoint_contexts(
545                _get_custom_policy(
546                    must_recompute_list=must_recompute_list,
547                ),
548            )
549
550        def context_fn_no_recompute_mm():
551            no_recompute_list = [
552                torch.ops.aten.mm.default,
553            ]
554            return create_selective_checkpoint_contexts(
555                _get_custom_policy(
556                    no_recompute_list=no_recompute_list,
557                ),
558            )
559
560        def _test(context_fn, bw_compiler):
561            def gn(x):
562                return torch.sigmoid(torch.matmul(x, x))
563
564            def fn(x):
565                return torch.utils.checkpoint.checkpoint(
566                    gn,
567                    x,
568                    use_reentrant=False,
569                    context_fn=context_fn,
570                )
571
572            x = torch.randn(4, 4, requires_grad=True)
573
574            fw_compiler = functools.partial(
575                count_ops,
576                freq=1,
577                op=torch.ops.aten.mm.default,
578            )
579
580            backend = aot_autograd(
581                fw_compiler=fw_compiler,
582                bw_compiler=bw_compiler,
583                partition_fn=min_cut_rematerialization_partition,
584            )
585            self._validate(fn, backend, x)
586
587        _test(
588            context_fn=context_fn_must_recompute_mm,
589            bw_compiler=functools.partial(
590                count_ops,
591                freq=3,  # 1 matmul recompute and 2 bwd mm ops per fwd matmul, so 1 + 2 * 1 = 3)
592                op=torch.ops.aten.mm.default,
593            ),
594        )
595        _test(
596            context_fn=context_fn_no_recompute_mm,
597            bw_compiler=functools.partial(
598                count_ops,
599                freq=2,  # 2 bwd mm ops per fwd matmul
600                op=torch.ops.aten.mm.default,
601            ),
602        )
603
604    @requires_cuda
605    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
606    def test_compile_selective_checkpoint_must_not_recompute_gemm(self):
607        def selective_checkpointing_context_fn():
608            no_recompute_list = [
609                torch.ops.aten.mm.default,
610            ]
611            return create_selective_checkpoint_contexts(
612                _get_custom_policy(no_recompute_list=no_recompute_list)
613            )
614
615        def gn(x, y):
616            return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
617
618        def fn(x, y):
619            return torch.utils.checkpoint.checkpoint(
620                gn,
621                x,
622                y,
623                use_reentrant=False,
624                context_fn=selective_checkpointing_context_fn,
625            )
626
627        x = torch.randn(4, 4, requires_grad=True, device="cuda")
628        y = torch.randn(4, 4, requires_grad=True, device="cuda")
629
630        fw_compiler = functools.partial(
631            count_ops,
632            freq=2,
633            op=torch.ops.aten.mm.default,
634        )
635        bw_compiler = functools.partial(
636            count_ops,
637            # We would've expected 6 here
638            # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
639            # if we didn't enable selective checkpointing.
640            freq=4,
641            op=torch.ops.aten.mm.default,
642        )
643        backend = aot_autograd(
644            fw_compiler=fw_compiler,
645            bw_compiler=bw_compiler,
646            partition_fn=min_cut_rematerialization_partition,
647        )
648        self._validate(fn, backend, x, y)
649        self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
650
651    @requires_cuda
652    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
653    def test_compile_selective_checkpoint_tensor_subclass(self):
654        def selective_checkpointing_context_fn():
655            no_recompute_list = [
656                torch.ops.aten.mm.default,
657            ]
658            return create_selective_checkpoint_contexts(
659                _get_custom_policy(no_recompute_list=no_recompute_list)
660            )
661
662        def gn(x, y):
663            return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
664
665        def fn(x, y):
666            return torch.utils.checkpoint.checkpoint(
667                gn,
668                x,
669                y,
670                use_reentrant=False,
671                context_fn=selective_checkpointing_context_fn,
672            )
673
674        rand_tensor = torch.randn(4, 4, requires_grad=True, device="cuda")
675
676        # tensor subclasses as inputs
677        x = TwoTensor(rand_tensor, rand_tensor.clone())
678        y = TwoTensor(rand_tensor.clone(), rand_tensor.clone())
679
680        fw_compiler = functools.partial(
681            count_ops,
682            freq=4,
683            op=torch.ops.aten.mm.default,
684        )
685        bw_compiler = functools.partial(
686            count_ops,
687            # We would've expected 12 here
688            # (4 matmul recompute and 4 mm ops per fwd matmul, so 4 + 2 * 4 = 12)
689            # if we didn't enable selective checkpointing.
690            freq=8,
691            op=torch.ops.aten.mm.default,
692        )
693        backend = aot_autograd(
694            fw_compiler=fw_compiler,
695            bw_compiler=bw_compiler,
696            partition_fn=min_cut_rematerialization_partition,
697        )
698        self._validate(fn, backend, x, y)
699        self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
700
701    @requires_cuda
702    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
703    def test_compile_selective_checkpoint_custom_rule(self):
704        def _get_custom_policy(meta):
705            no_recompute_list = [
706                torch.ops.aten.mm.default,
707            ]
708
709            def _custom_policy(mode, func, *args, **kwargs):
710                mm_count_key = f"{mode}_mm_count"
711                if mm_count_key not in meta:
712                    meta[mm_count_key] = 0
713                if func == torch.ops.aten.mm.default:
714                    meta[mm_count_key] += 1
715                # Saves output of all compute ops, except second mm
716                # (i.e. we will hint the partitioner to recompute second mm in backward pass)
717                return func in no_recompute_list and not (
718                    func == torch.ops.aten.mm.default and meta[mm_count_key] == 2
719                )
720
721            return _custom_policy
722
723        def selective_checkpointing_context_fn():
724            meta = {}
725            return create_selective_checkpoint_contexts(_get_custom_policy(meta))
726
727        def gn(x, y):
728            return torch.sigmoid(
729                torch.sigmoid(torch.matmul(torch.matmul(x, y) * y, y) * y)
730            )
731
732        def fn(x, y):
733            return torch.utils.checkpoint.checkpoint(
734                gn,
735                x,
736                y,
737                use_reentrant=False,
738                context_fn=selective_checkpointing_context_fn,
739            )
740
741        x = torch.randn(4, 4, requires_grad=True, device="cuda")
742        y = torch.randn(4, 4, requires_grad=True, device="cuda")
743
744        fw_compiler = functools.partial(
745            count_ops,
746            freq=2,
747            op=torch.ops.aten.mm.default,
748        )
749        bw_compiler = functools.partial(
750            count_ops,
751            # Q: How do we come to this number 4?
752            # A: We have 2 matmuls in the forward pass, each matmul contributes 2 `mm` ops in the backward pass,
753            # so we have at least 4 `mm` ops in backward pass. It's "at least" because whether second matmul in
754            # the forward pass is recomputed in the backward pass is up to the partitioner to decide.
755            freq_ge=4,
756            op=torch.ops.aten.mm.default,
757        )
758        backend = aot_autograd(
759            fw_compiler=fw_compiler,
760            bw_compiler=bw_compiler,
761            partition_fn=min_cut_rematerialization_partition,
762        )
763        self._validate(fn, backend, x, y)
764        self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
765
766    @requires_cuda
767    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
768    def test_compile_selective_checkpoint_partial_ctx_fn(self):
769        def selective_checkpointing_context_fn(no_recompute_list):
770            return create_selective_checkpoint_contexts(
771                _get_custom_policy(no_recompute_list=no_recompute_list)
772            )
773
774        def gn(x, y):
775            return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
776
777        def fn(x, y):
778            return torch.utils.checkpoint.checkpoint(
779                gn,
780                x,
781                y,
782                use_reentrant=False,
783                context_fn=functools.partial(
784                    selective_checkpointing_context_fn, [torch.ops.aten.mm.default]
785                ),
786            )
787
788        x = torch.randn(4, 4, requires_grad=True, device="cuda")
789        y = torch.randn(4, 4, requires_grad=True, device="cuda")
790
791        fw_compiler = functools.partial(
792            count_ops,
793            freq=2,
794            op=torch.ops.aten.mm.default,
795        )
796        bw_compiler = functools.partial(
797            count_ops,
798            # We would've expected 6 here
799            # (2 matmul recompute and 2 mm ops per fwd matmul, so 2 + 2 * 2 = 6)
800            # if we didn't enable selective checkpointing.
801            freq=4,
802            op=torch.ops.aten.mm.default,
803        )
804        backend = aot_autograd(
805            fw_compiler=fw_compiler,
806            bw_compiler=bw_compiler,
807            partition_fn=min_cut_rematerialization_partition,
808        )
809        self._validate(fn, backend, x, y)
810        self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
811
812    @requires_cuda
813    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
814    def test_compile_selective_checkpoint_outplace_op(self):
815        def selective_checkpointing_context_fn():
816            no_recompute_list = [
817                torch.ops.aten.mm.default,
818                torch.ops.aten.sigmoid.default,
819            ]
820            return create_selective_checkpoint_contexts(
821                _get_custom_policy(no_recompute_list=no_recompute_list),
822            )
823
824        def gn(x, y):
825            return torch.sigmoid(torch.selu(torch.matmul(torch.matmul(x, y), y))).relu()
826
827        def fn(x, y):
828            return torch.utils.checkpoint.checkpoint(
829                gn,
830                x,
831                y,
832                use_reentrant=False,
833                context_fn=selective_checkpointing_context_fn,
834            )
835
836        x = torch.randn(4, 4, requires_grad=True, device="cuda")
837        y = torch.randn(4, 4, requires_grad=True, device="cuda")
838
839        fw_compiler = functools.partial(
840            count_ops,
841            freqs=[2, 1],
842            ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
843        )
844        bw_compiler = functools.partial(
845            count_ops,
846            freqs=[4, 0],
847            ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
848        )
849        backend = aot_autograd(
850            fw_compiler=fw_compiler,
851            bw_compiler=bw_compiler,
852            partition_fn=min_cut_rematerialization_partition,
853        )
854        self._validate(fn, backend, x, y)
855        self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
856
857    @requires_cuda
858    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
859    @unittest.skip(
860        "In-place op support in selective checkpointing + torch.compile "
861        "requires TorchDispatchMode + torch.compile work to complete"
862    )
863    def test_compile_selective_checkpoint_inplace_op(self):
864        def selective_checkpointing_context_fn():
865            no_recompute_list = [
866                torch.ops.aten.mm.default,
867                torch.ops.aten.sigmoid.default,
868            ]
869            return create_selective_checkpoint_contexts(
870                _get_custom_policy(no_recompute_list=no_recompute_list)
871            )
872
873        def gn(x, y):
874            return torch.sigmoid(
875                torch.selu_(torch.matmul(torch.matmul(x, y), y))
876            ).relu_()
877
878        def fn(x, y):
879            return torch.utils.checkpoint.checkpoint(
880                gn,
881                x,
882                y,
883                use_reentrant=False,
884                context_fn=selective_checkpointing_context_fn,
885            )
886
887        x = torch.randn(4, 4, requires_grad=True, device="cuda")
888        y = torch.randn(4, 4, requires_grad=True, device="cuda")
889
890        fw_compiler = functools.partial(
891            count_ops,
892            freqs=[2, 1],
893            ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
894        )
895        bw_compiler = functools.partial(
896            count_ops,
897            freqs=[4, 0],
898            ops=[torch.ops.aten.mm.default, torch.ops.aten.sigmoid.default],
899        )
900        backend = aot_autograd(
901            fw_compiler=fw_compiler,
902            bw_compiler=bw_compiler,
903            partition_fn=min_cut_rematerialization_partition,
904        )
905        self._validate(fn, backend, x, y)
906        self._compare_orig_and_checkpointed_fns(gn, fn, x, y)
907
908    @requires_cuda
909    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
910    def test_compile_selective_checkpoint_random_op(self):
911        for preserve_rng_state in [True, False]:
912
913            def selective_checkpointing_context_fn():
914                no_recompute_list = [
915                    torch.ops.aten.sigmoid.default,
916                ]
917                return create_selective_checkpoint_contexts(
918                    _get_custom_policy(no_recompute_list=no_recompute_list)
919                )
920
921            def gn(x):
922                return torch.sigmoid(torch.dropout(torch.sigmoid(x), p=0.5, train=True))
923
924            def fn(x):
925                return torch.utils.checkpoint.checkpoint(
926                    gn,
927                    x,
928                    use_reentrant=False,
929                    # Regardless of whether `preserve_rng_state` is True or False,
930                    # we will always preserve RNG state when using `torch.compile`.
931                    preserve_rng_state=preserve_rng_state,
932                    context_fn=selective_checkpointing_context_fn,
933                )
934
935            x = torch.randn(4, 4, requires_grad=True, device="cuda")
936
937            fw_compiler = functools.partial(
938                count_ops,
939                freqs=[2, 1],
940                ops=[
941                    torch.ops.aten.sigmoid.default,
942                    torch.ops.aten.native_dropout.default,
943                ],
944            )
945            bw_compiler = functools.partial(
946                count_ops,
947                # NOTE: This unit test expects `dropout` to be recomputed (notice the count for `native_dropout` is 1).
948                freqs=[0, 1],
949                ops=[
950                    torch.ops.aten.sigmoid.default,
951                    torch.ops.aten.native_dropout.default,
952                ],
953            )
954            backend = aot_autograd(
955                fw_compiler=fw_compiler,
956                bw_compiler=bw_compiler,
957                partition_fn=min_cut_rematerialization_partition,
958            )
959
960            # NOTE: when `preserve_rng_state` is False, gradient will mismatch between torch.compile and eager,
961            # because eager version doesn't preserve RNG state while torch.compile still does.
962            # Hence when `preserve_rng_state` is False, we skip the output and gradient comparison
963            # between torch.compile and eager.
964            self._validate(fn, backend, x, skip_check=not preserve_rng_state)
965            self._compare_orig_and_checkpointed_fns(gn, fn, x)
966
967    @unittest.skipIf(IS_WINDOWS, "torch.compile doesn't work with windows")
968    def test_compile_selective_checkpoint_invalid_context(self):
969        def gn(x, y):
970            return torch.sigmoid(torch.matmul(x, y)) * y
971
972        def fn(x, y):
973            return torch.utils.checkpoint.checkpoint(
974                gn,
975                x,
976                y,
977                use_reentrant=False,
978                context_fn=_invalid_context_gen,
979            )
980
981        x = torch.randn(4, 4, requires_grad=True)
982        y = torch.randn(4, 4, requires_grad=True)
983
984        fw_compiler = functools.partial(
985            count_ops,
986            freq=1,
987            op=torch.ops.aten.mm.default,
988        )
989        bw_compiler = functools.partial(
990            count_ops,
991            freq_ge=2,
992            op=torch.ops.aten.mm.default,
993        )
994        backend = aot_autograd(
995            fw_compiler=fw_compiler,
996            bw_compiler=bw_compiler,
997            partition_fn=min_cut_rematerialization_partition,
998        )
999        with self.assertRaisesRegex(
1000            Exception, "must generate a tuple of two `TorchDispatchMode`s"
1001        ):
1002            self._validate(fn, backend, x, y)
1003
1004    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
1005    def test_compile_selective_checkpoint_parametrization(self):
1006        def sac_policy():
1007            def _recomp_policy():
1008                def _custom_policy(ctx, func, *args, **kwargs):
1009                    to_recompute = func in {
1010                        torch.ops.aten.mul.Tensor,
1011                        torch.ops.aten.sigmoid.default,
1012                    }
1013                    return (
1014                        CheckpointPolicy.MUST_RECOMPUTE
1015                        if to_recompute
1016                        else CheckpointPolicy.MUST_SAVE
1017                    )
1018
1019                return _custom_policy
1020
1021            return create_selective_checkpoint_contexts(_recomp_policy())
1022
1023        class Parametrization(torch.nn.Module):
1024            def __init__(self) -> None:
1025                super().__init__()
1026
1027            def parametrization(self, x):
1028                return torch.sigmoid(torch.mul(x, x))
1029
1030            def forward(self, x):
1031                return checkpoint(
1032                    self.parametrization, x, use_reentrant=False, context_fn=sac_policy
1033                )
1034
1035        def apply_parametrization(model):
1036            modules = list(model.modules())
1037
1038            for mod in modules:
1039                params_dict = dict(mod.named_parameters(recurse=False))
1040                for p_name, p in params_dict.items():
1041                    mod.register_parameter(p_name, nn.Parameter(p))
1042                    nn.utils.parametrize.register_parametrization(
1043                        mod, p_name, Parametrization(), unsafe=True
1044                    )
1045
1046            return model
1047
1048        class MLPModule(nn.Module):
1049            def __init__(self) -> None:
1050                super().__init__()
1051                torch.manual_seed(5)
1052                self.net1 = nn.Linear(16, 16, bias=False)
1053
1054            def forward(self, x):
1055                return self.net1(x)
1056
1057            def reset_parameters(self):
1058                self.net1.reset_parameters()
1059
1060        fw_compiler = functools.partial(
1061            count_ops,
1062            freqs=[1, 1],
1063            ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
1064        )
1065        bw_compiler = functools.partial(
1066            count_ops,
1067            freqs=[
1068                2,  # 1 from mul recompute, 1 from mul backward
1069                1,
1070            ],
1071            ops=[torch.ops.aten.mul.Tensor, torch.ops.aten.sigmoid.default],
1072        )
1073
1074        backend = aot_autograd(
1075            fw_compiler=fw_compiler,
1076            bw_compiler=bw_compiler,
1077            partition_fn=min_cut_rematerialization_partition,
1078        )
1079
1080        model = MLPModule()
1081        model = apply_parametrization(model)
1082        model_compiled = torch.compile(
1083            copy.deepcopy(model), backend=backend, fullgraph=True
1084        )
1085        input = torch.randn(8, 16, requires_grad=True)
1086        input_compiled = copy.deepcopy(input)
1087
1088        out = model(input)
1089        out.sum().backward()
1090        out_compiled = model_compiled(input_compiled)
1091        out_compiled.sum().backward()
1092
1093        self.assertEqual(out, out_compiled)
1094        self.assertEqual(input.grad, input_compiled.grad)
1095
1096    @requires_cuda
1097    @skipIfRocm
1098    def test_autocast_flash_attention(self):
1099        def fn(primals_1, primals_2, primals_3):
1100            return torch.ops.aten._scaled_dot_product_efficient_attention.default(
1101                primals_1, primals_2, primals_3, None, True, scale=0.17677669529663687
1102            )[0]
1103
1104        def gn(*args):
1105            return torch.utils.checkpoint.checkpoint(fn, *args, use_reentrant=True)
1106
1107        with torch.cuda.amp.autocast():
1108            x = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
1109            y = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
1110            z = torch.randn(4, 2, 16, 32, device="cuda", requires_grad=True)
1111            args = (x, y, z)
1112
1113            torch.manual_seed(0)
1114            ref = gn(*args)
1115
1116            opt_gn = torch.compile(gn)
1117            torch.manual_seed(0)
1118            res = opt_gn(*args)
1119            self.assertEqual(ref, res)
1120
1121    @requires_cuda
1122    def test_error_msg(self):
1123        class MockModule(torch.nn.Module):
1124            def __init__(self) -> None:
1125                super().__init__()
1126
1127            def forward(self, x):
1128                x = torch.sin(x)
1129                torch._dynamo.graph_break()
1130                x = torch.cos(x)
1131                return x
1132
1133        mod = MockModule().cuda()
1134
1135        def fn(x):
1136            return torch.utils.checkpoint.checkpoint(mod, x, use_reentrant=True)
1137
1138        x = torch.randn(4, 4).cuda()
1139        opt_fn = torch.compile(fn, fullgraph=True)
1140        with self.assertRaisesRegex(
1141            torch._dynamo.exc.Unsupported, "skip function graph_break in file"
1142        ):
1143            opt_fn(x)
1144
1145    @requires_cuda
1146    def test_list_inputs(self):
1147        class MockModule(torch.nn.Module):
1148            def __init__(self) -> None:
1149                super().__init__()
1150
1151            def forward(self, x, ys):
1152                a = torch.sin(x)
1153                b = torch.cos(ys[0])
1154                c = torch.cos(ys[1])
1155                return (x, [b, c])
1156
1157        mod = MockModule().cuda()
1158
1159        def fn(x, ys):
1160            return torch.utils.checkpoint.checkpoint(mod, x, ys, use_reentrant=True)
1161
1162        x = torch.randn(4, 4).cuda()
1163        y = torch.randn(4, 4).cuda()
1164        z = torch.randn(4, 4).cuda()
1165        ref = fn(x, [y, z])
1166        opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1167        res = opt_fn(x, [y, z])
1168        self.assertEqual(ref, res)
1169
1170    @requires_cuda
1171    def test_pattern_matcher(self):
1172        # Check that the sdpa op is recomputed in the backward graph
1173        # tests percolate_tags
1174
1175        @checkpoint_wrapper
1176        def dot_prod_attention(
1177            query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
1178        ) -> torch.Tensor:
1179            return (
1180                torch.matmul(query, key.transpose(-2, -1))
1181                .mul(1.0 / math.sqrt(key.shape[-1]))
1182                .softmax(dim=-1)
1183                .matmul(value)
1184            )
1185
1186        def fn(query, key, value):
1187            # Checks that sin is not recomputed in the backward graph
1188            return dot_prod_attention(query.sin(), key, value)
1189
1190        tensor_shape = (4, 2, 16, 32)
1191        dtype = torch.float16
1192        args1 = [
1193            torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True),
1194            torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True),
1195            torch.randn(tensor_shape, device="cuda", dtype=dtype, requires_grad=True),
1196        ]
1197
1198        # Save the AOT graphs
1199        aot_graphs = []
1200        from torch._inductor import compile_fx
1201
1202        def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
1203            aot_graphs.append(graph)
1204            return compile_fx.compile_fx_inner(graph, example_inputs, *args, **kwargs)
1205
1206        backend = functools.partial(
1207            compile_fx.compile_fx, inner_compile=debug_compile_fx_inner
1208        )
1209
1210        opt_fn = torch.compile(fn, backend=backend, fullgraph=True)
1211        opt_fn(*args1).sum().backward()
1212        if PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater:
1213            op = torch.ops.aten._scaled_dot_product_cudnn_attention.default
1214        else:
1215            op = torch.ops.aten._scaled_dot_product_flash_attention.default
1216
1217        fwd_graph = aot_graphs[0]
1218        self.assertTrue(
1219            count_ops(
1220                fwd_graph,
1221                [],
1222                freq=1,
1223                op=op,
1224            )
1225        )
1226
1227        bwd_graph = aot_graphs[1]
1228        # Check that sin is not recomputed in the backward graph - checks percolate tags
1229        self.assertTrue(count_ops(bwd_graph, [], freq=0, op=torch.ops.aten.sin.default))
1230        # Check that the sdpa op is recomputed in the backward graph
1231        self.assertTrue(
1232            count_ops(
1233                bwd_graph,
1234                [],
1235                freq=1,
1236                op=op,
1237            )
1238        )
1239
1240    @requires_cuda
1241    @requires_distributed()
1242    def test_distributed_utils_checkpoint_wrapper(self):
1243        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1244            checkpoint_wrapper as dist_checkpoint_wrapper,
1245        )
1246
1247        class MockModule(torch.nn.Module):
1248            def __init__(self) -> None:
1249                super().__init__()
1250                self.linear = torch.nn.Linear(4, 4)
1251                self.c = 2
1252
1253            def forward(self, x):
1254                x = torch.sin(x)
1255                x = self.linear(x)
1256                x = torch.cos(x)
1257                return x * self.c
1258
1259        mod = dist_checkpoint_wrapper(MockModule())
1260        x = torch.randn(4, 4)
1261        ref = mod(x)
1262        opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
1263        res = opt_mod(x)
1264        self.assertEqual(ref, res)
1265
1266    @requires_cuda
1267    @requires_distributed()
1268    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
1269    def test_dynamo_does_not_trace_getattr_as_top_frame(self):
1270        # inline_inbuilt_nn_modules is a proxy to emulate what FSDP tests do.
1271        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1272            CheckpointWrapper,
1273        )
1274
1275        cnt = CompileCounterWithBackend("eager")
1276
1277        lin = torch.nn.Linear(1, 1)
1278        mod = torch.nn.Sequential(lin, lin)
1279        mod = CheckpointWrapper(mod)
1280        mod._checkpoint_wrapped_module.a = torch.ones(1, 1)
1281
1282        def fn(x):
1283            return mod(x) * mod.a
1284
1285        opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
1286        x = torch.randn(1, 1)
1287
1288        self.assertEqual(opt_fn(x), fn(x))
1289
1290
1291if __name__ == "__main__":
1292    from torch._dynamo.test_case import run_tests
1293
1294    run_tests()
1295