xref: /aosp_15_r20/external/pytorch/test/dynamo/test_aot_autograd_cache.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import os
4import unittest
5from unittest.mock import patch
6
7import torch
8import torch._dynamo
9import torch._dynamo.test_case
10import torch._functorch._aot_autograd
11from torch._dynamo import config as dynamo_config
12from torch._dynamo.utils import counters
13from torch._functorch import config as functorch_config
14from torch._functorch._aot_autograd.autograd_cache import (
15    AOTAutogradCache,
16    autograd_cache_key,
17    BypassAOTAutogradCache,
18)
19from torch._functorch._aot_autograd.schemas import AOTConfig
20from torch._inductor import config as inductor_config
21from torch._inductor.test_case import TestCase as InductorTestCase
22from torch.testing._internal.common_cuda import SM80OrLater
23from torch.testing._internal.common_device_type import largeTensorTest
24from torch.testing._internal.common_utils import (
25    instantiate_parametrized_tests,
26    parametrize,
27    skipIfWindows,
28)
29from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
30
31
32@instantiate_parametrized_tests
33class AOTAutogradCacheTests(InductorTestCase):
34    def setUp(self):
35        """
36        Reset all counters and caches before each unit test
37        """
38        super().setUp()
39        counters.clear()
40        self._clear_all_caches()
41
42    def _clear_all_caches(self):
43        """
44        Clear every cache, including AOTAutogradCache and FXCache
45        """
46        torch._inductor.codecache.FxGraphCache.clear()
47        AOTAutogradCache.clear()
48        self._clear_dynamo_and_codecache()
49
50    def _clear_dynamo_and_codecache(self):
51        """
52        Clear unrelated caches, like dynamo and PyCodeCache
53        """
54        torch._dynamo.reset()
55        for m in torch._inductor.codecache.PyCodeCache.cache.values():
56            os.remove(m.__file__)
57        torch._inductor.codecache.PyCodeCache.cache_clear()
58
59    @inductor_config.patch("fx_graph_remote_cache", False)
60    @inductor_config.patch("fx_graph_cache", True)
61    @functorch_config.patch({"enable_autograd_cache": True})
62    def test_basic(self):
63        """
64        Verify the interactions between FXGraphCache and AOTAutogradCache.
65        """
66
67        def fn(x, y):
68            return (x * 2, y @ y)
69
70        a = torch.rand(25)
71        b = torch.rand(5, 5)
72
73        compiled_fn = torch.compile(fn, backend="inductor")
74
75        # A first call should miss in the cache.
76        self.assertEqual(fn(a, b), compiled_fn(a, b))
77        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
78        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
79        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
80
81        # A second call should hit. (First reset so in-memory guards
82        # don't prevent compilation).
83        self._clear_dynamo_and_codecache()
84        self.assertEqual(fn(a, b), compiled_fn(a, b))
85
86        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
87        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
88        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
89
90    @inductor_config.patch("fx_graph_remote_cache", False)
91    @inductor_config.patch("fx_graph_cache", True)
92    @functorch_config.patch({"enable_autograd_cache": True})
93    @skipIfWindows(
94        msg="Known issue: Window can't delete loaded modules, so we can't clear module cache."
95    )
96    def test_clear_fx_graph_cache(self):
97        """
98        Verify the interactions between FXGraphCache and AOTAutogradCache.
99        """
100
101        def fn(x, y):
102            return (x * 2, y @ y)
103
104        a = torch.rand(25)
105        b = torch.rand(5, 5)
106
107        compiled_fn = torch.compile(fn, backend="inductor")
108
109        # A first call should miss in the cache.
110        self.assertEqual(fn(a, b), compiled_fn(a, b))
111        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
112        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
113        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
114
115        # Clear FX graph cache: second call should also be a miss
116        self._clear_dynamo_and_codecache()
117        torch._inductor.codecache.FxGraphCache.clear()
118        self.assertEqual(fn(a, b), compiled_fn(a, b))
119        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
120        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
121        # We save again into the cache
122        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 2)
123
124    @inductor_config.patch("fx_graph_remote_cache", False)
125    @inductor_config.patch("fx_graph_cache", False)
126    @functorch_config.patch({"enable_autograd_cache": True})
127    def test_fx_graph_cache_off(self):
128        """
129        Should not use cache if FXGraphCache is not enabled
130        """
131
132        def fn(x, y):
133            return (x * 2, y @ y)
134
135        a = torch.rand(25)
136        b = torch.rand(5, 5)
137
138        compiled_fn = torch.compile(fn, backend="inductor")
139
140        # A first call should miss in the cache.
141        self.assertEqual(fn(a, b), compiled_fn(a, b))
142        self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 1)
143        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
144        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
145
146        # Clear FX graph cache: second call should also be a miss
147        self._clear_dynamo_and_codecache()
148
149        self.assertEqual(fn(a, b), compiled_fn(a, b))
150        self.assertEqual(counters["aot_autograd"]["autograd_cache_bypass"], 2)
151        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
152        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
153
154    @inductor_config.patch("fx_graph_remote_cache", False)
155    @inductor_config.patch("fx_graph_cache", True)
156    @functorch_config.patch({"enable_autograd_cache": True})
157    @dynamo_config.patch("compiled_autograd", True)
158    def test_compiled_autograd_bypass(self):
159        def fn(a, b):
160            out = a.cos() + b
161            loss = out.sum()
162            ga, gb = torch.autograd.grad(loss, inputs=[a, b])
163
164        a = torch.randn(25, requires_grad=True)
165        b = torch.randn(25, requires_grad=True)
166        a2 = a.detach().clone().requires_grad_(True)
167        b2 = b.detach().clone().requires_grad_(True)
168        compiled_fn = torch.compile(fn, backend="inductor")
169        self.assertEqual(fn(a, b), compiled_fn(a2, b2))
170        self.assertEqual(
171            counters["aot_autograd"]["autograd_cache_miss"], 1
172        )  # from compiled forward
173        self.assertEqual(
174            counters["aot_autograd"]["autograd_cache_bypass"], 1
175        )  # from compiled autograd
176
177    @inductor_config.patch("fx_graph_remote_cache", False)
178    @inductor_config.patch("fx_graph_cache", True)
179    @functorch_config.patch({"enable_autograd_cache": True})
180    @dynamo_config.patch("compiled_autograd", True)
181    def test_inference_graph_cache_hit_with_compiled_autograd_enabled(self):
182        def fn(a, b):
183            out = a.cos() + b
184            return out.sum()
185
186        a = torch.randn(25)
187        b = torch.randn(25)
188        compiled_fn = torch.compile(fn, backend="inductor")
189        self.assertEqual(fn(a, b), compiled_fn(a, b))
190        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
191        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
192
193        # Clear dynamo and run again. Should be a cache hit.
194        counters.clear()
195        self._clear_dynamo_and_codecache()
196        self.assertEqual(fn(a, b), compiled_fn(a, b))
197        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
198        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
199        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
200
201    @inductor_config.patch("fx_graph_remote_cache", False)
202    @inductor_config.patch({"fx_graph_cache": True})
203    @functorch_config.patch({"enable_autograd_cache": True})
204    def test_autograd_lazy_backward(self):
205        """
206        Lazily compile the backward, and lazily save to cache
207        """
208
209        def fn(a, b):
210            return a.cos() + b
211
212        a = torch.randn(25, requires_grad=True)
213        b = torch.randn(25, requires_grad=True)
214        a2 = a.detach().clone().requires_grad_(True)
215        b2 = b.detach().clone().requires_grad_(True)
216        compiled_fn = torch.compile(fn, backend="inductor")
217        self.assertEqual(fn(a, b), compiled_fn(a2, b2))
218        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
219        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
220        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
221
222        # Clear dynamo and run again. Should be a cache miss still, because backward hasn't run
223        self._clear_dynamo_and_codecache()
224        self.assertEqual(fn(a, b), compiled_fn(a2, b2))
225        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
226        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
227        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 0)
228
229        # Now let's run the backward
230        fn(a, b).sum().backward()
231        compiled_fn(a2, b2).sum().backward()
232        self.assertEqual(a.grad, a2.grad)
233        self.assertEqual(b.grad, b2.grad)
234        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
235
236        # Clear dynamo and rerun everything, now there should be a cache hit
237        self._clear_dynamo_and_codecache()
238        a = torch.randn(25, requires_grad=True)
239        b = torch.randn(25, requires_grad=True)
240        a2 = a.detach().clone().requires_grad_(True)
241        b2 = b.detach().clone().requires_grad_(True)
242        self.assertEqual(fn(a, b), compiled_fn(a2, b2))
243        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 2)
244        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
245        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
246        fn(a, b).sum().backward()
247        compiled_fn(a2, b2).sum().backward()
248        self.assertEqual(a.grad, a2.grad)
249        self.assertEqual(b.grad, b2.grad)
250
251    @inductor_config.patch("fx_graph_remote_cache", False)
252    @inductor_config.patch("fx_graph_cache", True)
253    @functorch_config.patch({"enable_autograd_cache": True})
254    def test_autograd_function(self):
255        """
256        Tests autograd cache hits
257        """
258
259        def fn(a, b):
260            return a.sin() + b
261
262        a = torch.randn(25, requires_grad=True)
263        b = torch.randn(25, requires_grad=True)
264        a2 = a.detach().clone().requires_grad_(True)
265        b2 = b.detach().clone().requires_grad_(True)
266
267        compiled_fn = torch.compile(fn, backend="inductor")
268
269        # A first call should miss in the cache.
270        self.assertEqual(fn(a, b), compiled_fn(a2, b2))
271        fn(a, b).sum().backward()
272        compiled_fn(a2, b2).sum().backward()
273        self.assertEqual(a.grad, a2.grad)
274        self.assertEqual(b.grad, b2.grad)
275
276        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
277        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
278        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
279
280        # Reset all tensors
281        a = torch.randn(25, requires_grad=True)
282        b = torch.randn(25, requires_grad=True)
283        a2 = a.detach().clone().requires_grad_(True)
284        b2 = b.detach().clone().requires_grad_(True)
285
286        # A second call should hit. (First reset so in-memory guards
287        # don't prevent compilation).
288        self._clear_dynamo_and_codecache()
289        self.assertEqual(fn(a, b), compiled_fn(a2, b2))
290        fn(a, b).sum().backward()
291        compiled_fn(a2, b2).sum().backward()
292        self.assertEqual(a.grad, a2.grad)
293        self.assertEqual(b.grad, b2.grad)
294
295        self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
296        self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
297        self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
298
299    @largeTensorTest("64GB", device=GPU_TYPE)
300    @parametrize("device", (GPU_TYPE,))
301    @parametrize("dtype", (torch.float16, torch.bfloat16))
302    @inductor_config.patch("fx_graph_cache", True)
303    @inductor_config.patch("fx_graph_remote_cache", False)
304    @functorch_config.patch({"enable_autograd_cache": True})
305    def test_autograd_guard_single_entry(self, device, dtype):
306        """
307        Test caching the same graph, but under conditions that introduce guards
308        for tensor sizes < int32. See test_codecache::TestFxGraphCache::test_cache_load_with_guards_int32_bounds.
309
310        This test in particular tests the behavior of a single entry cache. If we ever make AOTAutogradCache
311        support multiple entries under the same key, this test should be updated.
312        """
313        if device == GPU_TYPE and not HAS_GPU:
314            raise unittest.SkipTest(f"requires {GPU_TYPE}")
315        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
316            raise unittest.SkipTest("requires CUDA SM80 or later")
317
318        def fn(x, y):
319            return (x + x, y + y)
320
321        def expect_miss(compiled_fn, a, b):
322            self._clear_dynamo_and_codecache()
323            counters.clear()
324            res = compiled_fn(a, b)
325            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
326            self.assertEqual(
327                counters["aot_autograd"]["autograd_cache_guard_miss"],
328                0,
329            )
330            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
331            return res
332
333        def expect_hit(compiled_fn, a, b):
334            self._clear_dynamo_and_codecache()
335            counters.clear()
336            res = compiled_fn(a, b)
337            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 0)
338            self.assertEqual(
339                counters["aot_autograd"]["autograd_cache_guard_miss"],
340                0,
341            )
342            self.assertEqual(
343                counters["aot_autograd"]["autograd_cache_hit"],
344                1,
345            )
346            return res
347
348        def expect_guard_miss(compiled_fn, a, b):
349            self._clear_dynamo_and_codecache()
350            counters.clear()
351            res = compiled_fn(a, b)
352            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
353            self.assertEqual(
354                counters["aot_autograd"]["autograd_cache_guard_miss"],
355                1,
356            )
357            self.assertEqual(
358                counters["aot_autograd"]["autograd_cache_hit"],
359                0,
360            )
361            return res
362
363        compiled_fn = torch.compile(fn, dynamic=True)
364
365        a_shape = (5, 6)
366        b_shape = (7, 8)
367        a = torch.rand(a_shape, device=device, dtype=dtype)
368        b = torch.rand(b_shape, device=device, dtype=dtype)
369        res1 = expect_miss(compiled_fn, a, b)
370
371        # Same shape, should cache hit
372        a2 = a.detach().clone()
373        b2 = b.detach().clone()
374
375        res2 = expect_hit(compiled_fn, a2, b2)
376
377        self.assertEqual(res1, res2)
378
379        # By changing the shape greatly, despite the same exact input
380        # graph, inductor should report a guard miss, leading
381        # to a cache miss on our end.
382        a_shape = (5, 6)
383        b_shape = (47000, 47001)
384        a3 = torch.rand(a_shape, device=device, dtype=dtype)
385        b3 = torch.rand(b_shape, device=device, dtype=dtype)
386
387        expect_guard_miss(compiled_fn, a3, b3)
388
389        # Wobble the shape a bit, but not enough
390        # to trigger a guard miss (since 6, 7 is still less than int32)
391        # Should result in a cache hit
392        a_shape = (6, 7)
393        b_shape = (47000, 47001)
394        a4 = torch.rand(a_shape, device=device, dtype=dtype)
395        b4 = torch.rand(b_shape, device=device, dtype=dtype)
396        expect_hit(compiled_fn, a4, b4)
397
398        # Change the shape back to the original,
399        # FXGraphCache should hit because it stores
400        # multiple entries
401        a_shape = (5, 6)
402        b_shape = (7, 8)
403        a5 = torch.rand(a_shape, device=device, dtype=dtype)
404        b5 = torch.rand(b_shape, device=device, dtype=dtype)
405        expect_hit(compiled_fn, a5, b5)
406
407    @largeTensorTest("64GB", device=GPU_TYPE)
408    @parametrize("device", (GPU_TYPE,))
409    @parametrize("dtype", (torch.float16, torch.bfloat16))
410    @parametrize("requires_grad", (True, False))
411    @inductor_config.patch("fx_graph_cache", True)
412    @inductor_config.patch("fx_graph_remote_cache", False)
413    @functorch_config.patch({"enable_autograd_cache": True})
414    def test_autograd_inductor_guards(self, device, dtype, requires_grad):
415        """
416        Test caching the same graph, but under conditions that introduce guards
417        for tensor sizes < int32.
418        See test_codecache::TestFxGraphCache::test_cache_load_with_guards_int32_bounds.
419        """
420        if device == GPU_TYPE and not HAS_GPU:
421            raise unittest.SkipTest(f"requires {GPU_TYPE}")
422        if device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater:
423            raise unittest.SkipTest("requires CUDA SM80 or later")
424
425        def fn(x, y):
426            return (x + x, y + y)
427
428        compiled_fn = torch.compile(fn, dynamic=True)
429
430        # Iterate over different shapes, varying whether the total
431        # size is below or above int32. For each combination, we expect
432        # different guards around whether the symbolic sizes do or do
433        # not exceed int32.
434        shapes = (
435            ((5, 6), (7, 8)),
436            ((5, 6), (47000, 47001)),
437            ((47000, 47001), (5, 6)),
438        )
439        expected_hits = expected_misses = expected_saves = 0
440        expected_guard_misses = 0
441        for a_shape, b_shape in shapes:
442            a = torch.rand(
443                a_shape, device=device, dtype=dtype, requires_grad=requires_grad
444            )
445            b = torch.rand(
446                b_shape, device=device, dtype=dtype, requires_grad=requires_grad
447            )
448
449            # AVOID a dynamo reset here. We expect guards to have been
450            # added that will be violated with the new shape. We should
451            # see a recompilation (along with a cache miss).
452            res1 = compiled_fn(a, b)
453            # A first call should miss in the cache.
454            expected_misses += 1
455            self.assertEqual(
456                counters["aot_autograd"]["autograd_cache_miss"], expected_misses
457            )
458            self.assertEqual(
459                counters["aot_autograd"]["autograd_cache_guard_miss"],
460                expected_guard_misses,
461            )
462
463            self.assertEqual(
464                counters["aot_autograd"]["autograd_cache_hit"], expected_hits
465            )
466            # Because dynamic shapes are enabled, we expect backwards to be compiled ahead of time
467            # So we should see a cache save here
468            expected_saves += 1
469            self.assertEqual(
470                counters["aot_autograd"]["autograd_cache_saved"], expected_saves
471            )
472            if requires_grad:
473                res1[0].sum().backward()
474                # No extra saves
475                self.assertEqual(
476                    counters["aot_autograd"]["autograd_cache_saved"], expected_saves
477                )
478
479            a2 = a.detach().clone().requires_grad_(requires_grad)
480            b2 = b.detach().clone().requires_grad_(requires_grad)
481            # A second call should hit. (First reset so in-memory guards
482            # don't prevent compilation).
483
484            # Now clear dynamo and we should see a cache hit
485            # This should populate guards to dynamo's cache, so that a subsequent run with a different
486            # shape will still trigger a second call to autograd_cache.
487            self._clear_dynamo_and_codecache()
488            res2 = compiled_fn(a2, b2)
489            expected_hits += 1
490            self.assertEqual(
491                counters["aot_autograd"]["autograd_cache_miss"], expected_misses
492            )
493            self.assertEqual(
494                counters["aot_autograd"]["autograd_cache_guard_miss"],
495                expected_guard_misses,
496            )
497            # First compile is a regular cache miss, subsequent are guard misses
498            expected_guard_misses += 1
499            self.assertEqual(
500                counters["aot_autograd"]["autograd_cache_hit"], expected_hits
501            )
502            self.assertEqual(
503                counters["aot_autograd"]["autograd_cache_saved"], expected_saves
504            )
505            self.assertEqual(res1, res2)
506            if requires_grad:
507                res2[0].sum().backward()
508                self.assertEqual(a.grad, a2.grad)
509
510    @inductor_config.patch("fx_graph_cache", True)
511    @inductor_config.patch("fx_graph_remote_cache", False)
512    @functorch_config.patch({"enable_autograd_cache": True})
513    def test_nn_module_with_params_global_constant(self):
514        class MyMod(torch.nn.Module):
515            CONSTANT = torch.tensor([[2, 2], [2, 2]])
516
517            def __init__(self) -> None:
518                super().__init__()
519                self.param = torch.nn.Parameter(torch.randn([2, 2]))
520
521            def forward(self, x):
522                return x.sin() + self.param + MyMod.CONSTANT
523
524        with torch.no_grad():
525            compiled_fn = torch.compile(MyMod(), backend="inductor", fullgraph=True)
526            res1 = compiled_fn(torch.ones([2, 2]))
527            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
528            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 0)
529            self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
530
531            self._clear_dynamo_and_codecache()
532            res2 = compiled_fn(torch.ones([2, 2]))
533            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
534            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 1)
535            self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
536
537            self.assertEqual(res1, res2)
538            # Edit the "constant". We'll get a cache hit,
539            # but it should result in a different result when run
540            # because MyMod.CONSTANT is an input to the graph
541            MyMod.CONSTANT = torch.tensor([[3, 3], [3, 3]])
542            self._clear_dynamo_and_codecache()
543            res3 = compiled_fn(torch.ones([2, 2]))
544            self.assertEqual(counters["aot_autograd"]["autograd_cache_miss"], 1)
545            self.assertEqual(counters["aot_autograd"]["autograd_cache_hit"], 2)
546            self.assertEqual(counters["aot_autograd"]["autograd_cache_saved"], 1)
547            self.assertNotEqual(res1, res3)
548            self.assertEqual(res1, res3.sub(torch.ones(2, 2)))
549
550
551@inductor_config.patch("fx_graph_cache", True)
552class AOTAutogradCachePicklerTests(torch._dynamo.test_case.TestCase):
553    @property
554    def device_type(self) -> str:
555        return "cuda" if torch.cuda.is_available() else "cpu"
556
557    def default_config(self):
558        return AOTConfig(
559            fw_compiler=None,
560            bw_compiler=None,
561            inference_compiler=None,
562            partition_fn=None,
563            decompositions={},
564            num_params_buffers=0,
565            aot_id=0,
566            keep_inference_input_mutations=False,
567            dynamic_shapes=True,
568            aot_autograd_arg_pos_to_source=None,
569            is_export=False,
570            no_tangents=False,
571            enable_log=False,
572        )
573
574    def _get_dynamo_output(self, fn, *args, **kwargs):
575        # Reset dynamo between runs
576        torch._dynamo.reset()
577        fx_graph = None
578        example_inputs = None
579
580        def compiler(gm, inputs, **kwargs):
581            nonlocal fx_graph
582            nonlocal example_inputs
583            fx_graph = gm
584            example_inputs = inputs
585            return gm
586
587        g = torch.compile(fn, backend=compiler, fullgraph=True)
588        result = g(*args, **kwargs)
589        return (result, fx_graph, example_inputs)
590
591    def gen_cache_key(self, f, config, inputs=None):
592        if inputs is None:
593            inputs = [torch.ones(3)]
594        _, fx_g, example_inputs = self._get_dynamo_output(f, *inputs)
595        return autograd_cache_key(fx_g, example_inputs, config, {})
596
597    def test_basic_hash_key(self):
598        def fn(x):
599            return x.sin().cos()
600
601        config = self.default_config()
602        # Check hash is stable on multiple runs
603        c1 = self.gen_cache_key(fn, config)
604        c2 = self.gen_cache_key(fn, config)
605        self.assertEqual(c1, c2)
606
607    def test_identical_graphs_and_configs(self):
608        def fn(x):
609            return x.sin().cos()
610
611        def fn2(x):
612            y = x.sin()
613            z = y.cos()
614            return z
615
616        # Make the id different, but otherwise identical
617        config = self.default_config()
618        config2 = self.default_config()
619        config2.aot_id = 1
620
621        c1 = self.gen_cache_key(fn, config)
622        c2 = self.gen_cache_key(fn, config2)
623        self.assertEqual(c1, c2)
624
625    def test_different_graphs(self):
626        def fn(x):
627            return x.cos().sin()
628
629        def fn2(x):
630            return x.sin().cos()
631
632        config = self.default_config()
633        c1 = self.gen_cache_key(fn, config)
634        c2 = self.gen_cache_key(fn2, config)
635        self.assertNotEqual(c1, c2)
636
637    def test_different_configs(self):
638        def fn(x):
639            return x.cos().sin()
640
641        config = self.default_config()
642        config2 = self.default_config()
643        config2.dynamic_shapes = False
644        c1 = self.gen_cache_key(fn, config)
645        c2 = self.gen_cache_key(fn, config2)
646        self.assertNotEqual(c1, c2)
647
648    def test_different_inputs(self):
649        def fn(x):
650            return x.cos().sin()
651
652        config = self.default_config()
653        c1 = self.gen_cache_key(fn, config, inputs=[torch.ones(3)])
654        c2 = self.gen_cache_key(fn, config, inputs=[torch.ones(2)])
655        self.assertNotEqual(c1, c2)
656
657    def test_different_global_configs(self):
658        def fn(x):
659            return x.cos().sin()
660
661        config = self.default_config()
662
663        c1 = self.gen_cache_key(fn, config)
664        c2 = self.gen_cache_key(fn, config)
665        self.assertEqual(c1, c2)
666
667        c1 = self.gen_cache_key(fn, config)
668
669        # Change functorch config
670        with functorch_config.patch(
671            {"debug_assert": not functorch_config.debug_assert}
672        ):
673            c2 = self.gen_cache_key(fn, config)
674
675        self.assertNotEqual(c1, c2)
676
677        c1 = self.gen_cache_key(fn, config)
678        # Change inductor config
679        with inductor_config.patch({"debug": not inductor_config.debug}):
680            c2 = self.gen_cache_key(fn, config)
681
682        self.assertNotEqual(c1, c2)
683
684        c1 = self.gen_cache_key(fn, config)
685        # Change torch grad enabled
686        with torch.no_grad():
687            c2 = self.gen_cache_key(fn, config)
688        self.assertNotEqual(c1, c2)
689
690    def test_incompatible_function(self):
691        @torch._dynamo.allow_in_graph
692        class AllowInGraphFunc(torch.autograd.Function):
693            @staticmethod
694            def forward(_, x):
695                torch._dynamo.graph_break()
696                return x.sin()
697
698        def fn(x):
699            return AllowInGraphFunc.apply(x)
700
701        config = self.default_config()
702        self.assertRaises(
703            BypassAOTAutogradCache, lambda: self.gen_cache_key(fn, config)
704        )
705
706    def test_private_namespace(self):
707        # TODO: anyone who monkeypatches a **public** function into torch namespace with @allow_in_graph
708        # could still break our sanity check and cache something bad. But that's an edge case we'll take the risk on.
709        # Monkeypatch some random private function into torch, see that it fails
710        @torch._dynamo.allow_in_graph
711        def my_private_fun(x):
712            return x.sin()
713
714        with patch("torch._my_priv", new=my_private_fun, create=True):
715
716            def fn(x):
717                return torch._my_priv(x)
718
719            config = self.default_config()
720            self.assertRaises(
721                BypassAOTAutogradCache, lambda: self.gen_cache_key(fn, config)
722            )
723
724    def test_private_builtin(self):
725        # _foreach_add is a private torch function, but
726        # it's also a builtin_function_or_method, so it should be allowed to be cached
727        # since dynamo allows it in the graph
728        def fn(x, b):
729            y = (x, x)
730            return torch._foreach_add(y, b)
731
732        config = self.default_config()
733        r1 = self.gen_cache_key(fn, config, inputs=[torch.ones(3), 1])
734        r2 = self.gen_cache_key(fn, config, inputs=[torch.ones(3), 2])
735        self.assertNotEqual(r1, r2)
736
737    def test_nn_module_with_params(self):
738        class MyMod(torch.nn.Module):
739            def __init__(self) -> None:
740                super().__init__()
741                self.seq = torch.nn.Parameter(torch.ones((3, 3)))
742
743            def forward(self, x):
744                return self.seq + x
745
746        config = self.default_config()
747        # Different inputs and parameters, but all the same size
748        c1 = self.gen_cache_key(MyMod(), config, inputs=[torch.ones((3, 3))])
749        c2 = self.gen_cache_key(MyMod(), config, inputs=[torch.ones((3, 3))])
750        self.assertEqual(c1, c2)
751
752    def test_normal_torch_function(self):
753        @torch._dynamo.allow_in_graph
754        def fn(x):
755            y = torch.sin(x)
756            z = torch.cos(x)
757            w = y + z
758            w.abs()
759            return w
760
761        config = self.default_config()
762        self.gen_cache_key(fn, config)
763
764
765if __name__ == "__main__":
766    from torch._dynamo.test_case import run_tests
767
768    run_tests()
769