xref: /aosp_15_r20/external/pytorch/test/dynamo/test_comptime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import collections
4import re
5import sys
6import time
7from io import StringIO
8
9import torch._dynamo.test_case
10import torch._dynamo.testing
11from torch._dynamo.comptime import comptime
12
13
14# Because we don't support free variables in comptime at the moment,
15# we have to communicate via globals.  This also means these tests cannot
16# be run in parallel in a single process (not that you'd... ever want
17# to do that?)
18FILE = None
19SELF = None
20
21
22class ComptimeTests(torch._dynamo.test_case.TestCase):
23    def test_print_single(self):
24        global FILE
25        FILE = StringIO()
26        cnt = torch._dynamo.testing.CompileCounter()
27
28        def comptime_print(e):
29            @comptime
30            def _(ctx):
31                ctx.print(ctx.get_local("e"), file=FILE)
32
33        Employee = collections.namedtuple("Employee", ["name", "id"])
34
35        class mylist(list):
36            pass
37
38        @torch._dynamo.optimize(cnt, dynamic=True)
39        def f(x):
40            y = x * 2
41            comptime_print(y)
42            comptime_print(2)
43            comptime_print([y, 2])
44            comptime_print((y, 2))
45            comptime_print({"foo": y})
46            comptime_print(range(1, 3))
47            comptime_print(Employee("foo", 2))
48            comptime_print(mylist([1, 2]))
49            comptime_print(collections.defaultdict(lambda: None))
50            comptime_print(set())
51            comptime_print({"a", "b"})
52            comptime_print(x.size(0))
53            return y + 3
54
55        f(torch.randn(2))
56        self.assertEqual(cnt.frame_count, 1)
57        self.assertExpectedInline(
58            FILE.getvalue().strip(),
59            """\
60FakeTensor(..., size=(s0,))
612
62[FakeTensor(..., size=(s0,)), 2]
63(FakeTensor(..., size=(s0,)), 2)
64{'foo': FakeTensor(..., size=(s0,))}
65range(1, 3, 1)
66Employee(name='foo', id=2)
67[1, 2]
68defaultdict(NestedUserFunctionVariable(), {})
69set()
70{'a','b'}
71s0""",
72        )
73
74    def test_print_graph(self):
75        global FILE
76        FILE = StringIO()
77        cnt = torch._dynamo.testing.CompileCounter()
78
79        @torch._dynamo.optimize(cnt)
80        def f(x):
81            y = x * 2
82
83            @comptime
84            def _(ctx):
85                ctx.print_graph(verbose=False, file=FILE)
86
87            # Test the compact notation doesn't error or graph break;
88            # you'll have to visually inspect to see that it printed
89            comptime.print_graph()
90
91            return y + 3
92
93        f(torch.randn(2))
94        self.assertEqual(cnt.frame_count, 1)
95        self.assertExpectedInline(
96            FILE.getvalue().strip(),
97            """\
98def forward(self, L_x_ : torch.Tensor):
99    l_x_ = L_x_
100    y = l_x_ * 2;  l_x_ = y = None""",
101        )
102
103    def test_print_disas(self):
104        global FILE
105        FILE = StringIO()
106        cnt = torch._dynamo.testing.CompileCounter()
107
108        @torch._dynamo.optimize(cnt)
109        def f(x):
110            y = x * 2
111
112            @comptime
113            def _(ctx):
114                ctx.print_disas(file=FILE)
115
116            comptime.print_disas()
117
118            return y + 3
119
120        def munge_disas(s):
121            re.sub(
122                r"^(?: +\d+)?(?: +(-->)) \+\d+ ([A-Za-z0-9_]+)",
123                "\1 \3",
124                s,
125                flags=re.MULTILINE,
126            )
127
128        f(torch.randn(2))
129        self.assertEqual(cnt.frame_count, 1)
130        out = FILE.getvalue()
131        # Check that the instruction offset is working
132        self.assertIn("-->", out)
133        # Check that the bytecode resembles what we expect
134        self.assertIn("STORE_FAST", out)
135        if sys.version_info < (3, 11):
136            self.assertIn("BINARY_MULTIPLY", out)
137        else:
138            self.assertIn("BINARY_OP", out)
139
140    def test_print_value_stack(self):
141        global FILE
142        FILE = StringIO()
143        cnt = torch._dynamo.testing.CompileCounter()
144
145        def g(x):
146            @comptime
147            def _(ctx):
148                ctx.print_value_stack(file=FILE, stacklevel=1)
149
150            return x
151
152        @torch._dynamo.optimize(cnt)
153        def f(x):
154            y = x + g(x)
155
156            return y + comptime.print_value_stack_and_return(y * 2)
157
158        f(torch.randn(2))
159        self.assertEqual(cnt.frame_count, 1)
160        self.assertExpectedInline(
161            FILE.getvalue(),
162            """\
163- FakeTensor(..., size=(2,))
164""",
165        )
166
167    def test_print_locals(self):
168        global FILE
169        FILE = StringIO()
170        cnt = torch._dynamo.testing.CompileCounter()
171
172        @torch._dynamo.optimize(cnt)
173        def f(x):
174            y = x * 2
175
176            @comptime
177            def _(ctx):
178                ctx.print_locals(file=FILE)
179
180            comptime.print_locals()
181
182            return y + 3
183
184        f(torch.randn(2))
185        self.assertEqual(cnt.frame_count, 1)
186        self.assertExpectedInline(
187            FILE.getvalue(),
188            """\
189x = FakeTensor(..., size=(2,))
190y = FakeTensor(..., size=(2,))
191""",
192        )
193
194    # Just make sure it doesn't crash
195    def test_print_direct(self):
196        cnt = torch._dynamo.testing.CompileCounter()
197
198        @torch._dynamo.optimize(cnt)
199        def f(x, z):
200            y = x * 2
201            lambda: z
202            comptime.print(z)
203            return y + 3
204
205        f(torch.randn(2), torch.randn(2))
206
207    def test_sleep(self):
208        sleep_time = 5
209        cnt = torch._dynamo.testing.CompileCounter()
210
211        @torch._dynamo.optimize(cnt)
212        def f(x, z, should_sleep):
213            if should_sleep:
214                comptime.sleep(sleep_time)
215            y = x * 2
216            return y + 3
217
218        start = time.time()
219        f(torch.randn(2), torch.randn(2), False)
220        total_no_sleep = time.time() - start
221
222        start = time.time()
223        f(torch.randn(2), torch.randn(2), True)
224        total_with_sleep = time.time() - start
225
226        self.assertTrue(total_with_sleep > sleep_time)
227        # Hopefully this won't be flaky
228        self.assertTrue(abs(total_with_sleep - sleep_time - total_no_sleep) < 3)
229
230    # Just make sure it doesn't crash
231    def test_get_local_closure_variable(self):
232        global SELF
233        SELF = self
234        cnt = torch._dynamo.testing.CompileCounter()
235
236        @torch._dynamo.optimize(cnt)
237        def f(x):
238            z = 3
239
240            def g():
241                @comptime
242                def _(ctx):
243                    r = ctx.get_local("z")
244                    SELF.assertEqual(repr(r), "3")
245
246                comptime.print(z)
247                return 2
248
249            y = x * g()
250            return y + 3
251
252        f(torch.randn(2))
253
254    def test_print_bt(self):
255        global FILE
256        FILE = StringIO()
257        cnt = torch._dynamo.testing.CompileCounter()
258
259        def g(x):
260            @comptime
261            def _(ctx):
262                ctx.print_bt(file=FILE)
263
264            comptime.print_bt()
265
266            return x + 3
267
268        @torch._dynamo.optimize(cnt)
269        def f(x):
270            y = x * 2
271            y = g(y)
272            return y + 3
273
274        def munge_filenames(s):
275            return re.sub(r'File "[^"]+", line \d+', 'File "X", line X', s)
276
277        f(torch.randn(2))
278        self.assertEqual(cnt.frame_count, 1)
279        bt = FILE.getvalue()
280        self.assertIn("y = g(y)", bt)
281
282    def test_print_guards(self):
283        global FILE
284        FILE = StringIO()
285        cnt = torch._dynamo.testing.CompileCounter()
286
287        @torch._dynamo.optimize(cnt)
288        def f(x):
289            y = x * 2
290
291            @comptime
292            def _(ctx):
293                ctx.print_guards(file=FILE)
294
295            comptime.print_guards()
296
297            return y + 3
298
299        f(torch.randn(2))
300        self.assertEqual(cnt.frame_count, 1)
301        self.assertExpectedInline(
302            re.sub(r"\s+$", "", FILE.getvalue().rstrip(), flags=re.MULTILINE),
303            """\
304
305        local "L['x']" TENSOR_MATCH
306        {
307            'guard_types': None,
308            'code': None,
309            'obj_weakref': None
310            'guarded_class': None
311        }
312        global '' GRAD_MODE
313        {
314            'guard_types': None,
315            'code': None,
316            'obj_weakref': None
317            'guarded_class': None
318        }
319        global '' DETERMINISTIC_ALGORITHMS
320        {
321            'guard_types': None,
322            'code': None,
323            'obj_weakref': None
324            'guarded_class': None
325        }
326        global '' TORCH_FUNCTION_STATE
327        {
328            'guard_types': None,
329            'code': None,
330            'obj_weakref': None
331            'guarded_class': None
332        }
333        global '' DEFAULT_DEVICE
334        {
335            'guard_types': None,
336            'code': None,
337            'obj_weakref': None
338            'guarded_class': None
339        }
340        shape_env '' SHAPE_ENV
341        {
342            'guard_types': None,
343            'code': None,
344            'obj_weakref': None
345            'guarded_class': None
346        }""",
347        )
348
349    def test_graph_break(self):
350        cnt = torch._dynamo.testing.CompileCounter()
351
352        @torch._dynamo.optimize(cnt)
353        def f(x):
354            y = x * 2
355
356            @comptime
357            def _(ctx):
358                pass
359
360            return y + 3
361
362        f(torch.randn(2))
363        self.assertEqual(cnt.frame_count, 1)
364        cnt.frame_count = 0
365
366        @torch._dynamo.optimize(cnt)
367        def g(x):
368            y = x * 2
369
370            @comptime
371            def _(ctx):
372                ctx.graph_break()
373
374            y = y + 2
375
376            comptime.graph_break()
377
378            return y * 3
379
380        g(torch.randn(2))
381        self.assertEqual(cnt.frame_count, 3)
382
383    def test_get_local(self):
384        global SELF, FILE
385        SELF = self
386        FILE = StringIO()
387        cnt = torch._dynamo.testing.CompileCounter()
388
389        @torch._dynamo.optimize(cnt)
390        def f(x):
391            y = x * 2
392            lit = 2
393
394            @comptime
395            def _(ctx):
396                y = ctx.get_local("y")
397                SELF.assertEqual(y.as_fake().size(0), 2)
398                SELF.assertEqual(y.size(0), 2)
399                # Trigger a graph write (TODO: this is not so
400                # useful right now as there's no way to make use
401                # of the output proxy; maybe it's useful for inserting
402                # side-effectful operations into the graph)
403                y.as_proxy() + 4
404                ctx.print_graph(verbose=False, file=FILE)
405                SELF.assertIs(y.python_type(), torch.Tensor)
406                lit = ctx.get_local("lit")
407                SELF.assertEqual(lit.as_python_constant(), 2)
408
409            return y + 3
410
411        f(torch.randn(2))
412        self.assertEqual(cnt.frame_count, 1)
413        self.assertExpectedInline(
414            FILE.getvalue().strip(),
415            """\
416def forward(self, L_x_ : torch.Tensor):
417    l_x_ = L_x_
418    y = l_x_ * 2;  l_x_ = None
419    add = y + 4;  y = add = None""",
420        )
421
422
423if __name__ == "__main__":
424    from torch._dynamo.test_case import run_tests
425
426    run_tests()
427