xref: /aosp_15_r20/external/pytorch/test/dynamo/test_recompile_ux.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest
3import weakref
4
5import torch
6import torch._dynamo
7import torch._dynamo.config
8import torch._dynamo.test_case
9import torch._dynamo.testing
10import torch._logging
11from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings
12
13
14class RecompileUxTests(torch._dynamo.test_case.TestCase):
15    # TODO(whc) dynamo actually recompiles one more time than the cache limit
16    cache_limit = 1
17
18    @classmethod
19    def setUpClass(cls):
20        super().setUpClass()
21        cls._exit_stack.enter_context(
22            torch._dynamo.config.patch("cache_size_limit", cls.cache_limit)
23        )
24
25    def test_drop_cache_on_skip(self):
26        def model(x, i):
27            return x + i
28
29        attached = False
30        triggered = False
31
32        def trigger():
33            nonlocal triggered
34            triggered = True
35
36        def compiler(gm, input):
37            nonlocal attached
38            f = gm.forward
39            assert not attached
40            # NB: making this a weakref.ref causes the cycle to no
41            # longer be promptly GC'ed
42            weakref.finalize(f, trigger)
43            attached = True
44            return f
45
46        x = torch.randn(2)
47        for i in range(2):
48            opt_model = torch._dynamo.optimize(compiler)(model)
49            opt_model(x, i)
50
51        self.assertTrue(triggered)
52
53    def test_loop_torture(self):
54        def loop_torture(input, iters):
55            out = input
56            # randint itself causes one graph break
57            for _ in range(iters):
58                out += input
59            return out
60
61        compile_counter = torch._dynamo.testing.CompileCounter()
62        for _ in range(10):
63            x = torch.randn(3)
64            iters = torch.randint(low=0, high=1000, size=())
65            opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture)
66            opt_loop_torture(x, iters)
67
68        # Currently, we recompile each time,
69        # We'd probably like to bail out quickly and warn
70        # TODO(whc) these checks fail on py37.  Why?
71        # self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit)
72        # self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit)
73
74        # compile_counter only sees frames that were fed to the backend compiler,
75        # which is a subset of counters["frames"]["ok"] -- probably because
76        # counters["frames"]["ok"] includes frames not containing torch ops?
77        self.assertEqual(compile_counter.frame_count, self.cache_limit)
78
79    @torch._dynamo.config.patch("automatic_dynamic_shapes", False)
80    def test_dynamic_input(self):
81        def model(input):
82            return input + input
83
84        expected_recompiles = 2
85        compile_counter = torch._dynamo.testing.CompileCounter()
86        with torch._dynamo.config.patch("cache_size_limit", expected_recompiles):
87            with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
88                for _ in range(10):
89                    bsz = torch.randint(low=0, high=1000, size=())
90                    x = torch.randn((bsz, 3, 4))
91                    opt_model = torch._dynamo.optimize(compile_counter)(model)
92                    opt_model(x)
93
94        self.assertEqual(compile_counter.frame_count, expected_recompiles)
95        self.assertEqual(len(logs.records), 1)
96        print(logs.records[0])
97        self.assertTrue(
98            logs.records[0]
99            .getMessage()
100            .startswith("torch._dynamo hit config.cache_size_limit")
101        )
102
103    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
104    def test_nvfuser_guards(self):
105        # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards
106        # such that we ensure dynamo is in charge of all the recompilations at the top level,
107        # and we could thus simplify the underlying torchscript executor
108        def func(a, b, c):
109            return a + b * c
110
111        a = torch.rand(3, 4, 5, device="cuda")
112        b = torch.rand(3, 4, 5, device="cuda")
113        b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5)
114        b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1)
115        c = torch.rand(3, 4, 5, device="cuda")
116        compile_counter = torch._dynamo.testing.CompileCounter()
117
118        with torch._dynamo.config.patch("cache_size_limit", 2):
119            opt_func = torch._dynamo.optimize(compile_counter)(func)
120            opt_func(a, b, c)  # warmup
121            self.assertEqual(compile_counter.frame_count, 1)
122
123            opt_func(a, b, c)  # no guard fail or recompile
124            self.assertEqual(compile_counter.frame_count, 1)
125
126            opt_func(a, b_v, c)  # a view should not cause nvfuser recompile
127            self.assertEqual(compile_counter.frame_count, 1)
128
129            opt_func(a, b_p, c)  # a permutation should cause recompile
130            self.assertEqual(compile_counter.frame_count, 2)
131
132    def assert_single_log_contains(self, logs, contains_str):
133        self.assertEqual(len(logs.records), 1)
134        self.assertTrue(
135            logs.records[0].getMessage().find(contains_str) > 0,
136            msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"',
137        )
138
139    def test_verbose_tensor_check(self):
140        def func(a):
141            # Warning: choose a function here whose meta implementation lives
142            # entirely in C++.  If you do a Python one, Dynamo will dive into
143            # torch._refs which is OK but it will muddy up the warnings
144            return torch.add(a, 4)
145
146        def cache_fail_test(cached_input, missed_input, expected_failure):
147            # TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient
148            torch._dynamo.reset()
149            torch._dynamo.utils.counters.clear()
150            opt_func = torch._dynamo.optimize("eager")(func)
151            # warmup
152            opt_func(cached_input)
153
154            with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
155                opt_func = torch._dynamo.optimize("eager")(func)
156                opt_func(missed_input)
157            self.assert_single_log_contains(logs, expected_failure)
158
159        a = torch.rand(3, 4, 5)
160        cache_fail_test(
161            a,
162            a[0:2, :, :],
163            "tensor 'L['a']' size mismatch at index 0. expected 3, actual 2",
164        )
165        cache_fail_test(
166            a,
167            a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)),
168            "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1",
169        )
170        cache_fail_test(
171            a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2"
172        )
173        cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.")
174        cache_fail_test(
175            a,
176            a.to(torch.float16),
177            "tensor 'L['a']' dtype mismatch. expected Float, actual Half",
178        )
179        a_grad = a.clone()
180        a_grad.requires_grad = True
181        cache_fail_test(
182            a,
183            a_grad,
184            "tensor 'L['a']' requires_grad mismatch. expected requires_grad=0",
185        )
186
187    def test_mismatched_type(self):
188        a = torch.rand(3, 4, 5)
189        b = torch.rand(3, 4, 5)
190
191        def func(a, b):
192            return a + b
193
194        opt_func = torch._dynamo.optimize("eager")(func)
195        # warmup
196        opt_func(a, b)
197
198        with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs:
199            opt_func = torch._dynamo.optimize("eager")(func)
200            opt_func(a, 1)
201        self.assert_single_log_contains(
202            logs,
203            "expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>",
204        )
205
206    @torch._dynamo.config.patch("cache_size_limit", 32)
207    def test_multiple_guard_fails(self):
208        failure_reasons = []
209
210        def guard_fail_fn(failure):
211            failure_reasons.append(failure[0])
212
213        def f(x):
214            return torch.relu(x)
215
216        opt_f = torch._dynamo.optimize(
217            backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
218        )(f)
219
220        for i in range(5):
221            failure_reasons.clear()
222            opt_f(torch.randn(8 + i))
223
224        failure_str = "\n".join(failure_reasons)
225        for line in """\
226tensor 'L['x']' size mismatch at index 0. expected 11, actual 12
227tensor 'L['x']' size mismatch at index 0. expected 10, actual 12
228tensor 'L['x']' size mismatch at index 0. expected 9, actual 12
229tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split(
230            "\n"
231        ):
232            self.assertIn(
233                line,
234                failure_str,
235            )
236
237    @torch._dynamo.config.patch("cache_size_limit", 32)
238    def test_multiple_guard_fails_report_all(self):
239        with log_settings(kwargs_to_settings(recompiles_verbose=True)):
240            failure_reasons = []
241
242            def guard_fail_fn(failure):
243                failure_reasons.append(failure[0])
244
245            def f(x):
246                return torch.ones(len(x), x[-1])
247
248            opt_f = torch._dynamo.optimize(
249                backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False
250            )(f)
251
252            opt_f([4, 5, 6])
253
254            def filter_reasons():
255                return "\n".join(
256                    [
257                        line
258                        for line in "\n".join(failure_reasons).splitlines()
259                        if not line.startswith("___check_type_id")
260                    ]
261                )
262
263            failure_reasons.clear()
264            opt_f([7, 8])
265
266            for line in """\
267len(L['x']) == 3""".split(
268                "\n"
269            ):
270                self.assertIn(line, filter_reasons())
271
272            failure_reasons.clear()
273            opt_f([9])
274
275            for line in """\
276len(L['x']) == 2
277len(L['x']) == 3""".split(
278                "\n"
279            ):
280                self.assertIn(line, filter_reasons())
281
282
283if __name__ == "__main__":
284    from torch._dynamo.test_case import run_tests
285
286    run_tests()
287