1# mypy: allow-untyped-defs 2import copy 3import dataclasses 4import dis 5import itertools 6import sys 7import types 8from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union 9 10from .bytecode_analysis import ( 11 get_indexof, 12 propagate_line_nums, 13 remove_extra_line_nums, 14 stacksize_analysis, 15) 16 17 18@dataclasses.dataclass 19class InstructionExnTabEntry: 20 start: "Instruction" 21 end: "Instruction" 22 target: "Instruction" 23 depth: int 24 lasti: bool 25 26 def __repr__(self) -> str: 27 return ( 28 f"InstructionExnTabEntry(start={self.start.short_inst_repr()}, " 29 f"end={self.end.short_inst_repr()}, " 30 f"target={self.target.short_inst_repr()}, " 31 f"depth={self.depth}, lasti={self.lasti})" 32 ) 33 34 def __eq__(self, o) -> bool: 35 return ( 36 self.start is o.start 37 and self.end is o.end 38 and self.target is o.target 39 and self.depth == o.depth 40 and self.lasti == o.lasti 41 ) 42 43 44@dataclasses.dataclass 45class Instruction: 46 """A mutable version of dis.Instruction""" 47 48 opcode: int 49 opname: str 50 arg: Optional[int] 51 argval: Any 52 offset: Optional[int] = None 53 starts_line: Optional[int] = None 54 is_jump_target: bool = False 55 positions: Optional["dis.Positions"] = None 56 # extra fields to make modification easier: 57 target: Optional["Instruction"] = None 58 exn_tab_entry: Optional[InstructionExnTabEntry] = None 59 60 def __hash__(self) -> int: 61 return id(self) 62 63 def __eq__(self, other) -> bool: 64 return id(self) == id(other) 65 66 def short_inst_repr(self) -> str: 67 return f"Instruction(opname={self.opname}, offset={self.offset})" 68 69 70def convert_instruction(i: dis.Instruction) -> Instruction: 71 if sys.version_info >= (3, 13): 72 starts_line = i.line_number 73 else: 74 starts_line = i.starts_line 75 return Instruction( 76 i.opcode, 77 i.opname, 78 i.arg, 79 i.argval, 80 i.offset, 81 starts_line, 82 i.is_jump_target, 83 getattr(i, "positions", None), 84 ) 85 86 87class _NotProvided: 88 def __repr__(self) -> str: 89 return "_NotProvided" 90 91 92def inst_has_op_bits(name): 93 return (sys.version_info >= (3, 11) and name == "LOAD_GLOBAL") or ( 94 sys.version_info >= (3, 12) and name in ("LOAD_ATTR", "LOAD_SUPER_ATTR") 95 ) 96 97 98def create_instruction( 99 name, *, arg=None, argval=_NotProvided, target=None 100) -> Instruction: 101 """ 102 At most one of `arg`, `argval`, and `target` can be not None/_NotProvided. 103 This is to prevent ambiguity, e.g. does 104 create_instruction("LOAD_CONST", 5) 105 mean load the constant at co_consts[5], or load the constant 5? 106 107 If `arg` is not provided, it will be computed during assembly from 108 `argval` or `target`. 109 110 Bits in the args of instructions LOAD_GLOBAL, LOAD_ATTR (3.12+), and LOAD_SUPER_ATTR 111 modify the behavior of the instruction. In this case, we allow both `arg` 112 and `argval` to be set. The value of `arg` here is expected to be the value of 113 the op bits and the true value of `arg` will be computed during assembly. 114 If `arg` is not set, the bits are assumed to be 0. 115 """ 116 117 # allow for instructions with op bits to have both arg and argval specified 118 if inst_has_op_bits(name): 119 if target is not None: 120 raise RuntimeError("target cannot be specified for instruction") 121 if arg is None: 122 arg = 0 123 else: 124 cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None) 125 if cnt > 1: 126 raise RuntimeError( 127 "only one of arg, argval, and target can be not None/_NotProvided" 128 ) 129 if arg is not None and not isinstance(arg, int): 130 raise RuntimeError("instruction arg must be int or None") 131 return Instruction( 132 opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target 133 ) 134 135 136# Python 3.11 remaps 137def create_jump_absolute(target) -> Instruction: 138 inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE" 139 return create_instruction(inst, target=target) 140 141 142def create_dup_top() -> Instruction: 143 if sys.version_info >= (3, 11): 144 return create_instruction("COPY", arg=1) 145 return create_instruction("DUP_TOP") 146 147 148def create_rot_n(n) -> List[Instruction]: 149 """ 150 Returns a "simple" sequence of instructions that rotates TOS to the n-th 151 position in the stack. For Python < 3.11, returns a single ROT_* 152 instruction. If no such instruction exists, an error is raised and the 153 caller is expected to generate an equivalent sequence of instructions. 154 For Python >= 3.11, any rotation can be expressed as a simple sequence of 155 swaps. 156 """ 157 if n <= 1: 158 # don't rotate 159 return [] 160 161 if sys.version_info >= (3, 11): 162 # rotate can be expressed as a sequence of swap operations 163 # e.g. rotate 3 is equivalent to swap 3, swap 2 164 return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)] 165 166 # ensure desired rotate function exists 167 if sys.version_info < (3, 8) and n >= 4: 168 raise AttributeError(f"rotate {n} not supported for Python < 3.8") 169 if sys.version_info < (3, 10) and n >= 5: 170 raise AttributeError(f"rotate {n} not supported for Python < 3.10") 171 172 if n <= 4: 173 return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])] 174 return [create_instruction("ROT_N", arg=n)] 175 176 177def add_push_null( 178 inst_or_insts: Union[Instruction, List[Instruction]], 179) -> List[Instruction]: 180 """ 181 Appends or prepends a PUSH_NULL instruction to `inst_or_insts`, 182 depending on Python version. Used when you know that 183 `inst_or_insts` generates a callable that will be called. 184 185 NOTE: Assumes `inst_or_insts` is a single instruction or sequence of 186 instructions that pushes exactly 1 object to the stack that is to 187 be called. It is important that you include ALL instructions that 188 construct the callable - not just the first instruction/a prefix. 189 190 Will attempt to use the NULL push bit for instructions 191 with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR). 192 In this case, instructions WILL be modified. 193 """ 194 if isinstance(inst_or_insts, Instruction): 195 insts = [inst_or_insts] 196 else: 197 insts = inst_or_insts 198 199 def inst_has_bit_set(idx): 200 assert insts[idx].arg is not None 201 return insts[idx].arg & 1 == 1 202 203 def set_inst_bit(idx): 204 assert insts[idx].arg is not None 205 insts[idx].arg |= 1 206 207 if sys.version_info >= (3, 13): 208 # In 3.13, NULL follows the callable 209 if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1): 210 # All insts with op bits have the push_null bit as the last one. 211 # Only set the bit if it hasn't been set - otherwise, we need 212 # to add another PUSH_NULL. 213 set_inst_bit(-1) 214 else: 215 insts = insts + [create_instruction("PUSH_NULL")] 216 elif sys.version_info >= (3, 12): 217 # LOAD_ATTR/LOAD_SUPER_ATTR at the end 218 # We assume that `insts` will only load 1 object, so 219 # LOAD_GLOBAL at the end doesn't need to be checked 220 if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1): 221 set_inst_bit(-1) 222 elif insts[0].opname == "LOAD_GLOBAL" and not inst_has_bit_set(0): 223 set_inst_bit(0) 224 else: 225 insts = [create_instruction("PUSH_NULL")] + insts 226 elif sys.version_info >= (3, 11): 227 # 3.11 introduced NULL preceding callable 228 if inst_has_op_bits(insts[0].opname) and not inst_has_bit_set(0): 229 set_inst_bit(0) 230 else: 231 insts = [create_instruction("PUSH_NULL")] + insts 232 return insts 233 234 235def add_push_null_call_function_ex( 236 inst_or_insts: Union[Instruction, List[Instruction]], 237) -> List[Instruction]: 238 """Like add_push_null, but the low bit of LOAD_ATTR/LOAD_SUPER_ATTR 239 is not set, due to an expected CALL_FUNCTION_EX instruction. 240 """ 241 if isinstance(inst_or_insts, Instruction): 242 insts = [inst_or_insts] 243 else: 244 insts = inst_or_insts 245 246 if sys.version_info < (3, 11): 247 return insts 248 249 idx = -1 if sys.version_info >= (3, 13) else 0 250 if insts[idx].opname == "LOAD_GLOBAL": 251 assert insts[idx].arg is not None 252 if insts[idx].arg & 1 == 0: # type: ignore[operator] 253 insts[idx].arg |= 1 # type: ignore[operator] 254 return insts 255 256 if sys.version_info >= (3, 13): 257 insts = insts + [create_instruction("PUSH_NULL")] 258 else: 259 insts = [create_instruction("PUSH_NULL")] + insts 260 261 return insts 262 263 264def create_call_function(nargs, push_null) -> List[Instruction]: 265 """ 266 Creates a sequence of instructions that makes a function call. 267 268 `push_null` is used in Python 3.11+ only. It is used in codegen when 269 a function call is intended to be made with the NULL + fn convention, 270 and we know that the NULL has not been pushed yet. We will push a 271 NULL and rotate it to the correct position immediately before making 272 the function call. 273 274 `push_null` should be True if no NULL is pushed for the callable. 275 Conversely, `push_null` should be False if a NULL was pushed for the callable. 276 Prefer using `push_null=False` when possible since we will not need to rotate 277 NULL to the right place, which is less efficient. 278 279 Generally, you should codegen a function by using `add_push_null` then 280 `create_call_function` with `push_null=False`. 281 282 Example of when to set push_null False: 283 284 insts = [ 285 create_instruction("LOAD_GLOBAL", argval="torch"), 286 create_instruction("LOAD_ATTR", argval="nn"), 287 create_instruction("LOAD_ATTR", argval="functional"), 288 create_instruction("LOAD_ATTR", argval="relu"), 289 ] 290 insts = add_push_null(insts) 291 insts.append(create_instruction("LOAD_FAST", argval="x")) 292 insts.extend(create_call_function(1, False)) 293 294 Example of when to set push_null True: 295 296 insts = [create_instruction("LOAD_FAST", x)] 297 for should_wrap, wrapper_name in wrappers: 298 if should_wrap: 299 insts.extend([ 300 create_instruction("LOAD_GLOBAL", argval="wrapper1"), 301 create_instruction("SWAP", arg=2), 302 *create_call_function(1, True), 303 ) 304 """ 305 if sys.version_info >= (3, 11): 306 output = [] 307 if push_null: 308 output.append(create_instruction("PUSH_NULL")) 309 # 3.13 swapped NULL and callable 310 rots = nargs + 1 if sys.version_info >= (3, 13) else nargs + 2 311 output.extend(create_rot_n(rots)) 312 if sys.version_info < (3, 12): 313 output.append(create_instruction("PRECALL", arg=nargs)) 314 output.append(create_instruction("CALL", arg=nargs)) 315 return output 316 return [create_instruction("CALL_FUNCTION", arg=nargs)] 317 318 319def create_call_method(nargs) -> List[Instruction]: 320 if sys.version_info >= (3, 12): 321 return [create_instruction("CALL", arg=nargs)] 322 if sys.version_info >= (3, 11): 323 return [ 324 create_instruction("PRECALL", arg=nargs), 325 create_instruction("CALL", arg=nargs), 326 ] 327 return [create_instruction("CALL_METHOD", arg=nargs)] 328 329 330def create_load_method(name) -> Instruction: 331 if sys.version_info >= (3, 12): 332 # in 3.12, create a LOAD_ATTR instruction with the low bit set 333 return create_instruction("LOAD_ATTR", arg=1, argval=name) 334 return create_instruction("LOAD_METHOD", argval=name) 335 336 337def create_setup_with(target) -> Instruction: 338 opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" 339 return create_instruction(opname, target=target) 340 341 342def create_swap(n) -> List[Instruction]: 343 if sys.version_info >= (3, 11): 344 return [create_instruction("SWAP", arg=n)] 345 # in Python < 3.11, SWAP is a macro that expands to multiple instructions 346 if n == 1: 347 return [] 348 """ 349 e.g. swap "a" and "b" in this stack: 350 0 a 1 2 3 b 351 0 a [1 2 3 b] 352 0 a [1 2 3 b] [1 2 3 b] 353 0 a [1 2 3 b] [1 2 3 b] -1 354 0 a [1 2 3 b] b 355 0 b a [1 2 3 b] 356 0 b a [1 2 3 b] [1 2 3 b] 357 0 b [1 2 3 b] a [1 2 3 b] 358 0 b [1 2 3 b] a [1 2 3 b] -1 359 0 b [1 2 3 a] 360 0 b [1 2 3 a] [1 2 3 a] 361 0 b [1 2 3 a] [1 2 3 a] reverse 362 0 b [a 3 2 1] None 363 0 b [a 3 2 1] 364 0 b 1 2 3 a 365 """ 366 return [ 367 create_instruction("BUILD_LIST", arg=n - 1), 368 create_instruction("DUP_TOP"), 369 create_instruction("LOAD_CONST", argval=-1), 370 create_instruction("BINARY_SUBSCR"), 371 create_instruction("ROT_THREE"), 372 create_instruction("DUP_TOP"), 373 create_instruction("ROT_THREE"), 374 create_instruction("LOAD_CONST", argval=-1), 375 create_instruction("STORE_SUBSCR"), 376 create_instruction("DUP_TOP"), 377 create_load_method("reverse"), 378 *create_call_method(0), 379 create_instruction("POP_TOP"), 380 create_instruction("UNPACK_SEQUENCE", arg=n - 1), 381 ] 382 383 384def lnotab_writer( 385 lineno: int, byteno: int = 0 386) -> Tuple[List[int], Callable[[int, int], None]]: 387 """ 388 Used to create typing.CodeType.co_lnotab 389 See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt 390 This is the internal format of the line number table if Python < 3.10 391 """ 392 assert sys.version_info < (3, 10) 393 lnotab: List[int] = [] 394 395 def update(lineno_new, byteno_new): 396 nonlocal byteno, lineno 397 while byteno_new != byteno or lineno_new != lineno: 398 byte_offset = max(0, min(byteno_new - byteno, 255)) 399 line_offset = max(-128, min(lineno_new - lineno, 127)) 400 assert byte_offset != 0 or line_offset != 0 401 byteno += byte_offset 402 lineno += line_offset 403 lnotab.extend((byte_offset, line_offset & 0xFF)) 404 405 return lnotab, update 406 407 408def linetable_310_writer(first_lineno): 409 """ 410 Used to create typing.CodeType.co_linetable 411 See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt 412 This is the internal format of the line number table for Python 3.10 413 """ 414 assert sys.version_info >= (3, 10) and sys.version_info < (3, 11) 415 linetable: List[int] = [] 416 lineno = first_lineno 417 lineno_delta = 0 418 byteno = 0 419 420 def _update(byteno_delta, lineno_delta): 421 while byteno_delta != 0 or lineno_delta != 0: 422 byte_offset = max(0, min(byteno_delta, 254)) 423 line_offset = max(-127, min(lineno_delta, 127)) 424 assert byte_offset != 0 or line_offset != 0 425 byteno_delta -= byte_offset 426 lineno_delta -= line_offset 427 linetable.extend((byte_offset, line_offset & 0xFF)) 428 429 def update(lineno_new, byteno_new): 430 nonlocal lineno, lineno_delta, byteno 431 byteno_delta = byteno_new - byteno 432 byteno = byteno_new 433 _update(byteno_delta, lineno_delta) 434 lineno_delta = lineno_new - lineno 435 lineno = lineno_new 436 437 def end(total_bytes): 438 _update(total_bytes - byteno, lineno_delta) 439 440 return linetable, update, end 441 442 443def encode_varint(n: int) -> List[int]: 444 """ 445 6-bit chunk encoding of an unsigned integer 446 See https://github.com/python/cpython/blob/3.11/Objects/locations.md 447 """ 448 assert n >= 0 449 b = [n & 63] 450 n >>= 6 451 while n > 0: 452 b[-1] |= 64 453 b.append(n & 63) 454 n >>= 6 455 return b 456 457 458def linetable_311_writer(first_lineno: int): 459 """ 460 Used to create typing.CodeType.co_linetable 461 See https://github.com/python/cpython/blob/3.11/Objects/locations.md 462 This is the internal format of the line number table for Python 3.11 463 """ 464 assert sys.version_info >= (3, 11) 465 linetable = [] 466 lineno = first_lineno 467 468 def update(positions: "dis.Positions", inst_size): 469 nonlocal lineno 470 lineno_new = positions.lineno if positions else None 471 472 def _update(delta, size): 473 assert 0 < size <= 8 474 # first byte - use 13 (no column info) is positions is 475 # malformed, otherwise use 14 (long form) 476 other_varints: Tuple[int, ...] = () 477 if ( 478 positions 479 and positions.lineno is not None 480 and positions.end_lineno is not None 481 and positions.col_offset is not None 482 and positions.end_col_offset is not None 483 ): 484 linetable.append(0b1_1110_000 + size - 1) 485 # for whatever reason, column offset needs `+ 1` 486 # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603 487 other_varints = ( 488 positions.end_lineno - positions.lineno, 489 positions.col_offset + 1, 490 positions.end_col_offset + 1, 491 ) 492 else: 493 linetable.append(0b1_1101_000 + size - 1) 494 # encode signed int 495 if delta < 0: 496 delta = ((-delta) << 1) | 1 497 else: 498 delta <<= 1 499 # encode unsigned int 500 linetable.extend(encode_varint(delta)) 501 for n in other_varints: 502 linetable.extend(encode_varint(n)) 503 504 if lineno_new is None: 505 lineno_delta = 0 506 else: 507 lineno_delta = lineno_new - lineno 508 lineno = lineno_new 509 while inst_size > 8: 510 _update(lineno_delta, 8) 511 inst_size -= 8 512 _update(lineno_delta, inst_size) 513 514 return linetable, update 515 516 517@dataclasses.dataclass 518class ExceptionTableEntry: 519 start: int 520 end: int 521 target: int 522 depth: int 523 lasti: bool 524 525 526def encode_exception_table_varint(n: int) -> List[int]: 527 """ 528 Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse. 529 """ 530 assert n >= 0 531 b = [n & 63] 532 n >>= 6 533 while n > 0: 534 b.append(n & 63) 535 n >>= 6 536 b.reverse() 537 for i in range(len(b) - 1): 538 b[i] |= 64 539 return b 540 541 542def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int: 543 """ 544 Inverse of `encode_exception_table_varint`. 545 """ 546 b = next(bytes_iter) 547 val = b & 63 548 while b & 64: 549 val <<= 6 550 b = next(bytes_iter) 551 val |= b & 63 552 return val 553 554 555def check_exception_table(tab: List[ExceptionTableEntry]) -> None: 556 """ 557 Verifies that a list of ExceptionTableEntries will make a well-formed 558 jump table: entries are non-empty, sorted, and do not overlap. 559 """ 560 for i in range(len(tab) - 1): 561 assert ( 562 tab[i].start <= tab[i].end 563 and tab[i].end < tab[i + 1].start 564 and tab[i + 1].start <= tab[i + 1].end 565 ) 566 567 568def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]: 569 """ 570 Parse the exception table according to 571 https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt 572 """ 573 exntab_iter = iter(exntab) 574 tab = [] 575 try: 576 while True: 577 start = decode_exception_table_varint(exntab_iter) * 2 578 length = decode_exception_table_varint(exntab_iter) * 2 579 end = start + length - 2 580 target = decode_exception_table_varint(exntab_iter) * 2 581 dl = decode_exception_table_varint(exntab_iter) 582 depth = dl >> 1 583 lasti = bool(dl & 1) 584 tab.append(ExceptionTableEntry(start, end, target, depth, lasti)) 585 except StopIteration: 586 check_exception_table(tab) 587 return tab 588 589 590def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes: 591 """ 592 Inverse of parse_exception_table - encodes list of exception 593 table entries into bytes. 594 """ 595 b = [] 596 for entry in tab: 597 first_entry = encode_exception_table_varint(entry.start // 2) 598 first_entry[0] |= 1 << 7 599 b.extend(first_entry) 600 length = entry.end - entry.start + 2 601 b.extend(encode_exception_table_varint(length // 2)) 602 b.extend(encode_exception_table_varint(entry.target // 2)) 603 dl = (entry.depth << 1) + entry.lasti 604 b.extend(encode_exception_table_varint(dl)) 605 return bytes(b) 606 607 608def assemble(instructions: List[Instruction], firstlineno: int) -> Tuple[bytes, bytes]: 609 """Do the opposite of dis.get_instructions()""" 610 code: List[int] = [] 611 if sys.version_info >= (3, 11): 612 lnotab, update_lineno = linetable_311_writer(firstlineno) 613 num_ext = 0 614 for i, inst in enumerate(instructions): 615 if inst.opname == "EXTENDED_ARG": 616 inst_size = 1 617 num_ext += 1 618 # copy positions from the actual instruction 619 for j in (1, 2, 3): 620 if instructions[i + j].opname != "EXTENDED_ARG": 621 inst.positions = instructions[i + j].positions 622 break 623 else: 624 inst_size = instruction_size(inst) // 2 + num_ext 625 num_ext = 0 626 update_lineno(inst.positions, inst_size) 627 num_ext = 0 628 arg = inst.arg or 0 629 code.extend((inst.opcode, arg & 0xFF)) 630 for _ in range(instruction_size(inst) // 2 - 1): 631 code.extend((0, 0)) 632 else: 633 if sys.version_info < (3, 10): 634 lnotab, update_lineno = lnotab_writer(firstlineno) 635 else: 636 lnotab, update_lineno, end = linetable_310_writer(firstlineno) 637 638 for inst in instructions: 639 if inst.starts_line is not None: 640 update_lineno(inst.starts_line, len(code)) 641 arg = inst.arg or 0 642 code.extend((inst.opcode, arg & 0xFF)) 643 644 if sys.version_info >= (3, 10): 645 end(len(code)) 646 647 return bytes(code), bytes(lnotab) 648 649 650def _get_instruction_by_offset(offset_to_inst: Dict[int, Instruction], offset: int): 651 """ 652 Get the instruction located at a given offset, accounting for EXTENDED_ARGs 653 """ 654 for n in (0, 2, 4, 6): 655 if offset_to_inst[offset + n].opcode != dis.EXTENDED_ARG: 656 return offset_to_inst[offset + n] 657 return None 658 659 660def virtualize_jumps(instructions) -> None: 661 """Replace jump targets with pointers to make editing easier""" 662 jump_targets = {inst.offset: inst for inst in instructions} 663 664 for inst in instructions: 665 if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel: 666 inst.target = _get_instruction_by_offset(jump_targets, inst.argval) 667 668 669_REL_JUMPS = set(dis.hasjrel) 670 671 672def flip_jump_direction(instruction: Instruction) -> None: 673 if sys.version_info < (3, 11): 674 raise RuntimeError("Cannot flip jump direction in Python < 3.11") 675 if "FORWARD" in instruction.opname: 676 instruction.opname = instruction.opname.replace("FORWARD", "BACKWARD") 677 elif "BACKWARD" in instruction.opname: 678 instruction.opname = instruction.opname.replace("BACKWARD", "FORWARD") 679 else: 680 raise AttributeError("Instruction is not a forward or backward jump") 681 instruction.opcode = dis.opmap[instruction.opname] 682 assert instruction.opcode in _REL_JUMPS 683 684 685def _get_instruction_front(instructions: List[Instruction], idx: int): 686 """ 687 i.e. get the first EXTENDED_ARG instruction (if any) when targeting 688 instructions[idx] with a jump. 689 """ 690 target = instructions[idx] 691 for offset in (1, 2, 3): 692 if idx >= offset and instructions[idx - offset].opcode == dis.EXTENDED_ARG: 693 target = instructions[idx - offset] 694 else: 695 break 696 return target 697 698 699def devirtualize_jumps(instructions): 700 """Fill in args for virtualized jump target after instructions may have moved""" 701 jumps = set(dis.hasjabs).union(set(dis.hasjrel)) 702 703 # check for negative jump args and fix them 704 for inst in instructions: 705 if inst.opcode in jumps: 706 if inst.opcode not in dis.hasjabs: 707 if inst.target.offset < inst.offset: 708 if sys.version_info < (3, 11): 709 raise RuntimeError("Got negative jump offset for Python < 3.11") 710 # forward jumps become backward 711 if "FORWARD" in inst.opname: 712 flip_jump_direction(inst) 713 else: 714 # backward jumps become forward 715 if sys.version_info >= (3, 11) and "BACKWARD" in inst.opname: 716 flip_jump_direction(inst) 717 718 # jump instruction size may have changed due to flips 719 update_offsets(instructions) 720 indexof = get_indexof(instructions) 721 722 # compute jump instruction arg 723 for inst in instructions: 724 if inst.opcode in jumps: 725 target = _get_instruction_front(instructions, indexof[inst.target]) 726 if inst.opcode in dis.hasjabs: 727 if sys.version_info < (3, 10): 728 inst.arg = target.offset 729 elif sys.version_info < (3, 11): 730 # `arg` is expected to be bytecode offset, whereas `offset` is byte offset. 731 # Divide since bytecode is 2 bytes large. 732 inst.arg = int(target.offset / 2) 733 else: 734 raise RuntimeError("Python 3.11+ should not have absolute jumps") 735 else: # relative jump 736 # byte offset between target and next instruction 737 inst.arg = abs( 738 int(target.offset - inst.offset - instruction_size(inst)) 739 ) 740 if sys.version_info >= (3, 10): 741 # see bytecode size comment in the absolute jump case above 742 inst.arg //= 2 743 inst.argval = target.offset 744 inst.argrepr = f"to {target.offset}" 745 746 747def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruction]): 748 """Replace exception table entries with pointers to make editing easier""" 749 exn_tab = parse_exception_table(exn_tab_bytes) 750 offset_to_inst = {cast(int, inst.offset): inst for inst in instructions} 751 offsets = sorted(offset_to_inst.keys()) 752 end_offset_idx = 0 753 exn_tab_iter = iter(exn_tab) 754 try: 755 756 def step(): 757 nonlocal end_offset_idx 758 entry = next(exn_tab_iter) 759 # find rightmost offset <= entry.end, since entry.end may not be 760 # an actual instruction, e.g. if the end instruction is LOAD_GLOBAL, 761 # which takes more than 2 bytes, then entry.end points to the end 762 # of the LOAD_GLOBAL instruction, not the beginning. 763 while ( 764 end_offset_idx < len(offsets) and offsets[end_offset_idx] <= entry.end 765 ): 766 end_offset_idx += 1 767 assert end_offset_idx > 0 768 end_offset = offsets[end_offset_idx - 1] 769 inst_entry = InstructionExnTabEntry( 770 _get_instruction_by_offset(offset_to_inst, entry.start), 771 _get_instruction_by_offset(offset_to_inst, end_offset), 772 _get_instruction_by_offset(offset_to_inst, entry.target), 773 entry.depth, 774 entry.lasti, 775 ) 776 return entry, inst_entry 777 778 entry, inst_entry = step() 779 for inst in instructions: 780 while inst.offset > entry.end: 781 entry, inst_entry = step() 782 if inst.offset >= entry.start: 783 inst.exn_tab_entry = copy.copy(inst_entry) 784 except StopIteration: 785 pass 786 787 788def compute_exception_table( 789 instructions: List[Instruction], 790) -> List[ExceptionTableEntry]: 791 """Compute exception table in list format from instructions with exn_tab_entries""" 792 exn_dict: Dict[Tuple[int, int], Tuple[int, int, bool]] = {} 793 indexof = get_indexof(instructions) 794 795 for inst in instructions: 796 if inst.exn_tab_entry: 797 # account for prefixed EXTENDED_ARGS 798 start = _get_instruction_front( 799 instructions, indexof[inst.exn_tab_entry.start] 800 ).offset 801 # point to the last 2 bytes of the end instruction 802 end = ( 803 cast(int, inst.exn_tab_entry.end.offset) 804 + instruction_size(inst.exn_tab_entry.end) 805 - 2 806 ) 807 target = _get_instruction_front( 808 instructions, indexof[inst.exn_tab_entry.target] 809 ).offset 810 key = (start, end) 811 val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti) 812 if key in exn_dict: 813 assert exn_dict[key] == val 814 exn_dict[key] = val 815 816 # Dynamo may construct nested exception table entries for convenience, 817 # but Python expects exception table entries to not overlap. 818 # NOTE: below, "keys" refer to old instruction entries' starts and ends, 819 # and "entries" refer to the generated exception table entries. 820 821 # Sort keys by increasing start, then decreasing end 822 keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1])) 823 # smallest byte that the next exception table entry can start at 824 nexti = 0 825 # stack of current nested keys 826 key_stack: List[Tuple[int, int]] = [] 827 exn_tab: List[ExceptionTableEntry] = [] 828 829 def pop(): 830 """ 831 Pop the key_stack and append an exception table entry if possible. 832 """ 833 nonlocal nexti 834 if key_stack: 835 key = key_stack.pop() 836 if nexti <= key[1]: 837 exn_tab.append( 838 ExceptionTableEntry(max(key[0], nexti), key[1], *exn_dict[key]) 839 ) 840 nexti = key[1] + 2 841 842 for key in keys_sorted: 843 # pop keys that are no longer nested over the current key 844 while key_stack and key_stack[-1][1] < key[0]: 845 pop() 846 if key_stack: 847 # create an entry covering to the current key, if possible 848 assert key_stack[-1][0] <= key[0] <= key[1] <= key_stack[-1][1] 849 left = max(nexti, key_stack[-1][0]) 850 if left < key[0]: 851 exn_tab.append( 852 ExceptionTableEntry(left, key[0] - 2, *exn_dict[key_stack[-1]]) 853 ) 854 nexti = key[0] 855 key_stack.append(key) 856 while key_stack: 857 pop() 858 check_exception_table(exn_tab) 859 return exn_tab 860 861 862def check_inst_exn_tab_entries_nested( 863 tab: List[InstructionExnTabEntry], indexof 864) -> None: 865 """ 866 Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's, 867 i.e. no entries partially overlap. 868 "Properly sorted" means entries are sorted by increasing starts, then 869 decreasing ends. 870 """ 871 entry_stack: List[Tuple[int, int]] = [] 872 for entry in tab: 873 key = (indexof[entry.start], indexof[entry.end]) 874 while entry_stack and entry_stack[-1][1] < key[0]: 875 entry_stack.pop() 876 if entry_stack: 877 assert entry_stack[-1][0] <= key[0] <= key[1] <= entry_stack[-1][1] 878 entry_stack.append(key) 879 880 881def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None: 882 """ 883 Copies exception table entries to all instructions in an entry's range. 884 Supports nested exception table entries. 885 """ 886 indexof = get_indexof(instructions) 887 entries: Dict[Tuple[int, int], InstructionExnTabEntry] = {} 888 for inst in instructions: 889 if inst.exn_tab_entry: 890 key = ( 891 indexof[inst.exn_tab_entry.start], 892 indexof[inst.exn_tab_entry.end], 893 ) 894 if key in entries: 895 assert inst.exn_tab_entry == entries[key] 896 entries[key] = inst.exn_tab_entry 897 sorted_entries = [ 898 entries[key] for key in sorted(entries.keys(), key=lambda t: (t[0], -t[1])) 899 ] 900 check_inst_exn_tab_entries_nested(sorted_entries, indexof) 901 # Propagation of nested entries works since nested entries come later 902 # in sorted order. 903 for entry in sorted_entries: 904 for i in range(indexof[entry.start], indexof[entry.end] + 1): 905 instructions[i].exn_tab_entry = copy.copy(entry) 906 907 908def check_inst_exn_tab_entries_valid(instructions: List[Instruction]): 909 """ 910 Checks that exn_tab_entries of instructions are valid. 911 An entry's start, end, and target must be in instructions. 912 Instructions with an exn_tab_entry are located within 913 the entry's start and end instructions. 914 Instructions do not share exn_tab_entries. 915 916 Implicitly checks for no duplicate instructions. 917 """ 918 indexof = get_indexof(instructions) 919 exn_tab_entry_set = set() 920 for i, inst in enumerate(instructions): 921 if inst.exn_tab_entry: 922 assert sys.version_info >= (3, 11) 923 assert id(inst.exn_tab_entry) not in exn_tab_entry_set 924 exn_tab_entry_set.add(id(inst.exn_tab_entry)) 925 entry = inst.exn_tab_entry 926 assert entry.start in indexof 927 assert entry.end in indexof 928 assert entry.target in indexof 929 assert indexof[entry.start] <= i <= indexof[entry.end] 930 931 932def strip_extended_args(instructions: List[Instruction]) -> None: 933 instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG] 934 935 936def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]: 937 """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it""" 938 assert sys.version_info < (3, 11) 939 rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"} 940 for inst in instructions: 941 if inst.opname in rewrites: 942 inst.opname = rewrites[inst.opname] 943 inst.opcode = dis.opmap[inst.opname] 944 return instructions 945 946 947def remove_jump_if_none(instructions: List[Instruction]) -> None: 948 new_insts = [] 949 for inst in instructions: 950 new_insts.append(inst) 951 if "_NONE" in inst.opname: 952 is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname)) 953 is_op.argval = is_op.arg 954 is_op.positions = inst.positions 955 if sys.version_info < (3, 12): 956 jump_op = create_instruction( 957 "POP_JUMP_FORWARD_IF_TRUE" 958 if "FORWARD" in inst.opname 959 else "POP_JUMP_BACKWARD_IF_TRUE", 960 target=inst.target, 961 ) 962 else: 963 jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target) 964 jump_op.positions = inst.positions 965 # update inst.exn_tab_entry.end if necessary 966 if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: 967 inst.exn_tab_entry.end = jump_op 968 # preserve exception table entries 969 is_op.exn_tab_entry = copy.copy(inst.exn_tab_entry) 970 jump_op.exn_tab_entry = copy.copy(inst.exn_tab_entry) 971 # modify inst in-place to preserve jump target 972 inst.opcode = dis.opmap["LOAD_CONST"] 973 inst.opname = "LOAD_CONST" 974 inst.arg = None 975 inst.argval = None 976 new_insts.extend([is_op, jump_op]) 977 instructions[:] = new_insts 978 979 980def remove_binary_store_slice(instructions: List[Instruction]) -> None: 981 new_insts = [] 982 for inst in instructions: 983 new_insts.append(inst) 984 if inst.opname in ("BINARY_SLICE", "STORE_SLICE"): 985 # new instruction 986 subscr_inst = create_instruction(inst.opname.replace("SLICE", "SUBSCR")) 987 if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: 988 inst.exn_tab_entry.end = subscr_inst 989 subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry) 990 subscr_inst.positions = inst.positions 991 # modify inst in-place to preserve jump target 992 inst.opcode = dis.opmap["BUILD_SLICE"] 993 inst.opname = "BUILD_SLICE" 994 inst.arg = 2 995 inst.argval = 2 996 new_insts.append(subscr_inst) 997 instructions[:] = new_insts 998 999 1000FUSED_INSTS = { 1001 "LOAD_FAST_LOAD_FAST": ("LOAD_FAST", "LOAD_FAST"), 1002 "STORE_FAST_STORE_FAST": ("STORE_FAST", "STORE_FAST"), 1003 "STORE_FAST_LOAD_FAST": ("STORE_FAST", "LOAD_FAST"), 1004} 1005 1006 1007def remove_fused_load_store(instructions: List[Instruction]) -> None: 1008 new_insts = [] 1009 for inst in instructions: 1010 new_insts.append(inst) 1011 if inst.opname in FUSED_INSTS: 1012 inst0, inst1 = FUSED_INSTS[inst.opname] 1013 argval0, argval1 = inst.argval 1014 1015 # modify inst in-place to preserve jump target 1016 inst.opcode = dis.opmap[inst0] 1017 inst.opname = inst0 1018 inst.argval = argval0 1019 1020 new_inst = create_instruction(inst1, argval=argval1) 1021 # update inst.exn_tab_entry.end if necessary 1022 if inst.exn_tab_entry and inst.exn_tab_entry.end is inst: 1023 inst.exn_tab_entry.end = new_inst 1024 # preserve exception table entries 1025 new_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry) 1026 1027 new_insts.append(new_inst) 1028 instructions[:] = new_insts 1029 1030 1031def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> None: 1032 """convert super() with no args into explicit arg form""" 1033 cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ()) 1034 if not len(code.co_varnames): 1035 # A function with no argument cannot contain a valid "super()" call 1036 return 1037 output = [] 1038 for idx, inst in enumerate(instructions): 1039 output.append(inst) 1040 if inst.opname == "LOAD_GLOBAL" and inst.argval == "super": 1041 nexti = instructions[idx + 1] 1042 if nexti.arg == 0 and ( 1043 (sys.version_info >= (3, 12) and nexti.opname == "CALL") 1044 or ( 1045 sys.version_info >= (3, 11) 1046 and sys.version_info < (3, 12) 1047 and nexti.opname == "PRECALL" 1048 ) 1049 or (sys.version_info < (3, 11) and nexti.opname == "CALL_FUNCTION") 1050 ): 1051 assert "__class__" in cell_and_free 1052 output.append(create_instruction("LOAD_DEREF", argval="__class__")) 1053 first_var = code.co_varnames[0] 1054 if first_var in cell_and_free: 1055 output.append(create_instruction("LOAD_DEREF", argval=first_var)) 1056 else: 1057 output.append(create_instruction("LOAD_FAST", argval=first_var)) 1058 nexti.arg = 2 1059 nexti.argval = 2 1060 if nexti.opname == "PRECALL": 1061 # also update the following CALL instruction 1062 call_inst = instructions[idx + 2] 1063 call_inst.arg = 2 1064 call_inst.argval = 2 1065 1066 instructions[:] = output 1067 1068 1069def fix_extended_args(instructions: List[Instruction]) -> int: 1070 """Fill in correct argvals for EXTENDED_ARG ops""" 1071 output: List[Instruction] = [] 1072 1073 def maybe_pop_n(n): 1074 for _ in range(n): 1075 if output and output[-1].opcode == dis.EXTENDED_ARG: 1076 output.pop() 1077 1078 for inst in instructions: 1079 if inst.opcode == dis.EXTENDED_ARG: 1080 # Leave this instruction alone for now so we never shrink code 1081 inst.arg = 0 1082 elif inst.arg and inst.arg > 0xFFFFFF: 1083 maybe_pop_n(3) 1084 output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 24)) 1085 output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) 1086 output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) 1087 elif inst.arg and inst.arg > 0xFFFF: 1088 maybe_pop_n(2) 1089 output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16)) 1090 output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) 1091 elif inst.arg and inst.arg > 0xFF: 1092 maybe_pop_n(1) 1093 output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8)) 1094 output.append(inst) 1095 1096 added = len(output) - len(instructions) 1097 assert added >= 0 1098 instructions[:] = output 1099 return added 1100 1101 1102def instruction_size(inst) -> int: 1103 import torch 1104 1105 if sys.version_info >= (3, 11): 1106 return 2 * (torch._C._dynamo.eval_frame.py_opcode_caches[inst.opcode] + 1) 1107 return 2 1108 1109 1110def check_offsets(instructions) -> None: 1111 offset = 0 1112 for inst in instructions: 1113 assert inst.offset == offset 1114 offset += instruction_size(inst) 1115 1116 1117def update_offsets(instructions) -> None: 1118 offset = 0 1119 for inst in instructions: 1120 inst.offset = offset 1121 offset += instruction_size(inst) 1122 1123 1124def debug_bytes(*args) -> str: 1125 index = range(max(map(len, args))) 1126 result = [] 1127 for arg in ( 1128 [index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]] 1129 ): 1130 result.append(" ".join(f"{x:03}" for x in arg)) 1131 1132 return "bytes mismatch\n" + "\n".join(result) 1133 1134 1135def debug_checks(code): 1136 """Make sure our assembler produces same bytes as we start with""" 1137 dode = transform_code_object(code, lambda x, y: None, safe=True) 1138 assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code) 1139 assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab) 1140 1141 1142HAS_LOCAL = set(dis.haslocal) 1143HAS_NAME = set(dis.hasname) 1144HAS_FREE = set(dis.hasfree) 1145HAS_CONST = set(dis.hasconst) 1146 1147 1148def get_const_index(code_options, val) -> int: 1149 for i, v in enumerate(code_options["co_consts"]): 1150 # NOTE: stronger comparison is required, since we have 1151 # examples where two values compare equal but have 1152 # different semantic meaning in some cases, e.g. 1153 # 0.0 == -0.0 but have different effects in torch.copysign. 1154 if val is v: 1155 return i 1156 code_options["co_consts"] += (val,) 1157 return len(code_options["co_consts"]) - 1 1158 1159 1160def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None): 1161 # compute instruction arg from argval if arg is not provided 1162 names = {name: idx for idx, name in enumerate(code_options["co_names"])} 1163 1164 def get_name_index(name) -> int: 1165 try: 1166 idx = names[name] 1167 except KeyError: 1168 # Add a missing item to co_names 1169 idx = names[name] = len(names) 1170 code_options["co_names"] = (*code_options["co_names"], name) 1171 assert len(code_options["co_names"]) == len(names) 1172 return idx 1173 1174 if sys.version_info < (3, 11): 1175 assert varname_from_oparg is None 1176 varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])} 1177 freenames = { 1178 name: idx 1179 for idx, name in enumerate( 1180 code_options["co_cellvars"] + code_options["co_freevars"] 1181 ) 1182 } 1183 else: 1184 assert callable(varname_from_oparg) 1185 allnames = {} 1186 for idx in itertools.count(): 1187 try: 1188 name = varname_from_oparg(idx) 1189 allnames[name] = idx 1190 except IndexError: 1191 break 1192 varnames = {name: allnames[name] for name in code_options["co_varnames"]} 1193 freenames = { 1194 name: allnames[name] 1195 for name in code_options["co_cellvars"] + code_options["co_freevars"] 1196 } 1197 for i in range(len(instructions)): 1198 1199 def should_compute_arg(): 1200 # argval is prioritized over arg 1201 return instructions[i].argval is not _NotProvided 1202 1203 if instructions[i].opname == "LOAD_GLOBAL": 1204 # 3.11 LOAD_GLOBAL requires both arg and argval - see create_instruction 1205 assert instructions[i].argval is not _NotProvided 1206 if sys.version_info >= (3, 11): 1207 assert instructions[i].arg is not None 1208 instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( 1209 cast(int, instructions[i].arg) % 2 1210 ) 1211 else: 1212 instructions[i].arg = get_name_index(instructions[i].argval) 1213 elif instructions[i].opname == "LOAD_ATTR": 1214 # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL 1215 assert instructions[i].argval is not _NotProvided 1216 if sys.version_info >= (3, 12): 1217 assert instructions[i].arg is not None 1218 instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( 1219 cast(int, instructions[i].arg) % 2 1220 ) 1221 else: 1222 instructions[i].arg = get_name_index(instructions[i].argval) 1223 elif instructions[i].opname == "LOAD_SUPER_ATTR": 1224 assert instructions[i].arg is not None 1225 assert instructions[i].argval is not _NotProvided 1226 # Copy low bit, force second bit on for explicit super (the "+ 2") 1227 instructions[i].arg = ( 1228 (get_name_index(instructions[i].argval) << 2) 1229 + (cast(int, instructions[i].arg) % 2) 1230 + 2 1231 ) 1232 elif instructions[i].opcode in HAS_LOCAL: 1233 if should_compute_arg(): 1234 if ( 1235 sys.version_info >= (3, 13) 1236 and instructions[i].argval not in varnames 1237 ): 1238 # instructions like LOAD_FAST used for both local and free vars 1239 instructions[i].arg = freenames[instructions[i].argval] 1240 else: 1241 instructions[i].arg = varnames[instructions[i].argval] 1242 elif instructions[i].opcode in HAS_NAME: 1243 if should_compute_arg(): 1244 instructions[i].arg = get_name_index(instructions[i].argval) 1245 elif instructions[i].opcode in HAS_FREE: 1246 if should_compute_arg(): 1247 instructions[i].arg = freenames[instructions[i].argval] 1248 elif instructions[i].opcode in HAS_CONST: 1249 # NOTE: only update argval if arg is not provided. This assumes 1250 # that any additions to co_consts are appended. 1251 if instructions[i].arg is None: 1252 # cannot use a dictionary since consts may not be hashable 1253 idx = get_const_index(code_options, instructions[i].argval) 1254 assert idx >= 0 1255 instructions[i].arg = idx 1256 1257 1258def clear_instruction_args(instructions): 1259 # Clear the instruction arg for instructions that have argvals. 1260 # Useful for using dis'd bytecode within generated bytecode. 1261 for inst in instructions: 1262 if ( 1263 inst.argval is not _NotProvided 1264 and ( 1265 inst.opcode in HAS_LOCAL 1266 or inst.opcode in HAS_NAME 1267 or inst.opcode in HAS_FREE 1268 or inst.opcode in HAS_CONST 1269 ) 1270 and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR") 1271 ): 1272 inst.arg = None 1273 1274 1275def get_code_keys() -> List[str]: 1276 # Python 3.11 changes to code keys are not fully documented. 1277 # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24 1278 # for new format. 1279 keys = ["co_argcount"] 1280 keys.append("co_posonlyargcount") 1281 keys.extend( 1282 [ 1283 "co_kwonlyargcount", 1284 "co_nlocals", 1285 "co_stacksize", 1286 "co_flags", 1287 "co_code", 1288 "co_consts", 1289 "co_names", 1290 "co_varnames", 1291 "co_filename", 1292 "co_name", 1293 ] 1294 ) 1295 if sys.version_info >= (3, 11): 1296 keys.append("co_qualname") 1297 keys.append("co_firstlineno") 1298 if sys.version_info >= (3, 10): 1299 keys.append("co_linetable") 1300 else: 1301 keys.append("co_lnotab") 1302 if sys.version_info >= (3, 11): 1303 # not documented, but introduced in https://github.com/python/cpython/issues/84403 1304 keys.append("co_exceptiontable") 1305 keys.extend( 1306 [ 1307 "co_freevars", 1308 "co_cellvars", 1309 ] 1310 ) 1311 return keys 1312 1313 1314def transform_code_object(code, transformations, safe=False) -> types.CodeType: 1315 keys = get_code_keys() 1316 code_options = {k: getattr(code, k) for k in keys} 1317 assert len(code_options["co_varnames"]) == code_options["co_nlocals"] 1318 1319 instructions = cleaned_instructions(code, safe) 1320 propagate_line_nums(instructions) 1321 1322 transformations(instructions, code_options) 1323 return clean_and_assemble_instructions(instructions, keys, code_options)[1] 1324 1325 1326def clean_and_assemble_instructions( 1327 instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any] 1328) -> Tuple[List[Instruction], types.CodeType]: 1329 # also implicitly checks for no duplicate instructions 1330 check_inst_exn_tab_entries_valid(instructions) 1331 1332 code_options["co_nlocals"] = len(code_options["co_varnames"]) 1333 varname_from_oparg = None 1334 if sys.version_info >= (3, 11): 1335 # temporary code object with updated names 1336 tmp_code = types.CodeType(*[code_options[k] for k in keys]) 1337 varname_from_oparg = tmp_code._varname_from_oparg # type: ignore[attr-defined] 1338 fix_vars(instructions, code_options, varname_from_oparg=varname_from_oparg) 1339 1340 dirty = True 1341 while dirty: 1342 update_offsets(instructions) 1343 devirtualize_jumps(instructions) 1344 # this pass might change offsets, if so we need to try again 1345 dirty = bool(fix_extended_args(instructions)) 1346 1347 remove_extra_line_nums(instructions) 1348 bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"]) 1349 if sys.version_info < (3, 10): 1350 code_options["co_lnotab"] = lnotab 1351 else: 1352 code_options["co_linetable"] = lnotab 1353 1354 code_options["co_code"] = bytecode 1355 code_options["co_stacksize"] = stacksize_analysis(instructions) 1356 assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - { 1357 "co_posonlyargcount" 1358 } 1359 if sys.version_info >= (3, 11): 1360 code_options["co_exceptiontable"] = assemble_exception_table( 1361 compute_exception_table(instructions) 1362 ) 1363 1364 return instructions, types.CodeType(*[code_options[k] for k in keys]) 1365 1366 1367def populate_kw_names_argval(instructions, consts): 1368 for inst in instructions: 1369 if inst.opname == "KW_NAMES": 1370 inst.argval = consts[inst.arg] 1371 1372 1373def cleaned_instructions(code, safe=False) -> List[Instruction]: 1374 instructions = list(map(convert_instruction, dis.get_instructions(code))) 1375 check_offsets(instructions) 1376 if sys.version_info >= (3, 11): 1377 populate_kw_names_argval(instructions, code.co_consts) 1378 virtualize_exception_table(code.co_exceptiontable, instructions) 1379 virtualize_jumps(instructions) 1380 strip_extended_args(instructions) 1381 if not safe: 1382 if sys.version_info < (3, 11): 1383 remove_load_call_method(instructions) 1384 if sys.version_info < (3, 12): 1385 explicit_super(code, instructions) 1386 if sys.version_info >= (3, 11): 1387 remove_jump_if_none(instructions) 1388 if sys.version_info >= (3, 12): 1389 remove_binary_store_slice(instructions) 1390 if sys.version_info >= (3, 13): 1391 remove_fused_load_store(instructions) 1392 update_offsets(instructions) 1393 devirtualize_jumps(instructions) 1394 return instructions 1395 1396 1397_unique_id_counter = itertools.count() 1398 1399 1400def unique_id(name) -> str: 1401 return f"{name}_{next(_unique_id_counter)}" 1402 1403 1404def is_generator(code: types.CodeType) -> bool: 1405 co_generator = 0x20 1406 return (code.co_flags & co_generator) > 0 1407 1408 1409def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True): 1410 """Generates bytecode from a template function `fn` for use in 1411 dynamo bytecode generation. 1412 1413 For example, we can generate Python-version-independent bytecode 1414 for looping through a dictionary and copying the values to a new dictionary. 1415 1416 def template(d1, d2): 1417 for k, v in d1.items(): 1418 d2[k] = v 1419 1420 1421 or a try block: 1422 1423 def template(): 1424 try: 1425 dummy1 1426 except: 1427 dummy2 1428 raise 1429 dummy3 1430 1431 Args: 1432 fn: a function template to generate bytecode from 1433 varname_map: a mapping of `fn`'s varnames to new names. This 1434 map will be applied to the generated bytecode's varnames. 1435 For example, local variables in `fn` can be replaced with 1436 new names that are generated by `OutputGraph.new_var`. 1437 noreturn: remove all RETURN_* bytecodes and replace them with a jump 1438 to the end of the bytecode. 1439 noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive). 1440 """ 1441 insts = cleaned_instructions(fn.__code__) 1442 clear_instruction_args(insts) 1443 1444 if noprefix: 1445 for i, inst in enumerate(insts): 1446 if inst.opname == "RESUME": 1447 insts = insts[i + 1 :] 1448 break 1449 1450 for inst in insts: 1451 # If we don't reset starts_line, then the generated 1452 # bytecode's line number will be based on fn's. 1453 inst.starts_line = None 1454 if varname_map and inst.argval in varname_map: 1455 inst.argval = varname_map[inst.argval] 1456 1457 if noreturn: 1458 if sys.version_info >= (3, 12): 1459 # replace RETURN_CONST with LOAD_CONST RETURN_VALUE 1460 new_insts = [] 1461 for inst in insts: 1462 if inst.opname == "RETURN_CONST": 1463 inst.opcode = dis.opmap["LOAD_CONST"] 1464 inst.opname = "LOAD_CONST" 1465 new_insts.append(inst) 1466 # no need to propagate target/exn table 1467 new_insts.append(create_instruction("RETURN_VALUE")) 1468 else: 1469 new_insts.append(inst) 1470 insts = new_insts 1471 1472 returns = [] 1473 for inst in insts: 1474 if inst.opname == "RETURN_VALUE": 1475 returns.append(inst) 1476 1477 if len(returns) == 1 and returns[0] is insts[-1]: 1478 # only 1 return at the end - just pop it 1479 insts.pop(-1) 1480 elif len(returns) > 0: 1481 # create jump target - if the last inst is a return, 1482 # we can replace it with a NOP and make that the jump target. 1483 if insts[-1] is returns[-1]: 1484 insts[-1].opname = "NOP" 1485 insts[-1].opcode = dis.opmap["NOP"] 1486 insts[-1].arg = None 1487 insts[-1].argval = _NotProvided 1488 returns.pop(-1) 1489 else: 1490 insts.append(create_instruction("NOP")) 1491 1492 # replace returns with jumps 1493 for inst in returns: 1494 # don't replace inst with new instruction 1495 # due to targetting/exn table/etc. 1496 jump_inst = create_jump_absolute(insts[-1]) 1497 inst.opname = jump_inst.opname 1498 inst.opcode = jump_inst.opcode 1499 inst.arg = jump_inst.arg 1500 inst.argval = jump_inst.argval 1501 inst.target = jump_inst.target 1502 1503 return insts 1504