1# Owner(s): ["module: dynamo"] 2 3import collections 4import dis 5import sys 6import unittest 7 8import torch 9import torch._dynamo.test_case 10from torch._dynamo import bytecode_analysis, bytecode_transformation 11from torch._dynamo.testing import skipIfNotPy311, skipIfNotPy312 12 13 14class BytecodeTests(torch._dynamo.test_case.TestCase): 15 @skipIfNotPy311 16 def test_linetable_311_writer1(self): 17 def fn(): 18 a = 10 19 b = 20 20 # prevent LOAD_FAST_LOAD_FAST in 3.13 by wrapping b with g() 21 c = a + g(b) 22 f = "linetable_writer" 23 return f"Test if {f} generates correct co_linetable: {c}" 24 25 keys = bytecode_transformation.get_code_keys() 26 code_options = {k: getattr(fn.__code__, k) for k in keys} 27 result = bytecode_transformation.clean_and_assemble_instructions( 28 bytecode_transformation.cleaned_instructions(fn.__code__), 29 keys, 30 code_options, 31 ) 32 l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions()) 33 self.assertEqual(len(l1), len(l2)) 34 for p1, p2 in zip(l1, l2): 35 self.assertEqual(p1, p2) 36 # TODO co_lnotab is deprecated in 3.12 and will be removed in 3.14 37 # In 3.11+,. it is computed lazily from other linetable attributes (e.g. co_linetable), 38 # so we do not set this attribute ourselves. 39 self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) 40 41 @skipIfNotPy311 42 def test_linetable_311_writer2(self): 43 """ 44 test large ops (LOAD_METHOD) and EXTENDED_ARGS 45 fn_str is in the form: 46 def fn(): 47 ... 48 x0 = 1 49 x1 = 1 50 ... 51 l = [x0, x1, ...] 52 """ 53 fn_str = f"""\ 54def fn(): 55 foo.bar(1, 2, 3) 56{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} 57 l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] 58 """ 59 locals = {} 60 exec(fn_str, {}, locals) 61 fn = locals["fn"] 62 orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn)))) 63 self.assertIn("EXTENDED_ARG", orig_inst_str) 64 load_method_str = "LOAD_ATTR" if sys.version_info >= (3, 12) else "LOAD_METHOD" 65 self.assertIn(load_method_str, orig_inst_str) 66 keys = bytecode_transformation.get_code_keys() 67 code_options = {k: getattr(fn.__code__, k) for k in keys} 68 result = bytecode_transformation.clean_and_assemble_instructions( 69 bytecode_transformation.cleaned_instructions(fn.__code__), 70 keys, 71 code_options, 72 ) 73 new_inst_str = "\n".join(list(map(str, result[0]))) 74 self.assertIn("EXTENDED_ARG", new_inst_str) 75 self.assertIn(load_method_str, new_inst_str) 76 l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions()) 77 self.assertEqual(len(l1), len(l2)) 78 for p1, p2 in zip(l1, l2): 79 self.assertEqual(p1, p2) 80 self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab) 81 82 @unittest.skipIf( 83 sys.version_info < (3, 10) or sys.version_info >= (3, 11), 84 "linetable test for Python 3.10", 85 ) 86 def test_linetable_310_writer(self): 87 def fn(): 88 a = 10 89 b = 20 90 c = a + b 91 f = "linetable_writer" 92 return f"Test if {f} generates correct co_linetable: {c}" 93 94 inst = dis.get_instructions(fn) 95 result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) 96 self.assertTrue(result[1] == fn.__code__.co_linetable) 97 98 @unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10") 99 def test_lnotab_writer(self): 100 def fn(): 101 a = 10 102 b = 20 103 c = a + b 104 f = "lnotab_writer" 105 return f"Test if {f} generates correct co_lnotab: {c}" 106 107 inst = dis.get_instructions(fn) 108 result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno) 109 self.assertTrue(result[1] == fn.__code__.co_lnotab) 110 111 def test_if_tensor_is_none(self): 112 """ 113 Python 3.11 adds new jump instructions that check if 114 TOS is None. We do not support these instructions. 115 """ 116 117 def f(x, y): 118 z = 1 119 if x is None: 120 z *= 2 121 if y is not None: 122 z *= 3 123 return z 124 125 opt_f = torch._dynamo.optimize("eager", nopython=True)(f) 126 self.assertEqual(opt_f(None, torch.ones(2)), 6) 127 128 if sys.version_info >= (3, 11): 129 insts = bytecode_transformation.cleaned_instructions(f.__code__) 130 for inst in insts: 131 self.assertNotIn("_NONE", inst.opname) 132 133 @skipIfNotPy311 134 def test_py311_jump_offset(self): 135 new_inst = bytecode_transformation.create_instruction 136 consts = (None, 1, 2, 3, 4) 137 138 def create_test_code(jump_opname, target_idx): 139 targets = [ 140 new_inst("LOAD_CONST", argval=1), 141 new_inst("LOAD_CONST", argval=3), 142 ] 143 jump_to_target_inst = new_inst(jump_opname, target=targets[target_idx]) 144 """ 145 pseudocode of generated bytecode: 146 def test_py311_fn(): 147 goto target1 148 target0: 149 return 1 150 target1: 151 goto [target0/target2] (via fwd or bwd jump) 152 return 2 153 target2: 154 return 3 155 return 4 156 """ 157 # test with LOAD_GLOBAL since it has a different instruction size 158 insts = [ 159 new_inst("RESUME", arg=0), 160 new_inst("JUMP_FORWARD", target=jump_to_target_inst), 161 targets[0], 162 new_inst("LOAD_GLOBAL", arg=0, argval="print"), 163 new_inst("POP_TOP"), 164 new_inst("RETURN_VALUE"), 165 jump_to_target_inst, 166 new_inst("LOAD_CONST", argval=2), 167 new_inst("LOAD_GLOBAL", arg=0, argval="print"), 168 new_inst("POP_TOP"), 169 new_inst("RETURN_VALUE"), 170 targets[1], 171 new_inst("RETURN_VALUE"), 172 new_inst("LOAD_CONST", argval=4), 173 new_inst("RETURN_VALUE"), 174 ] 175 code_options = collections.OrderedDict( 176 [ 177 ("co_argcount", 0), 178 ("co_posonlyargcount", 0), 179 ("co_kwonlyargcount", 0), 180 ("co_nlocals", 0), 181 ("co_stacksize", 2), 182 ("co_flags", 3), 183 ("co_code", b""), 184 ("co_consts", consts), 185 ("co_names", ("print",)), 186 ("co_varnames", ()), 187 ("co_filename", __file__), 188 ("co_name", "test_py311_fn"), 189 ("co_qualname", "test_py311_fn"), 190 ("co_firstlineno", 1), 191 ("co_linetable", b""), 192 ("co_exceptiontable", b""), 193 ("co_freevars", ()), 194 ("co_cellvars", ()), 195 ] 196 ) 197 return bytecode_transformation.clean_and_assemble_instructions( 198 insts, 199 list(code_options.keys()), 200 code_options, 201 ) 202 203 # format: jump_opname, target_idx, expected forward jump, expected return value 204 test_args = ( 205 ("JUMP_FORWARD", 0, False, 1), 206 ("JUMP_FORWARD", 1, True, 3), 207 ("JUMP_BACKWARD", 0, False, 1), 208 ("JUMP_BACKWARD", 1, True, 3), 209 ) 210 211 for test in test_args: 212 insts, code = create_test_code(test[0], test[1]) 213 # check if offset of latest jump instruction is forward/backward 214 for inst in reversed(insts): 215 if inst.opname.startswith("JUMP"): 216 if test[2]: 217 self.assertIn("FORWARD", inst.opname) 218 else: 219 self.assertIn("BACKWARD", inst.opname) 220 break 221 # run the code and check result 222 223 def dummy_fn(): 224 pass 225 226 dummy_fn.__code__ = code 227 self.assertEqual(dummy_fn(), test[3]) 228 229 dummy_opt = torch._dynamo.optimize("eager")(dummy_fn) 230 self.assertEqual(dummy_opt(), test[3]) 231 232 def test_exception_table_encode_varint(self): 233 # these numbers have no real meaning to them 234 nums = [ 235 0b111_101010_000000, 236 0b1100_111000_010101_101010, 237 ] 238 b = bytecode_transformation.encode_exception_table_varint( 239 nums[0] 240 ) + bytecode_transformation.encode_exception_table_varint(nums[1]) 241 nums_new = [] 242 b_iter = iter(bytes(b)) 243 while True: 244 try: 245 nums_new.append( 246 bytecode_transformation.decode_exception_table_varint(b_iter) 247 ) 248 except StopIteration: 249 break 250 self.assertEqual(nums, nums_new) 251 252 @skipIfNotPy311 253 def test_exception_table_parsing(self): 254 def fn(): 255 try: 256 with a(): 257 b() 258 c() 259 except Exception: 260 d() 261 finally: 262 e() 263 f() 264 265 tab = bytecode_transformation.parse_exception_table( 266 fn.__code__.co_exceptiontable 267 ) 268 b = bytecode_transformation.assemble_exception_table(tab) 269 self.assertEqual(b, fn.__code__.co_exceptiontable) 270 271 @skipIfNotPy311 272 def test_exception_table_e2e(self): 273 def fn(): 274 try: 275 with a(): 276 b() 277 c() 278 except Exception: 279 d() 280 finally: 281 e() 282 f() 283 284 def nothing(*args): 285 pass 286 287 code = bytecode_transformation.transform_code_object(fn.__code__, nothing) 288 self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) 289 290 @skipIfNotPy311 291 def test_exception_table_e2e_2(self): 292 # last instructions of an exn_table entry is a large instruction 293 # i.e., LOAD_GLOBAL a 294 def fn(): 295 try: 296 return a 297 except Exception: 298 pass 299 300 def nothing(*args): 301 pass 302 303 code = bytecode_transformation.transform_code_object(fn.__code__, nothing) 304 self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable) 305 306 @skipIfNotPy311 307 def test_exception_table_entry_propagation(self): 308 insts = [] 309 for _ in range(10): 310 insts.append(bytecode_transformation.create_instruction("NOP")) 311 insts[8].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 312 insts[0], insts[9], insts[0], 0, True 313 ) 314 insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 315 insts[0], insts[0], insts[1], 0, True 316 ) 317 insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 318 insts[0], insts[2], insts[2], 0, True 319 ) 320 insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 321 insts[4], insts[6], insts[3], 0, True 322 ) 323 insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 324 insts[9], insts[9], insts[4], 0, True 325 ) 326 insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 327 insts[7], insts[9], insts[5], 0, True 328 ) 329 bytecode_transformation.propagate_inst_exn_table_entries(insts) 330 expected = [1, 2, 2, 0, 3, 3, 3, 5, 5, 4] 331 for inst, exp in zip(insts, expected): 332 self.assertIsNotNone(inst.exn_tab_entry) 333 self.assertIs(inst.exn_tab_entry.target, insts[exp]) 334 335 @skipIfNotPy311 336 def test_compute_exception_table_nested(self): 337 insts = [] 338 for _ in range(20): 339 insts.append(bytecode_transformation.create_instruction("NOP")) 340 insts[10].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 341 insts[1], insts[10], insts[0], 0, True 342 ) 343 insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 344 insts[1], insts[1], insts[1], 0, True 345 ) 346 insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 347 insts[1], insts[3], insts[2], 0, True 348 ) 349 insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 350 insts[5], insts[7], insts[3], 0, True 351 ) 352 insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 353 insts[10], insts[10], insts[4], 0, True 354 ) 355 insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 356 insts[8], insts[10], insts[5], 0, True 357 ) 358 insts[14].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 359 insts[13], insts[17], insts[6], 0, True 360 ) 361 insts[16].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 362 insts[15], insts[16], insts[7], 0, True 363 ) 364 bytecode_transformation.update_offsets(insts) 365 tab = bytecode_transformation.compute_exception_table(insts) 366 expected = [ 367 (1, 1, 1), 368 (2, 3, 2), 369 (4, 4, 0), 370 (5, 7, 3), 371 (8, 9, 5), 372 (10, 10, 4), 373 (13, 14, 6), 374 (15, 16, 7), 375 (17, 17, 6), 376 ] 377 self.assertEqual(len(tab), len(expected)) 378 for entry, exp in zip(tab, expected): 379 self.assertEqual(entry.start, exp[0] * 2) 380 self.assertEqual(entry.end, exp[1] * 2) 381 self.assertEqual(entry.target, exp[2] * 2) 382 383 @skipIfNotPy311 384 def test_remove_dead_code_with_exn_table_entries(self): 385 create_instruction = bytecode_transformation.create_instruction 386 target1 = create_instruction("NOP") 387 target2 = create_instruction("NOP") 388 target3 = create_instruction("NOP") 389 exn_start = create_instruction("NOP") 390 exn_end = create_instruction("NOP") 391 insts = [ 392 create_instruction("JUMP_FORWARD", target=target1), 393 exn_start, # dead 394 target1, 395 create_instruction("JUMP_FORWARD", target=target3), 396 exn_end, # dead 397 target2, 398 target3, 399 ] 400 exn_start.exn_tab_entry = bytecode_transformation.InstructionExnTabEntry( 401 exn_start, exn_end, target2, 0, True 402 ) 403 bytecode_transformation.propagate_inst_exn_table_entries(insts) 404 insts = bytecode_analysis.remove_dead_code(insts) 405 self.assertEqual(len(insts), 5) 406 self.assertNotIn(exn_start, insts) 407 self.assertNotIn(exn_end, insts) 408 self.assertIn(target2, insts) 409 self.assertIn(target3, insts) 410 bytecode_transformation.update_offsets(insts) 411 tab = bytecode_transformation.compute_exception_table(insts) 412 self.assertEqual(len(tab), 1) 413 self.assertEqual(tab[0].start, 2) 414 self.assertEqual(tab[0].end, 4) 415 self.assertEqual(tab[0].target, 6) 416 417 def test_bytecode_from_template(self): 418 def fn(d1): 419 for k, v in d1.items(): 420 d2[k] = v 421 422 varname_map = {"d1": "var1", "d2": "var2", "k": "var3", "v": "var4"} 423 insts = bytecode_transformation.bytecode_from_template(fn, varname_map) 424 for inst in insts: 425 self.assertIsNone(inst.starts_line) 426 if inst.opname.startswith("LOAD"): 427 self.assertNotIn(inst.argval, varname_map) 428 if inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR"): 429 self.assertIsNone(inst.arg) 430 self.assertFalse(inst.opname.startswith("RETURN")) 431 432 @skipIfNotPy311 433 def test_bytecode_from_template_noprefix(self): 434 # Test that 3.11+ prefix instructions are removed 435 def gen_fn(): 436 cl = None 437 438 def fn(): 439 return cl 440 441 return fn 442 443 fn = gen_fn() 444 445 dis_insts = list(dis.get_instructions(fn)) 446 names = {inst.opname for inst in dis_insts} 447 self.assertIn("RESUME", names) 448 self.assertIn("COPY_FREE_VARS", names) 449 450 insts = bytecode_transformation.bytecode_from_template(fn) 451 names = {inst.opname for inst in insts} 452 self.assertNotIn("RESUME", names) 453 self.assertNotIn("COPY_FREE_VARS", names) 454 455 def test_bytecode_from_template_noreturn1(self): 456 # Test that functions with multiple returns will have their 457 # returns replaced with jumps to the end 458 def fn(): 459 if x: 460 return y 461 z = 3 462 return z 463 464 dis_insts = list(dis.get_instructions(fn)) 465 dis_returns = list(filter(lambda x: x.opname.startswith("RETURN"), dis_insts)) 466 self.assertGreater(len(dis_returns), 1) 467 self.assertTrue(dis_insts[-1].opname.startswith("RETURN")) 468 469 insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) 470 self.assertEqual(insts[-1].opname, "NOP") 471 self.assertEqual(len(dis_insts), len(insts)) 472 for i0, i1 in zip(dis_insts, insts): 473 if i0.opname.startswith("RETURN"): 474 if i1 is insts[-1]: 475 continue 476 self.assertIn("JUMP", i1.opname) 477 self.assertIs(i1.target, insts[-1]) 478 479 # Should work with 3.10, but testing with 3.11+ is sufficient. 480 # In 3.8, `fn` ends with a RETURN_VALUE. 481 @skipIfNotPy311 482 def test_bytecode_from_template_noreturn2(self): 483 # Test function that doesn't end with RETURN_VALUE 484 def fn(): 485 if x: 486 return x 487 if x: 488 return x 489 raise RuntimeError 490 491 dis_insts = list(dis.get_instructions(fn)) 492 self.assertFalse(dis_insts[-1].opname.startswith("RETURN")) 493 494 insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) 495 self.assertEqual(insts[-1].opname, "NOP") 496 self.assertEqual(insts[-2].opname, dis_insts[-1].opname) 497 self.assertEqual(len(dis_insts) + 1, len(insts)) 498 for i0, i1 in zip(dis_insts, insts): 499 if i0.opname.startswith("RETURN"): 500 self.assertIn("JUMP", i1.opname) 501 self.assertIs(i1.target, insts[-1]) 502 503 @skipIfNotPy312 504 def test_bytecode_from_template_noreturn_const(self): 505 # Test 3.12+ RETURN_CONST 506 def fn(): 507 if x: 508 return 1 509 return 0 510 511 dis_insts = list(dis.get_instructions(fn)) 512 dis_return_consts = list( 513 filter(lambda x: x.opname == "RETURN_CONST", dis_insts) 514 ) 515 self.assertGreater(len(dis_return_consts), 1) 516 self.assertTrue(dis_insts[-1].opname == "RETURN_CONST") 517 518 insts = bytecode_transformation.bytecode_from_template(fn, noprefix=False) 519 self.assertEqual(insts[-1].opname, "NOP") 520 insts_i = 0 521 for i, inst in enumerate(dis_insts): 522 if inst.opname == "RETURN_CONST": 523 self.assertEqual(insts[insts_i].opname, "LOAD_CONST") 524 insts_i += 1 525 if insts_i != len(insts) - 1: 526 self.assertIn("JUMP", insts[insts_i].opname) 527 self.assertIs(insts[insts_i].target, insts[-1]) 528 insts_i += 1 529 530 531class BytecodeHookTests(torch._dynamo.test_case.TestCase): 532 def test_bytecode_hook(self): 533 def fn(a, b): 534 return a - b * 10 535 536 def hook(code, out_code): 537 print(code) 538 print(out_code) 539 return code 540 541 torch._dynamo.reset() 542 handle = torch._dynamo.convert_frame.register_bytecode_hook(hook) 543 try: 544 opt_fn = torch.compile(fn) 545 for i in range(2, 12): 546 opt_fn(torch.randn(i), torch.randn(i)) 547 finally: 548 handle.remove() 549 550 551if __name__ == "__main__": 552 from torch._dynamo.test_case import run_tests 553 554 run_tests() 555