xref: /aosp_15_r20/external/pytorch/test/dynamo/test_bytecode_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import collections
4import dis
5import sys
6import unittest
7
8import torch
9import torch._dynamo.test_case
10from torch._dynamo import bytecode_analysis, bytecode_transformation
11from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312
12
13
14class BytecodeTests(torch._dynamo.test_case.TestCase):
15    @skipIfNotPy311
16    def test_linetable_311_writer1(self):
17        def fn():
18            a = 10
19            b = 20
20            # prevent LOAD_FAST_LOAD_FAST in 3.13 by wrapping b with g()
21            c = a + g(b)
22            f = "linetable_writer"
23            return f"Test if {f} generates correct co_linetable: {c}"
24
25        keys = bytecode_transformation.get_code_keys()
26        code_options = {k: getattr(fn.__code__, k) for k in keys}
27        result = bytecode_transformation.clean_and_assemble_instructions(
28            bytecode_transformation.cleaned_instructions(fn.__code__),
29            keys,
30            code_options,
31        )
32        l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
33        self.assertEqual(len(l1), len(l2))
34        for p1, p2 in zip(l1, l2):
35            self.assertEqual(p1, p2)
36        # TODO co_lnotab is deprecated in 3.12 and will be removed in 3.14
37        # In 3.11+,. it is computed lazily from other linetable attributes (e.g. co_linetable),
38        # so we do not set this attribute ourselves.
39        self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
40
41    @skipIfNotPy311
42    def test_linetable_311_writer2(self):
43        """
44        test large ops (LOAD_METHOD) and EXTENDED_ARGS
45        fn_str is in the form:
46        def fn():
47            ...
48            x0 = 1
49            x1 = 1
50            ...
51            l = [x0, x1, ...]
52        """
53        fn_str = f"""\
54def fn():
55    foo.bar(1, 2, 3)
56{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
57    l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}]
58        """
59        locals = {}
60        exec(fn_str, {}, locals)
61        fn = locals["fn"]
62        orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn))))
63        self.assertIn("EXTENDED_ARG", orig_inst_str)
64        load_method_str = "LOAD_ATTR" if sys.version_info >= (3, 12) else "LOAD_METHOD"
65        self.assertIn(load_method_str, orig_inst_str)
66        keys = bytecode_transformation.get_code_keys()
67        code_options = {k: getattr(fn.__code__, k) for k in keys}
68        result = bytecode_transformation.clean_and_assemble_instructions(
69            bytecode_transformation.cleaned_instructions(fn.__code__),
70            keys,
71            code_options,
72        )
73        new_inst_str = "\n".join(list(map(str, result[0])))
74        self.assertIn("EXTENDED_ARG", new_inst_str)
75        self.assertIn(load_method_str, new_inst_str)
76        l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
77        self.assertEqual(len(l1), len(l2))
78        for p1, p2 in zip(l1, l2):
79            self.assertEqual(p1, p2)
80        self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
81
82    @unittest.skipIf(
83        sys.version_info < (3, 10) or sys.version_info >= (3, 11),
84        "linetable test for Python 3.10",
85    )
86    def test_linetable_310_writer(self):
87        def fn():
88            a = 10
89            b = 20
90            c = a + b
91            f = "linetable_writer"
92            return f"Test if {f} generates correct co_linetable: {c}"
93
94        inst = dis.get_instructions(fn)
95        result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
96        self.assertTrue(result[1] == fn.__code__.co_linetable)
97
98    @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10")
99    def test_lnotab_writer(self):
100        def fn():
101            a = 10
102            b = 20
103            c = a + b
104            f = "lnotab_writer"
105            return f"Test if {f} generates correct co_lnotab: {c}"
106
107        inst = dis.get_instructions(fn)
108        result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
109        self.assertTrue(result[1] == fn.__code__.co_lnotab)
110
111    def test_if_tensor_is_none(self):
112        """
113        Python 3.11 adds new jump instructions that check if
114        TOS is None. We do not support these instructions.
115        """
116
117        def f(x, y):
118            z = 1
119            if x is None:
120                z *= 2
121            if y is not None:
122                z *= 3
123            return z
124
125        opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
126        self.assertEqual(opt_f(None, torch.ones(2)), 6)
127
128        if sys.version_info >= (3, 11):
129            insts = bytecode_transformation.cleaned_instructions(f.__code__)
130            for inst in insts:
131                self.assertNotIn("_NONE", inst.opname)
132
133    @skipIfNotPy311
134    def test_py311_jump_offset(self):
135        new_inst = bytecode_transformation.create_instruction
136        consts = (None, 1, 2, 3, 4)
137
138        def create_test_code(jump_opname, target_idx):
139            targets = [
140                new_inst("LOAD_CONST", argval=1),
141                new_inst("LOAD_CONST", argval=3),
142            ]
143            jump_to_target_inst = new_inst(jump_opname, target=targets[target_idx])
144            """
145            pseudocode of generated bytecode:
146            def test_py311_fn():
147                goto target1
148            target0:
149                return 1
150            target1:
151                goto [target0/target2] (via fwd or bwd jump)
152                return 2
153            target2:
154                return 3
155                return 4
156            """
157            # test with LOAD_GLOBAL since it has a different instruction size
158            insts = [
159                new_inst("RESUME", arg=0),
160                new_inst("JUMP_FORWARD", target=jump_to_target_inst),
161                targets[0],
162                new_inst("LOAD_GLOBAL", arg=0, argval="print"),
163                new_inst("POP_TOP"),
164                new_inst("RETURN_VALUE"),
165                jump_to_target_inst,
166                new_inst("LOAD_CONST", argval=2),
167                new_inst("LOAD_GLOBAL", arg=0, argval="print"),
168                new_inst("POP_TOP"),
169                new_inst("RETURN_VALUE"),
170                targets[1],
171                new_inst("RETURN_VALUE"),
172                new_inst("LOAD_CONST", argval=4),
173                new_inst("RETURN_VALUE"),
174            ]
175            code_options = collections.OrderedDict(
176                [
177                    ("co_argcount", 0),
178                    ("co_posonlyargcount", 0),
179                    ("co_kwonlyargcount", 0),
180                    ("co_nlocals", 0),
181                    ("co_stacksize", 2),
182                    ("co_flags", 3),
183                    ("co_code", b""),
184                    ("co_consts", consts),
185                    ("co_names", ("print",)),
186                    ("co_varnames", ()),
187                    ("co_filename", __file__),
188                    ("co_name", "test_py311_fn"),
189                    ("co_qualname", "test_py311_fn"),
190                    ("co_firstlineno", 1),
191                    ("co_linetable", b""),
192                    ("co_exceptiontable", b""),
193                    ("co_freevars", ()),
194                    ("co_cellvars", ()),
195                ]
196            )
197            return bytecode_transformation.clean_and_assemble_instructions(
198                insts,
199                list(code_options.keys()),
200                code_options,
201            )
202
203        # format: jump_opname, target_idx, expected forward jump, expected return value
204        test_args = (
205            ("JUMP_FORWARD", 0, False, 1),
206            ("JUMP_FORWARD", 1, True, 3),
207            ("JUMP_BACKWARD", 0, False, 1),
208            ("JUMP_BACKWARD", 1, True, 3),
209        )
210
211        for test in test_args:
212            insts, code = create_test_code(test[0], test[1])
213            # check if offset of latest jump instruction is forward/backward
214            for inst in reversed(insts):
215                if inst.opname.startswith("JUMP"):
216                    if test[2]:
217                        self.assertIn("FORWARD", inst.opname)
218                    else:
219                        self.assertIn("BACKWARD", inst.opname)
220                    break
221            # run the code and check result
222
223            def dummy_fn():
224                pass
225
226            dummy_fn.__code__ = code
227            self.assertEqual(dummy_fn(), test[3])
228
229            dummy_opt = torch._dynamo.optimize("eager")(dummy_fn)
230            self.assertEqual(dummy_opt(), test[3])
231
232    def test_exception_table_encode_varint(self):
233        # these numbers have no real meaning to them
234        nums = [
235            0b111_101010_000000,
236            0b1100_111000_010101_101010,
237        ]
238        b = bytecode_transformation.encode_exception_table_varint(
239            nums[0]
240        ) + bytecode_transformation.encode_exception_table_varint(nums[1])
241        nums_new = []
242        b_iter = iter(bytes(b))
243        while True:
244            try:
245                nums_new.append(
246                    bytecode_transformation.decode_exception_table_varint(b_iter)
247                )
248            except StopIteration:
249                break
250        self.assertEqual(nums, nums_new)
251
252    @skipIfNotPy311
253    def test_exception_table_parsing(self):
254        def fn():
255            try:
256                with a():
257                    b()
258                c()
259            except Exception:
260                d()
261            finally:
262                e()
263            f()
264
265        tab = bytecode_transformation.parse_exception_table(
266            fn.__code__.co_exceptiontable
267        )
268        b = bytecode_transformation.assemble_exception_table(tab)
269        self.assertEqual(b, fn.__code__.co_exceptiontable)
270
271    @skipIfNotPy311
272    def test_exception_table_e2e(self):
273        def fn():
274            try:
275                with a():
276                    b()
277                c()
278            except Exception:
279                d()
280            finally:
281                e()
282            f()
283
284        def nothing(*args):
285            pass
286
287        code = bytecode_transformation.transform_code_object(fn.__code__, nothing)
288        self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable)
289
290    @skipIfNotPy311
291    def test_exception_table_e2e_2(self):
292        # last instructions of an exn_table entry is a large instruction
293        # i.e., LOAD_GLOBAL a
294        def fn():
295            try:
296                return a
297            except Exception:
298                pass
299
300        def nothing(*args):
301            pass
302
303        code = bytecode_transformation.transform_code_object(fn.__code__, nothing)
304        self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable)
305
306    @skipIfNotPy311
307    def test_exception_table_entry_propagation(self):
308        insts = []
309        for _ in range(10):
310            insts.append(bytecode_transformation.create_instruction("NOP"))
311        insts[8].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
312            insts[0], insts[9], insts[0], 0, True
313        )
314        insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
315            insts[0], insts[0], insts[1], 0, True
316        )
317        insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
318            insts[0], insts[2], insts[2], 0, True
319        )
320        insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
321            insts[4], insts[6], insts[3], 0, True
322        )
323        insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
324            insts[9], insts[9], insts[4], 0, True
325        )
326        insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
327            insts[7], insts[9], insts[5], 0, True
328        )
329        bytecode_transformation.propagate_inst_exn_table_entries(insts)
330        expected = [1, 2, 2, 0, 3, 3, 3, 5, 5, 4]
331        for inst, exp in zip(insts, expected):
332            self.assertIsNotNone(inst.exn_tab_entry)
333            self.assertIs(inst.exn_tab_entry.target, insts[exp])
334
335    @skipIfNotPy311
336    def test_compute_exception_table_nested(self):
337        insts = []
338        for _ in range(20):
339            insts.append(bytecode_transformation.create_instruction("NOP"))
340        insts[10].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
341            insts[1], insts[10], insts[0], 0, True
342        )
343        insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
344            insts[1], insts[1], insts[1], 0, True
345        )
346        insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
347            insts[1], insts[3], insts[2], 0, True
348        )
349        insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
350            insts[5], insts[7], insts[3], 0, True
351        )
352        insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
353            insts[10], insts[10], insts[4], 0, True
354        )
355        insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
356            insts[8], insts[10], insts[5], 0, True
357        )
358        insts[14].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
359            insts[13], insts[17], insts[6], 0, True
360        )
361        insts[16].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
362            insts[15], insts[16], insts[7], 0, True
363        )
364        bytecode_transformation.update_offsets(insts)
365        tab = bytecode_transformation.compute_exception_table(insts)
366        expected = [
367            (1, 1, 1),
368            (2, 3, 2),
369            (4, 4, 0),
370            (5, 7, 3),
371            (8, 9, 5),
372            (10, 10, 4),
373            (13, 14, 6),
374            (15, 16, 7),
375            (17, 17, 6),
376        ]
377        self.assertEqual(len(tab), len(expected))
378        for entry, exp in zip(tab, expected):
379            self.assertEqual(entry.start, exp[0] * 2)
380            self.assertEqual(entry.end, exp[1] * 2)
381            self.assertEqual(entry.target, exp[2] * 2)
382
383    @skipIfNotPy311
384    def test_remove_dead_code_with_exn_table_entries(self):
385        create_instruction = bytecode_transformation.create_instruction
386        target1 = create_instruction("NOP")
387        target2 = create_instruction("NOP")
388        target3 = create_instruction("NOP")
389        exn_start = create_instruction("NOP")
390        exn_end = create_instruction("NOP")
391        insts = [
392            create_instruction("JUMP_FORWARD", target=target1),
393            exn_start,  # dead
394            target1,
395            create_instruction("JUMP_FORWARD", target=target3),
396            exn_end,  # dead
397            target2,
398            target3,
399        ]
400        exn_start.exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
401            exn_start, exn_end, target2, 0, True
402        )
403        bytecode_transformation.propagate_inst_exn_table_entries(insts)
404        insts = bytecode_analysis.remove_dead_code(insts)
405        self.assertEqual(len(insts), 5)
406        self.assertNotIn(exn_start, insts)
407        self.assertNotIn(exn_end, insts)
408        self.assertIn(target2, insts)
409        self.assertIn(target3, insts)
410        bytecode_transformation.update_offsets(insts)
411        tab = bytecode_transformation.compute_exception_table(insts)
412        self.assertEqual(len(tab), 1)
413        self.assertEqual(tab[0].start, 2)
414        self.assertEqual(tab[0].end, 4)
415        self.assertEqual(tab[0].target, 6)
416
417    def test_bytecode_from_template(self):
418        def fn(d1):
419            for k, v in d1.items():
420                d2[k] = v
421
422        varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"}
423        insts = bytecode_transformation.bytecode_from_template(fn, varname_map)
424        for inst in insts:
425            self.assertIsNone(inst.starts_line)
426            if inst.opname.startswith("LOAD"):
427                self.assertNotIn(inst.argval, varname_map)
428                if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"):
429                    self.assertIsNone(inst.arg)
430            self.assertFalse(inst.opname.startswith("RETURN"))
431
432    @skipIfNotPy311
433    def test_bytecode_from_template_noprefix(self):
434        # Test that 3.11+ prefix instructions are removed
435        def gen_fn():
436            cl = None
437
438            def fn():
439                return cl
440
441            return fn
442
443        fn = gen_fn()
444
445        dis_insts = list(dis.get_instructions(fn))
446        names = {inst.opname for inst in dis_insts}
447        self.assertIn("RESUME", names)
448        self.assertIn("COPY_FREE_VARS", names)
449
450        insts = bytecode_transformation.bytecode_from_template(fn)
451        names = {inst.opname for inst in insts}
452        self.assertNotIn("RESUME", names)
453        self.assertNotIn("COPY_FREE_VARS", names)
454
455    def test_bytecode_from_template_noreturn1(self):
456        # Test that functions with multiple returns will have their
457        # returns replaced with jumps to the end
458        def fn():
459            if x:
460                return y
461            z = 3
462            return z
463
464        dis_insts = list(dis.get_instructions(fn))
465        dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts))
466        self.assertGreater(len(dis_returns), 1)
467        self.assertTrue(dis_insts[-1].opname.startswith("RETURN"))
468
469        insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
470        self.assertEqual(insts[-1].opname, "NOP")
471        self.assertEqual(len(dis_insts), len(insts))
472        for i0, i1 in zip(dis_insts, insts):
473            if i0.opname.startswith("RETURN"):
474                if i1 is insts[-1]:
475                    continue
476                self.assertIn("JUMP", i1.opname)
477                self.assertIs(i1.target, insts[-1])
478
479    # Should work with 3.10, but testing with 3.11+ is sufficient.
480    # In 3.8, `fn` ends with a RETURN_VALUE.
481    @skipIfNotPy311
482    def test_bytecode_from_template_noreturn2(self):
483        # Test function that doesn't end with RETURN_VALUE
484        def fn():
485            if x:
486                return x
487            if x:
488                return x
489            raise RuntimeError
490
491        dis_insts = list(dis.get_instructions(fn))
492        self.assertFalse(dis_insts[-1].opname.startswith("RETURN"))
493
494        insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
495        self.assertEqual(insts[-1].opname, "NOP")
496        self.assertEqual(insts[-2].opname, dis_insts[-1].opname)
497        self.assertEqual(len(dis_insts) + 1, len(insts))
498        for i0, i1 in zip(dis_insts, insts):
499            if i0.opname.startswith("RETURN"):
500                self.assertIn("JUMP", i1.opname)
501                self.assertIs(i1.target, insts[-1])
502
503    @skipIfNotPy312
504    def test_bytecode_from_template_noreturn_const(self):
505        # Test 3.12+ RETURN_CONST
506        def fn():
507            if x:
508                return 1
509            return 0
510
511        dis_insts = list(dis.get_instructions(fn))
512        dis_return_consts = list(
513            filter(lambda x: x.opname == "RETURN_CONST", dis_insts)
514        )
515        self.assertGreater(len(dis_return_consts), 1)
516        self.assertTrue(dis_insts[-1].opname == "RETURN_CONST")
517
518        insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False)
519        self.assertEqual(insts[-1].opname, "NOP")
520        insts_i = 0
521        for i, inst in enumerate(dis_insts):
522            if inst.opname == "RETURN_CONST":
523                self.assertEqual(insts[insts_i].opname, "LOAD_CONST")
524                insts_i += 1
525                if insts_i != len(insts) - 1:
526                    self.assertIn("JUMP", insts[insts_i].opname)
527                    self.assertIs(insts[insts_i].target, insts[-1])
528            insts_i += 1
529
530
531class BytecodeHookTests(torch._dynamo.test_case.TestCase):
532    def test_bytecode_hook(self):
533        def fn(a, b):
534            return a - b * 10
535
536        def hook(code, out_code):
537            print(code)
538            print(out_code)
539            return code
540
541        torch._dynamo.reset()
542        handle = torch._dynamo.convert_frame.register_bytecode_hook(hook)
543        try:
544            opt_fn = torch.compile(fn)
545            for i in range(2, 12):
546                opt_fn(torch.randn(i), torch.randn(i))
547        finally:
548            handle.remove()
549
550
551if __name__ == "__main__":
552    from torch._dynamo.test_case import run_tests
553
554    run_tests()
555