xref: /aosp_15_r20/external/pytorch/torch/_dynamo/bytecode_transformation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import copy
3import dataclasses
4import dis
5import itertools
6import sys
7import types
8from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Tuple, Union
9
10from .bytecode_analysis import (
11    get_indexof,
12    propagate_line_nums,
13    remove_extra_line_nums,
14    stacksize_analysis,
15)
16
17
18@dataclasses.dataclass
19class InstructionExnTabEntry:
20    start: "Instruction"
21    end: "Instruction"
22    target: "Instruction"
23    depth: int
24    lasti: bool
25
26    def __repr__(self) -> str:
27        return (
28            f"InstructionExnTabEntry(start={self.start.short_inst_repr()}, "
29            f"end={self.end.short_inst_repr()}, "
30            f"target={self.target.short_inst_repr()}, "
31            f"depth={self.depth}, lasti={self.lasti})"
32        )
33
34    def __eq__(self, o) -> bool:
35        return (
36            self.start is o.start
37            and self.end is o.end
38            and self.target is o.target
39            and self.depth == o.depth
40            and self.lasti == o.lasti
41        )
42
43
44@dataclasses.dataclass
45class Instruction:
46    """A mutable version of dis.Instruction"""
47
48    opcode: int
49    opname: str
50    arg: Optional[int]
51    argval: Any
52    offset: Optional[int] = None
53    starts_line: Optional[int] = None
54    is_jump_target: bool = False
55    positions: Optional["dis.Positions"] = None
56    # extra fields to make modification easier:
57    target: Optional["Instruction"] = None
58    exn_tab_entry: Optional[InstructionExnTabEntry] = None
59
60    def __hash__(self) -> int:
61        return id(self)
62
63    def __eq__(self, other) -> bool:
64        return id(self) == id(other)
65
66    def short_inst_repr(self) -> str:
67        return f"Instruction(opname={self.opname}, offset={self.offset})"
68
69
70def convert_instruction(i: dis.Instruction) -> Instruction:
71    if sys.version_info >= (3, 13):
72        starts_line = i.line_number
73    else:
74        starts_line = i.starts_line
75    return Instruction(
76        i.opcode,
77        i.opname,
78        i.arg,
79        i.argval,
80        i.offset,
81        starts_line,
82        i.is_jump_target,
83        getattr(i, "positions", None),
84    )
85
86
87class _NotProvided:
88    def __repr__(self) -> str:
89        return "_NotProvided"
90
91
92def inst_has_op_bits(name):
93    return (sys.version_info >= (3, 11) and name == "LOAD_GLOBAL") or (
94        sys.version_info >= (3, 12) and name in ("LOAD_ATTR", "LOAD_SUPER_ATTR")
95    )
96
97
98def create_instruction(
99    name, *, arg=None, argval=_NotProvided, target=None
100) -> Instruction:
101    """
102    At most one of `arg`, `argval`, and `target` can be not None/_NotProvided.
103    This is to prevent ambiguity, e.g. does
104        create_instruction("LOAD_CONST", 5)
105    mean load the constant at co_consts[5], or load the constant 5?
106
107    If `arg` is not provided, it will be computed during assembly from
108    `argval` or `target`.
109
110    Bits in the args of instructions LOAD_GLOBAL, LOAD_ATTR (3.12+), and LOAD_SUPER_ATTR
111    modify the behavior of the instruction. In this case, we allow both `arg`
112    and `argval` to be set. The value of `arg` here is expected to be the value of
113    the op bits and the true value of `arg` will be computed during assembly.
114    If `arg` is not set, the bits are assumed to be 0.
115    """
116
117    # allow for instructions with op bits to have both arg and argval specified
118    if inst_has_op_bits(name):
119        if target is not None:
120            raise RuntimeError("target cannot be specified for instruction")
121        if arg is None:
122            arg = 0
123    else:
124        cnt = (arg is not None) + (argval is not _NotProvided) + (target is not None)
125        if cnt > 1:
126            raise RuntimeError(
127                "only one of arg, argval, and target can be not None/_NotProvided"
128            )
129    if arg is not None and not isinstance(arg, int):
130        raise RuntimeError("instruction arg must be int or None")
131    return Instruction(
132        opcode=dis.opmap[name], opname=name, arg=arg, argval=argval, target=target
133    )
134
135
136# Python 3.11 remaps
137def create_jump_absolute(target) -> Instruction:
138    inst = "JUMP_FORWARD" if sys.version_info >= (3, 11) else "JUMP_ABSOLUTE"
139    return create_instruction(inst, target=target)
140
141
142def create_dup_top() -> Instruction:
143    if sys.version_info >= (3, 11):
144        return create_instruction("COPY", arg=1)
145    return create_instruction("DUP_TOP")
146
147
148def create_rot_n(n) -> List[Instruction]:
149    """
150    Returns a "simple" sequence of instructions that rotates TOS to the n-th
151    position in the stack. For Python < 3.11, returns a single ROT_*
152    instruction. If no such instruction exists, an error is raised and the
153    caller is expected to generate an equivalent sequence of instructions.
154    For Python >= 3.11, any rotation can be expressed as a simple sequence of
155    swaps.
156    """
157    if n <= 1:
158        # don't rotate
159        return []
160
161    if sys.version_info >= (3, 11):
162        # rotate can be expressed as a sequence of swap operations
163        # e.g. rotate 3 is equivalent to swap 3, swap 2
164        return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)]
165
166    # ensure desired rotate function exists
167    if sys.version_info < (3, 8) and n >= 4:
168        raise AttributeError(f"rotate {n} not supported for Python < 3.8")
169    if sys.version_info < (3, 10) and n >= 5:
170        raise AttributeError(f"rotate {n} not supported for Python < 3.10")
171
172    if n <= 4:
173        return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])]
174    return [create_instruction("ROT_N", arg=n)]
175
176
177def add_push_null(
178    inst_or_insts: Union[Instruction, List[Instruction]],
179) -> List[Instruction]:
180    """
181    Appends or prepends a PUSH_NULL instruction to `inst_or_insts`,
182    depending on Python version. Used when you know that
183    `inst_or_insts` generates a callable that will be called.
184
185    NOTE: Assumes `inst_or_insts` is a single instruction or sequence of
186    instructions that pushes exactly 1 object to the stack that is to
187    be called. It is important that you include ALL instructions that
188    construct the callable - not just the first instruction/a prefix.
189
190    Will attempt to use the NULL push bit for instructions
191    with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
192    In this case, instructions WILL be modified.
193    """
194    if isinstance(inst_or_insts, Instruction):
195        insts = [inst_or_insts]
196    else:
197        insts = inst_or_insts
198
199    def inst_has_bit_set(idx):
200        assert insts[idx].arg is not None
201        return insts[idx].arg & 1 == 1
202
203    def set_inst_bit(idx):
204        assert insts[idx].arg is not None
205        insts[idx].arg |= 1
206
207    if sys.version_info >= (3, 13):
208        # In 3.13, NULL follows the callable
209        if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1):
210            # All insts with op bits have the push_null bit as the last one.
211            # Only set the bit if it hasn't been set - otherwise, we need
212            # to add another PUSH_NULL.
213            set_inst_bit(-1)
214        else:
215            insts = insts + [create_instruction("PUSH_NULL")]
216    elif sys.version_info >= (3, 12):
217        # LOAD_ATTR/LOAD_SUPER_ATTR at the end
218        # We assume that `insts` will only load 1 object, so
219        # LOAD_GLOBAL at the end doesn't need to be checked
220        if inst_has_op_bits(insts[-1].opname) and not inst_has_bit_set(-1):
221            set_inst_bit(-1)
222        elif insts[0].opname == "LOAD_GLOBAL" and not inst_has_bit_set(0):
223            set_inst_bit(0)
224        else:
225            insts = [create_instruction("PUSH_NULL")] + insts
226    elif sys.version_info >= (3, 11):
227        # 3.11 introduced NULL preceding callable
228        if inst_has_op_bits(insts[0].opname) and not inst_has_bit_set(0):
229            set_inst_bit(0)
230        else:
231            insts = [create_instruction("PUSH_NULL")] + insts
232    return insts
233
234
235def add_push_null_call_function_ex(
236    inst_or_insts: Union[Instruction, List[Instruction]],
237) -> List[Instruction]:
238    """Like add_push_null, but the low bit of LOAD_ATTR/LOAD_SUPER_ATTR
239    is not set, due to an expected CALL_FUNCTION_EX instruction.
240    """
241    if isinstance(inst_or_insts, Instruction):
242        insts = [inst_or_insts]
243    else:
244        insts = inst_or_insts
245
246    if sys.version_info < (3, 11):
247        return insts
248
249    idx = -1 if sys.version_info >= (3, 13) else 0
250    if insts[idx].opname == "LOAD_GLOBAL":
251        assert insts[idx].arg is not None
252        if insts[idx].arg & 1 == 0:  # type: ignore[operator]
253            insts[idx].arg |= 1  # type: ignore[operator]
254            return insts
255
256    if sys.version_info >= (3, 13):
257        insts = insts + [create_instruction("PUSH_NULL")]
258    else:
259        insts = [create_instruction("PUSH_NULL")] + insts
260
261    return insts
262
263
264def create_call_function(nargs, push_null) -> List[Instruction]:
265    """
266    Creates a sequence of instructions that makes a function call.
267
268    `push_null` is used in Python 3.11+ only. It is used in codegen when
269    a function call is intended to be made with the NULL + fn convention,
270    and we know that the NULL has not been pushed yet. We will push a
271    NULL and rotate it to the correct position immediately before making
272    the function call.
273
274    `push_null` should be True if no NULL is pushed for the callable.
275    Conversely, `push_null` should be False if a NULL was pushed for the callable.
276    Prefer using `push_null=False` when possible since we will not need to rotate
277    NULL to the right place, which is less efficient.
278
279    Generally, you should codegen a function by using `add_push_null` then
280    `create_call_function` with `push_null=False`.
281
282    Example of when to set push_null False:
283
284    insts = [
285        create_instruction("LOAD_GLOBAL", argval="torch"),
286        create_instruction("LOAD_ATTR", argval="nn"),
287        create_instruction("LOAD_ATTR", argval="functional"),
288        create_instruction("LOAD_ATTR", argval="relu"),
289    ]
290    insts = add_push_null(insts)
291    insts.append(create_instruction("LOAD_FAST", argval="x"))
292    insts.extend(create_call_function(1, False))
293
294    Example of when to set push_null True:
295
296    insts = [create_instruction("LOAD_FAST", x)]
297    for should_wrap, wrapper_name in wrappers:
298        if should_wrap:
299            insts.extend([
300                create_instruction("LOAD_GLOBAL", argval="wrapper1"),
301                create_instruction("SWAP", arg=2),
302                *create_call_function(1, True),
303            )
304    """
305    if sys.version_info >= (3, 11):
306        output = []
307        if push_null:
308            output.append(create_instruction("PUSH_NULL"))
309            # 3.13 swapped NULL and callable
310            rots = nargs + 1 if sys.version_info >= (3, 13) else nargs + 2
311            output.extend(create_rot_n(rots))
312        if sys.version_info < (3, 12):
313            output.append(create_instruction("PRECALL", arg=nargs))
314        output.append(create_instruction("CALL", arg=nargs))
315        return output
316    return [create_instruction("CALL_FUNCTION", arg=nargs)]
317
318
319def create_call_method(nargs) -> List[Instruction]:
320    if sys.version_info >= (3, 12):
321        return [create_instruction("CALL", arg=nargs)]
322    if sys.version_info >= (3, 11):
323        return [
324            create_instruction("PRECALL", arg=nargs),
325            create_instruction("CALL", arg=nargs),
326        ]
327    return [create_instruction("CALL_METHOD", arg=nargs)]
328
329
330def create_load_method(name) -> Instruction:
331    if sys.version_info >= (3, 12):
332        # in 3.12, create a LOAD_ATTR instruction with the low bit set
333        return create_instruction("LOAD_ATTR", arg=1, argval=name)
334    return create_instruction("LOAD_METHOD", argval=name)
335
336
337def create_setup_with(target) -> Instruction:
338    opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH"
339    return create_instruction(opname, target=target)
340
341
342def create_swap(n) -> List[Instruction]:
343    if sys.version_info >= (3, 11):
344        return [create_instruction("SWAP", arg=n)]
345    # in Python < 3.11, SWAP is a macro that expands to multiple instructions
346    if n == 1:
347        return []
348    """
349    e.g. swap "a" and "b" in this stack:
350    0 a 1 2 3 b
351    0 a [1 2 3 b]
352    0 a [1 2 3 b] [1 2 3 b]
353    0 a [1 2 3 b] [1 2 3 b] -1
354    0 a [1 2 3 b] b
355    0 b a [1 2 3 b]
356    0 b a [1 2 3 b] [1 2 3 b]
357    0 b [1 2 3 b] a [1 2 3 b]
358    0 b [1 2 3 b] a [1 2 3 b] -1
359    0 b [1 2 3 a]
360    0 b [1 2 3 a] [1 2 3 a]
361    0 b [1 2 3 a] [1 2 3 a] reverse
362    0 b [a 3 2 1] None
363    0 b [a 3 2 1]
364    0 b 1 2 3 a
365    """
366    return [
367        create_instruction("BUILD_LIST", arg=n - 1),
368        create_instruction("DUP_TOP"),
369        create_instruction("LOAD_CONST", argval=-1),
370        create_instruction("BINARY_SUBSCR"),
371        create_instruction("ROT_THREE"),
372        create_instruction("DUP_TOP"),
373        create_instruction("ROT_THREE"),
374        create_instruction("LOAD_CONST", argval=-1),
375        create_instruction("STORE_SUBSCR"),
376        create_instruction("DUP_TOP"),
377        create_load_method("reverse"),
378        *create_call_method(0),
379        create_instruction("POP_TOP"),
380        create_instruction("UNPACK_SEQUENCE", arg=n - 1),
381    ]
382
383
384def lnotab_writer(
385    lineno: int, byteno: int = 0
386) -> Tuple[List[int], Callable[[int, int], None]]:
387    """
388    Used to create typing.CodeType.co_lnotab
389    See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
390    This is the internal format of the line number table if Python < 3.10
391    """
392    assert sys.version_info < (3, 10)
393    lnotab: List[int] = []
394
395    def update(lineno_new, byteno_new):
396        nonlocal byteno, lineno
397        while byteno_new != byteno or lineno_new != lineno:
398            byte_offset = max(0, min(byteno_new - byteno, 255))
399            line_offset = max(-128, min(lineno_new - lineno, 127))
400            assert byte_offset != 0 or line_offset != 0
401            byteno += byte_offset
402            lineno += line_offset
403            lnotab.extend((byte_offset, line_offset & 0xFF))
404
405    return lnotab, update
406
407
408def linetable_310_writer(first_lineno):
409    """
410    Used to create typing.CodeType.co_linetable
411    See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
412    This is the internal format of the line number table for Python 3.10
413    """
414    assert sys.version_info >= (3, 10) and sys.version_info < (3, 11)
415    linetable: List[int] = []
416    lineno = first_lineno
417    lineno_delta = 0
418    byteno = 0
419
420    def _update(byteno_delta, lineno_delta):
421        while byteno_delta != 0 or lineno_delta != 0:
422            byte_offset = max(0, min(byteno_delta, 254))
423            line_offset = max(-127, min(lineno_delta, 127))
424            assert byte_offset != 0 or line_offset != 0
425            byteno_delta -= byte_offset
426            lineno_delta -= line_offset
427            linetable.extend((byte_offset, line_offset & 0xFF))
428
429    def update(lineno_new, byteno_new):
430        nonlocal lineno, lineno_delta, byteno
431        byteno_delta = byteno_new - byteno
432        byteno = byteno_new
433        _update(byteno_delta, lineno_delta)
434        lineno_delta = lineno_new - lineno
435        lineno = lineno_new
436
437    def end(total_bytes):
438        _update(total_bytes - byteno, lineno_delta)
439
440    return linetable, update, end
441
442
443def encode_varint(n: int) -> List[int]:
444    """
445    6-bit chunk encoding of an unsigned integer
446    See https://github.com/python/cpython/blob/3.11/Objects/locations.md
447    """
448    assert n >= 0
449    b = [n & 63]
450    n >>= 6
451    while n > 0:
452        b[-1] |= 64
453        b.append(n & 63)
454        n >>= 6
455    return b
456
457
458def linetable_311_writer(first_lineno: int):
459    """
460    Used to create typing.CodeType.co_linetable
461    See https://github.com/python/cpython/blob/3.11/Objects/locations.md
462    This is the internal format of the line number table for Python 3.11
463    """
464    assert sys.version_info >= (3, 11)
465    linetable = []
466    lineno = first_lineno
467
468    def update(positions: "dis.Positions", inst_size):
469        nonlocal lineno
470        lineno_new = positions.lineno if positions else None
471
472        def _update(delta, size):
473            assert 0 < size <= 8
474            # first byte - use 13 (no column info) is positions is
475            # malformed, otherwise use 14 (long form)
476            other_varints: Tuple[int, ...] = ()
477            if (
478                positions
479                and positions.lineno is not None
480                and positions.end_lineno is not None
481                and positions.col_offset is not None
482                and positions.end_col_offset is not None
483            ):
484                linetable.append(0b1_1110_000 + size - 1)
485                # for whatever reason, column offset needs `+ 1`
486                # https://github.com/python/cpython/blob/1931c2a438c50e6250725c84dff94fc760b9b951/Python/compile.c#L7603
487                other_varints = (
488                    positions.end_lineno - positions.lineno,
489                    positions.col_offset + 1,
490                    positions.end_col_offset + 1,
491                )
492            else:
493                linetable.append(0b1_1101_000 + size - 1)
494            # encode signed int
495            if delta < 0:
496                delta = ((-delta) << 1) | 1
497            else:
498                delta <<= 1
499            # encode unsigned int
500            linetable.extend(encode_varint(delta))
501            for n in other_varints:
502                linetable.extend(encode_varint(n))
503
504        if lineno_new is None:
505            lineno_delta = 0
506        else:
507            lineno_delta = lineno_new - lineno
508            lineno = lineno_new
509        while inst_size > 8:
510            _update(lineno_delta, 8)
511            inst_size -= 8
512        _update(lineno_delta, inst_size)
513
514    return linetable, update
515
516
517@dataclasses.dataclass
518class ExceptionTableEntry:
519    start: int
520    end: int
521    target: int
522    depth: int
523    lasti: bool
524
525
526def encode_exception_table_varint(n: int) -> List[int]:
527    """
528    Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse.
529    """
530    assert n >= 0
531    b = [n & 63]
532    n >>= 6
533    while n > 0:
534        b.append(n & 63)
535        n >>= 6
536    b.reverse()
537    for i in range(len(b) - 1):
538        b[i] |= 64
539    return b
540
541
542def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int:
543    """
544    Inverse of `encode_exception_table_varint`.
545    """
546    b = next(bytes_iter)
547    val = b & 63
548    while b & 64:
549        val <<= 6
550        b = next(bytes_iter)
551        val |= b & 63
552    return val
553
554
555def check_exception_table(tab: List[ExceptionTableEntry]) -> None:
556    """
557    Verifies that a list of ExceptionTableEntries will make a well-formed
558    jump table: entries are non-empty, sorted, and do not overlap.
559    """
560    for i in range(len(tab) - 1):
561        assert (
562            tab[i].start <= tab[i].end
563            and tab[i].end < tab[i + 1].start
564            and tab[i + 1].start <= tab[i + 1].end
565        )
566
567
568def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]:
569    """
570    Parse the exception table according to
571    https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
572    """
573    exntab_iter = iter(exntab)
574    tab = []
575    try:
576        while True:
577            start = decode_exception_table_varint(exntab_iter) * 2
578            length = decode_exception_table_varint(exntab_iter) * 2
579            end = start + length - 2
580            target = decode_exception_table_varint(exntab_iter) * 2
581            dl = decode_exception_table_varint(exntab_iter)
582            depth = dl >> 1
583            lasti = bool(dl & 1)
584            tab.append(ExceptionTableEntry(start, end, target, depth, lasti))
585    except StopIteration:
586        check_exception_table(tab)
587        return tab
588
589
590def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
591    """
592    Inverse of parse_exception_table - encodes list of exception
593    table entries into bytes.
594    """
595    b = []
596    for entry in tab:
597        first_entry = encode_exception_table_varint(entry.start // 2)
598        first_entry[0] |= 1 << 7
599        b.extend(first_entry)
600        length = entry.end - entry.start + 2
601        b.extend(encode_exception_table_varint(length // 2))
602        b.extend(encode_exception_table_varint(entry.target // 2))
603        dl = (entry.depth << 1) + entry.lasti
604        b.extend(encode_exception_table_varint(dl))
605    return bytes(b)
606
607
608def assemble(instructions: List[Instruction], firstlineno: int) -> Tuple[bytes, bytes]:
609    """Do the opposite of dis.get_instructions()"""
610    code: List[int] = []
611    if sys.version_info >= (3, 11):
612        lnotab, update_lineno = linetable_311_writer(firstlineno)
613        num_ext = 0
614        for i, inst in enumerate(instructions):
615            if inst.opname == "EXTENDED_ARG":
616                inst_size = 1
617                num_ext += 1
618                # copy positions from the actual instruction
619                for j in (1, 2, 3):
620                    if instructions[i + j].opname != "EXTENDED_ARG":
621                        inst.positions = instructions[i + j].positions
622                        break
623            else:
624                inst_size = instruction_size(inst) // 2 + num_ext
625                num_ext = 0
626            update_lineno(inst.positions, inst_size)
627            num_ext = 0
628            arg = inst.arg or 0
629            code.extend((inst.opcode, arg & 0xFF))
630            for _ in range(instruction_size(inst) // 2 - 1):
631                code.extend((0, 0))
632    else:
633        if sys.version_info < (3, 10):
634            lnotab, update_lineno = lnotab_writer(firstlineno)
635        else:
636            lnotab, update_lineno, end = linetable_310_writer(firstlineno)
637
638        for inst in instructions:
639            if inst.starts_line is not None:
640                update_lineno(inst.starts_line, len(code))
641            arg = inst.arg or 0
642            code.extend((inst.opcode, arg & 0xFF))
643
644        if sys.version_info >= (3, 10):
645            end(len(code))
646
647    return bytes(code), bytes(lnotab)
648
649
650def _get_instruction_by_offset(offset_to_inst: Dict[int, Instruction], offset: int):
651    """
652    Get the instruction located at a given offset, accounting for EXTENDED_ARGs
653    """
654    for n in (0, 2, 4, 6):
655        if offset_to_inst[offset + n].opcode != dis.EXTENDED_ARG:
656            return offset_to_inst[offset + n]
657    return None
658
659
660def virtualize_jumps(instructions) -> None:
661    """Replace jump targets with pointers to make editing easier"""
662    jump_targets = {inst.offset: inst for inst in instructions}
663
664    for inst in instructions:
665        if inst.opcode in dis.hasjabs or inst.opcode in dis.hasjrel:
666            inst.target = _get_instruction_by_offset(jump_targets, inst.argval)
667
668
669_REL_JUMPS = set(dis.hasjrel)
670
671
672def flip_jump_direction(instruction: Instruction) -> None:
673    if sys.version_info < (3, 11):
674        raise RuntimeError("Cannot flip jump direction in Python < 3.11")
675    if "FORWARD" in instruction.opname:
676        instruction.opname = instruction.opname.replace("FORWARD", "BACKWARD")
677    elif "BACKWARD" in instruction.opname:
678        instruction.opname = instruction.opname.replace("BACKWARD", "FORWARD")
679    else:
680        raise AttributeError("Instruction is not a forward or backward jump")
681    instruction.opcode = dis.opmap[instruction.opname]
682    assert instruction.opcode in _REL_JUMPS
683
684
685def _get_instruction_front(instructions: List[Instruction], idx: int):
686    """
687    i.e. get the first EXTENDED_ARG instruction (if any) when targeting
688    instructions[idx] with a jump.
689    """
690    target = instructions[idx]
691    for offset in (1, 2, 3):
692        if idx >= offset and instructions[idx - offset].opcode == dis.EXTENDED_ARG:
693            target = instructions[idx - offset]
694        else:
695            break
696    return target
697
698
699def devirtualize_jumps(instructions):
700    """Fill in args for virtualized jump target after instructions may have moved"""
701    jumps = set(dis.hasjabs).union(set(dis.hasjrel))
702
703    # check for negative jump args and fix them
704    for inst in instructions:
705        if inst.opcode in jumps:
706            if inst.opcode not in dis.hasjabs:
707                if inst.target.offset < inst.offset:
708                    if sys.version_info < (3, 11):
709                        raise RuntimeError("Got negative jump offset for Python < 3.11")
710                    # forward jumps become backward
711                    if "FORWARD" in inst.opname:
712                        flip_jump_direction(inst)
713                else:
714                    # backward jumps become forward
715                    if sys.version_info >= (3, 11) and "BACKWARD" in inst.opname:
716                        flip_jump_direction(inst)
717
718    # jump instruction size may have changed due to flips
719    update_offsets(instructions)
720    indexof = get_indexof(instructions)
721
722    # compute jump instruction arg
723    for inst in instructions:
724        if inst.opcode in jumps:
725            target = _get_instruction_front(instructions, indexof[inst.target])
726            if inst.opcode in dis.hasjabs:
727                if sys.version_info < (3, 10):
728                    inst.arg = target.offset
729                elif sys.version_info < (3, 11):
730                    # `arg` is expected to be bytecode offset, whereas `offset` is byte offset.
731                    # Divide since bytecode is 2 bytes large.
732                    inst.arg = int(target.offset / 2)
733                else:
734                    raise RuntimeError("Python 3.11+ should not have absolute jumps")
735            else:  # relative jump
736                # byte offset between target and next instruction
737                inst.arg = abs(
738                    int(target.offset - inst.offset - instruction_size(inst))
739                )
740                if sys.version_info >= (3, 10):
741                    # see bytecode size comment in the absolute jump case above
742                    inst.arg //= 2
743            inst.argval = target.offset
744            inst.argrepr = f"to {target.offset}"
745
746
747def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruction]):
748    """Replace exception table entries with pointers to make editing easier"""
749    exn_tab = parse_exception_table(exn_tab_bytes)
750    offset_to_inst = {cast(int, inst.offset): inst for inst in instructions}
751    offsets = sorted(offset_to_inst.keys())
752    end_offset_idx = 0
753    exn_tab_iter = iter(exn_tab)
754    try:
755
756        def step():
757            nonlocal end_offset_idx
758            entry = next(exn_tab_iter)
759            # find rightmost offset <= entry.end, since entry.end may not be
760            # an actual instruction, e.g. if the end instruction is LOAD_GLOBAL,
761            # which takes more than 2 bytes, then entry.end points to the end
762            # of the LOAD_GLOBAL instruction, not the beginning.
763            while (
764                end_offset_idx < len(offsets) and offsets[end_offset_idx] <= entry.end
765            ):
766                end_offset_idx += 1
767            assert end_offset_idx > 0
768            end_offset = offsets[end_offset_idx - 1]
769            inst_entry = InstructionExnTabEntry(
770                _get_instruction_by_offset(offset_to_inst, entry.start),
771                _get_instruction_by_offset(offset_to_inst, end_offset),
772                _get_instruction_by_offset(offset_to_inst, entry.target),
773                entry.depth,
774                entry.lasti,
775            )
776            return entry, inst_entry
777
778        entry, inst_entry = step()
779        for inst in instructions:
780            while inst.offset > entry.end:
781                entry, inst_entry = step()
782            if inst.offset >= entry.start:
783                inst.exn_tab_entry = copy.copy(inst_entry)
784    except StopIteration:
785        pass
786
787
788def compute_exception_table(
789    instructions: List[Instruction],
790) -> List[ExceptionTableEntry]:
791    """Compute exception table in list format from instructions with exn_tab_entries"""
792    exn_dict: Dict[Tuple[int, int], Tuple[int, int, bool]] = {}
793    indexof = get_indexof(instructions)
794
795    for inst in instructions:
796        if inst.exn_tab_entry:
797            # account for prefixed EXTENDED_ARGS
798            start = _get_instruction_front(
799                instructions, indexof[inst.exn_tab_entry.start]
800            ).offset
801            # point to the last 2 bytes of the end instruction
802            end = (
803                cast(int, inst.exn_tab_entry.end.offset)
804                + instruction_size(inst.exn_tab_entry.end)
805                - 2
806            )
807            target = _get_instruction_front(
808                instructions, indexof[inst.exn_tab_entry.target]
809            ).offset
810            key = (start, end)
811            val = (target, inst.exn_tab_entry.depth, inst.exn_tab_entry.lasti)
812            if key in exn_dict:
813                assert exn_dict[key] == val
814            exn_dict[key] = val
815
816    # Dynamo may construct nested exception table entries for convenience,
817    # but Python expects exception table entries to not overlap.
818    # NOTE: below, "keys" refer to old instruction entries' starts and ends,
819    # and "entries" refer to the generated exception table entries.
820
821    # Sort keys by increasing start, then decreasing end
822    keys_sorted = sorted(exn_dict.keys(), key=lambda t: (t[0], -t[1]))
823    # smallest byte that the next exception table entry can start at
824    nexti = 0
825    # stack of current nested keys
826    key_stack: List[Tuple[int, int]] = []
827    exn_tab: List[ExceptionTableEntry] = []
828
829    def pop():
830        """
831        Pop the key_stack and append an exception table entry if possible.
832        """
833        nonlocal nexti
834        if key_stack:
835            key = key_stack.pop()
836            if nexti <= key[1]:
837                exn_tab.append(
838                    ExceptionTableEntry(max(key[0], nexti), key[1], *exn_dict[key])
839                )
840                nexti = key[1] + 2
841
842    for key in keys_sorted:
843        # pop keys that are no longer nested over the current key
844        while key_stack and key_stack[-1][1] < key[0]:
845            pop()
846        if key_stack:
847            # create an entry covering to the current key, if possible
848            assert key_stack[-1][0] <= key[0] <= key[1] <= key_stack[-1][1]
849            left = max(nexti, key_stack[-1][0])
850            if left < key[0]:
851                exn_tab.append(
852                    ExceptionTableEntry(left, key[0] - 2, *exn_dict[key_stack[-1]])
853                )
854            nexti = key[0]
855        key_stack.append(key)
856    while key_stack:
857        pop()
858    check_exception_table(exn_tab)
859    return exn_tab
860
861
862def check_inst_exn_tab_entries_nested(
863    tab: List[InstructionExnTabEntry], indexof
864) -> None:
865    """
866    Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's,
867    i.e. no entries partially overlap.
868    "Properly sorted" means entries are sorted by increasing starts, then
869    decreasing ends.
870    """
871    entry_stack: List[Tuple[int, int]] = []
872    for entry in tab:
873        key = (indexof[entry.start], indexof[entry.end])
874        while entry_stack and entry_stack[-1][1] < key[0]:
875            entry_stack.pop()
876        if entry_stack:
877            assert entry_stack[-1][0] <= key[0] <= key[1] <= entry_stack[-1][1]
878        entry_stack.append(key)
879
880
881def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
882    """
883    Copies exception table entries to all instructions in an entry's range.
884    Supports nested exception table entries.
885    """
886    indexof = get_indexof(instructions)
887    entries: Dict[Tuple[int, int], InstructionExnTabEntry] = {}
888    for inst in instructions:
889        if inst.exn_tab_entry:
890            key = (
891                indexof[inst.exn_tab_entry.start],
892                indexof[inst.exn_tab_entry.end],
893            )
894            if key in entries:
895                assert inst.exn_tab_entry == entries[key]
896            entries[key] = inst.exn_tab_entry
897    sorted_entries = [
898        entries[key] for key in sorted(entries.keys(), key=lambda t: (t[0], -t[1]))
899    ]
900    check_inst_exn_tab_entries_nested(sorted_entries, indexof)
901    # Propagation of nested entries works since nested entries come later
902    # in sorted order.
903    for entry in sorted_entries:
904        for i in range(indexof[entry.start], indexof[entry.end] + 1):
905            instructions[i].exn_tab_entry = copy.copy(entry)
906
907
908def check_inst_exn_tab_entries_valid(instructions: List[Instruction]):
909    """
910    Checks that exn_tab_entries of instructions are valid.
911    An entry's start, end, and target must be in instructions.
912    Instructions with an exn_tab_entry are located within
913    the entry's start and end instructions.
914    Instructions do not share exn_tab_entries.
915
916    Implicitly checks for no duplicate instructions.
917    """
918    indexof = get_indexof(instructions)
919    exn_tab_entry_set = set()
920    for i, inst in enumerate(instructions):
921        if inst.exn_tab_entry:
922            assert sys.version_info >= (3, 11)
923            assert id(inst.exn_tab_entry) not in exn_tab_entry_set
924            exn_tab_entry_set.add(id(inst.exn_tab_entry))
925            entry = inst.exn_tab_entry
926            assert entry.start in indexof
927            assert entry.end in indexof
928            assert entry.target in indexof
929            assert indexof[entry.start] <= i <= indexof[entry.end]
930
931
932def strip_extended_args(instructions: List[Instruction]) -> None:
933    instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
934
935
936def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
937    """LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
938    assert sys.version_info < (3, 11)
939    rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
940    for inst in instructions:
941        if inst.opname in rewrites:
942            inst.opname = rewrites[inst.opname]
943            inst.opcode = dis.opmap[inst.opname]
944    return instructions
945
946
947def remove_jump_if_none(instructions: List[Instruction]) -> None:
948    new_insts = []
949    for inst in instructions:
950        new_insts.append(inst)
951        if "_NONE" in inst.opname:
952            is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
953            is_op.argval = is_op.arg
954            is_op.positions = inst.positions
955            if sys.version_info < (3, 12):
956                jump_op = create_instruction(
957                    "POP_JUMP_FORWARD_IF_TRUE"
958                    if "FORWARD" in inst.opname
959                    else "POP_JUMP_BACKWARD_IF_TRUE",
960                    target=inst.target,
961                )
962            else:
963                jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
964            jump_op.positions = inst.positions
965            # update inst.exn_tab_entry.end if necessary
966            if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
967                inst.exn_tab_entry.end = jump_op
968            # preserve exception table entries
969            is_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
970            jump_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
971            # modify inst in-place to preserve jump target
972            inst.opcode = dis.opmap["LOAD_CONST"]
973            inst.opname = "LOAD_CONST"
974            inst.arg = None
975            inst.argval = None
976            new_insts.extend([is_op, jump_op])
977    instructions[:] = new_insts
978
979
980def remove_binary_store_slice(instructions: List[Instruction]) -> None:
981    new_insts = []
982    for inst in instructions:
983        new_insts.append(inst)
984        if inst.opname in ("BINARY_SLICE", "STORE_SLICE"):
985            # new instruction
986            subscr_inst = create_instruction(inst.opname.replace("SLICE", "SUBSCR"))
987            if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
988                inst.exn_tab_entry.end = subscr_inst
989            subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
990            subscr_inst.positions = inst.positions
991            # modify inst in-place to preserve jump target
992            inst.opcode = dis.opmap["BUILD_SLICE"]
993            inst.opname = "BUILD_SLICE"
994            inst.arg = 2
995            inst.argval = 2
996            new_insts.append(subscr_inst)
997    instructions[:] = new_insts
998
999
1000FUSED_INSTS = {
1001    "LOAD_FAST_LOAD_FAST": ("LOAD_FAST", "LOAD_FAST"),
1002    "STORE_FAST_STORE_FAST": ("STORE_FAST", "STORE_FAST"),
1003    "STORE_FAST_LOAD_FAST": ("STORE_FAST", "LOAD_FAST"),
1004}
1005
1006
1007def remove_fused_load_store(instructions: List[Instruction]) -> None:
1008    new_insts = []
1009    for inst in instructions:
1010        new_insts.append(inst)
1011        if inst.opname in FUSED_INSTS:
1012            inst0, inst1 = FUSED_INSTS[inst.opname]
1013            argval0, argval1 = inst.argval
1014
1015            # modify inst in-place to preserve jump target
1016            inst.opcode = dis.opmap[inst0]
1017            inst.opname = inst0
1018            inst.argval = argval0
1019
1020            new_inst = create_instruction(inst1, argval=argval1)
1021            # update inst.exn_tab_entry.end if necessary
1022            if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
1023                inst.exn_tab_entry.end = new_inst
1024            # preserve exception table entries
1025            new_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
1026
1027            new_insts.append(new_inst)
1028    instructions[:] = new_insts
1029
1030
1031def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> None:
1032    """convert super() with no args into explicit arg form"""
1033    cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
1034    if not len(code.co_varnames):
1035        # A function with no argument cannot contain a valid "super()" call
1036        return
1037    output = []
1038    for idx, inst in enumerate(instructions):
1039        output.append(inst)
1040        if inst.opname == "LOAD_GLOBAL" and inst.argval == "super":
1041            nexti = instructions[idx + 1]
1042            if nexti.arg == 0 and (
1043                (sys.version_info >= (3, 12) and nexti.opname == "CALL")
1044                or (
1045                    sys.version_info >= (3, 11)
1046                    and sys.version_info < (3, 12)
1047                    and nexti.opname == "PRECALL"
1048                )
1049                or (sys.version_info < (3, 11) and nexti.opname == "CALL_FUNCTION")
1050            ):
1051                assert "__class__" in cell_and_free
1052                output.append(create_instruction("LOAD_DEREF", argval="__class__"))
1053                first_var = code.co_varnames[0]
1054                if first_var in cell_and_free:
1055                    output.append(create_instruction("LOAD_DEREF", argval=first_var))
1056                else:
1057                    output.append(create_instruction("LOAD_FAST", argval=first_var))
1058                nexti.arg = 2
1059                nexti.argval = 2
1060                if nexti.opname == "PRECALL":
1061                    # also update the following CALL instruction
1062                    call_inst = instructions[idx + 2]
1063                    call_inst.arg = 2
1064                    call_inst.argval = 2
1065
1066    instructions[:] = output
1067
1068
1069def fix_extended_args(instructions: List[Instruction]) -> int:
1070    """Fill in correct argvals for EXTENDED_ARG ops"""
1071    output: List[Instruction] = []
1072
1073    def maybe_pop_n(n):
1074        for _ in range(n):
1075            if output and output[-1].opcode == dis.EXTENDED_ARG:
1076                output.pop()
1077
1078    for inst in instructions:
1079        if inst.opcode == dis.EXTENDED_ARG:
1080            # Leave this instruction alone for now so we never shrink code
1081            inst.arg = 0
1082        elif inst.arg and inst.arg > 0xFFFFFF:
1083            maybe_pop_n(3)
1084            output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 24))
1085            output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16))
1086            output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8))
1087        elif inst.arg and inst.arg > 0xFFFF:
1088            maybe_pop_n(2)
1089            output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 16))
1090            output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8))
1091        elif inst.arg and inst.arg > 0xFF:
1092            maybe_pop_n(1)
1093            output.append(create_instruction("EXTENDED_ARG", arg=inst.arg >> 8))
1094        output.append(inst)
1095
1096    added = len(output) - len(instructions)
1097    assert added >= 0
1098    instructions[:] = output
1099    return added
1100
1101
1102def instruction_size(inst) -> int:
1103    import torch
1104
1105    if sys.version_info >= (3, 11):
1106        return 2 * (torch._C._dynamo.eval_frame.py_opcode_caches[inst.opcode] + 1)
1107    return 2
1108
1109
1110def check_offsets(instructions) -> None:
1111    offset = 0
1112    for inst in instructions:
1113        assert inst.offset == offset
1114        offset += instruction_size(inst)
1115
1116
1117def update_offsets(instructions) -> None:
1118    offset = 0
1119    for inst in instructions:
1120        inst.offset = offset
1121        offset += instruction_size(inst)
1122
1123
1124def debug_bytes(*args) -> str:
1125    index = range(max(map(len, args)))
1126    result = []
1127    for arg in (
1128        [index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]]
1129    ):
1130        result.append(" ".join(f"{x:03}" for x in arg))
1131
1132    return "bytes mismatch\n" + "\n".join(result)
1133
1134
1135def debug_checks(code):
1136    """Make sure our assembler produces same bytes as we start with"""
1137    dode = transform_code_object(code, lambda x, y: None, safe=True)
1138    assert code.co_code == dode.co_code, debug_bytes(code.co_code, dode.co_code)
1139    assert code.co_lnotab == dode.co_lnotab, debug_bytes(code.co_lnotab, dode.co_lnotab)
1140
1141
1142HAS_LOCAL = set(dis.haslocal)
1143HAS_NAME = set(dis.hasname)
1144HAS_FREE = set(dis.hasfree)
1145HAS_CONST = set(dis.hasconst)
1146
1147
1148def get_const_index(code_options, val) -> int:
1149    for i, v in enumerate(code_options["co_consts"]):
1150        # NOTE: stronger comparison is required, since we have
1151        # examples where two values compare equal but have
1152        # different semantic meaning in some cases, e.g.
1153        # 0.0 == -0.0 but have different effects in torch.copysign.
1154        if val is v:
1155            return i
1156    code_options["co_consts"] += (val,)
1157    return len(code_options["co_consts"]) - 1
1158
1159
1160def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None):
1161    # compute instruction arg from argval if arg is not provided
1162    names = {name: idx for idx, name in enumerate(code_options["co_names"])}
1163
1164    def get_name_index(name) -> int:
1165        try:
1166            idx = names[name]
1167        except KeyError:
1168            # Add a missing item to co_names
1169            idx = names[name] = len(names)
1170            code_options["co_names"] = (*code_options["co_names"], name)
1171            assert len(code_options["co_names"]) == len(names)
1172        return idx
1173
1174    if sys.version_info < (3, 11):
1175        assert varname_from_oparg is None
1176        varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])}
1177        freenames = {
1178            name: idx
1179            for idx, name in enumerate(
1180                code_options["co_cellvars"] + code_options["co_freevars"]
1181            )
1182        }
1183    else:
1184        assert callable(varname_from_oparg)
1185        allnames = {}
1186        for idx in itertools.count():
1187            try:
1188                name = varname_from_oparg(idx)
1189                allnames[name] = idx
1190            except IndexError:
1191                break
1192        varnames = {name: allnames[name] for name in code_options["co_varnames"]}
1193        freenames = {
1194            name: allnames[name]
1195            for name in code_options["co_cellvars"] + code_options["co_freevars"]
1196        }
1197    for i in range(len(instructions)):
1198
1199        def should_compute_arg():
1200            # argval is prioritized over arg
1201            return instructions[i].argval is not _NotProvided
1202
1203        if instructions[i].opname == "LOAD_GLOBAL":
1204            # 3.11 LOAD_GLOBAL requires both arg and argval - see create_instruction
1205            assert instructions[i].argval is not _NotProvided
1206            if sys.version_info >= (3, 11):
1207                assert instructions[i].arg is not None
1208                instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + (
1209                    cast(int, instructions[i].arg) % 2
1210                )
1211            else:
1212                instructions[i].arg = get_name_index(instructions[i].argval)
1213        elif instructions[i].opname == "LOAD_ATTR":
1214            # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL
1215            assert instructions[i].argval is not _NotProvided
1216            if sys.version_info >= (3, 12):
1217                assert instructions[i].arg is not None
1218                instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + (
1219                    cast(int, instructions[i].arg) % 2
1220                )
1221            else:
1222                instructions[i].arg = get_name_index(instructions[i].argval)
1223        elif instructions[i].opname == "LOAD_SUPER_ATTR":
1224            assert instructions[i].arg is not None
1225            assert instructions[i].argval is not _NotProvided
1226            # Copy low bit, force second bit on for explicit super (the "+ 2")
1227            instructions[i].arg = (
1228                (get_name_index(instructions[i].argval) << 2)
1229                + (cast(int, instructions[i].arg) % 2)
1230                + 2
1231            )
1232        elif instructions[i].opcode in HAS_LOCAL:
1233            if should_compute_arg():
1234                if (
1235                    sys.version_info >= (3, 13)
1236                    and instructions[i].argval not in varnames
1237                ):
1238                    # instructions like LOAD_FAST used for both local and free vars
1239                    instructions[i].arg = freenames[instructions[i].argval]
1240                else:
1241                    instructions[i].arg = varnames[instructions[i].argval]
1242        elif instructions[i].opcode in HAS_NAME:
1243            if should_compute_arg():
1244                instructions[i].arg = get_name_index(instructions[i].argval)
1245        elif instructions[i].opcode in HAS_FREE:
1246            if should_compute_arg():
1247                instructions[i].arg = freenames[instructions[i].argval]
1248        elif instructions[i].opcode in HAS_CONST:
1249            # NOTE: only update argval if arg is not provided. This assumes
1250            # that any additions to co_consts are appended.
1251            if instructions[i].arg is None:
1252                # cannot use a dictionary since consts may not be hashable
1253                idx = get_const_index(code_options, instructions[i].argval)
1254                assert idx >= 0
1255                instructions[i].arg = idx
1256
1257
1258def clear_instruction_args(instructions):
1259    # Clear the instruction arg for instructions that have argvals.
1260    # Useful for using dis'd bytecode within generated bytecode.
1261    for inst in instructions:
1262        if (
1263            inst.argval is not _NotProvided
1264            and (
1265                inst.opcode in HAS_LOCAL
1266                or inst.opcode in HAS_NAME
1267                or inst.opcode in HAS_FREE
1268                or inst.opcode in HAS_CONST
1269            )
1270            and inst.opname not in ("LOAD_GLOBAL", "LOAD_ATTR", "LOAD_SUPER_ATTR")
1271        ):
1272            inst.arg = None
1273
1274
1275def get_code_keys() -> List[str]:
1276    # Python 3.11 changes to code keys are not fully documented.
1277    # See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
1278    # for new format.
1279    keys = ["co_argcount"]
1280    keys.append("co_posonlyargcount")
1281    keys.extend(
1282        [
1283            "co_kwonlyargcount",
1284            "co_nlocals",
1285            "co_stacksize",
1286            "co_flags",
1287            "co_code",
1288            "co_consts",
1289            "co_names",
1290            "co_varnames",
1291            "co_filename",
1292            "co_name",
1293        ]
1294    )
1295    if sys.version_info >= (3, 11):
1296        keys.append("co_qualname")
1297    keys.append("co_firstlineno")
1298    if sys.version_info >= (3, 10):
1299        keys.append("co_linetable")
1300    else:
1301        keys.append("co_lnotab")
1302    if sys.version_info >= (3, 11):
1303        # not documented, but introduced in https://github.com/python/cpython/issues/84403
1304        keys.append("co_exceptiontable")
1305    keys.extend(
1306        [
1307            "co_freevars",
1308            "co_cellvars",
1309        ]
1310    )
1311    return keys
1312
1313
1314def transform_code_object(code, transformations, safe=False) -> types.CodeType:
1315    keys = get_code_keys()
1316    code_options = {k: getattr(code, k) for k in keys}
1317    assert len(code_options["co_varnames"]) == code_options["co_nlocals"]
1318
1319    instructions = cleaned_instructions(code, safe)
1320    propagate_line_nums(instructions)
1321
1322    transformations(instructions, code_options)
1323    return clean_and_assemble_instructions(instructions, keys, code_options)[1]
1324
1325
1326def clean_and_assemble_instructions(
1327    instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any]
1328) -> Tuple[List[Instruction], types.CodeType]:
1329    # also implicitly checks for no duplicate instructions
1330    check_inst_exn_tab_entries_valid(instructions)
1331
1332    code_options["co_nlocals"] = len(code_options["co_varnames"])
1333    varname_from_oparg = None
1334    if sys.version_info >= (3, 11):
1335        # temporary code object with updated names
1336        tmp_code = types.CodeType(*[code_options[k] for k in keys])
1337        varname_from_oparg = tmp_code._varname_from_oparg  # type: ignore[attr-defined]
1338    fix_vars(instructions, code_options, varname_from_oparg=varname_from_oparg)
1339
1340    dirty = True
1341    while dirty:
1342        update_offsets(instructions)
1343        devirtualize_jumps(instructions)
1344        # this pass might change offsets, if so we need to try again
1345        dirty = bool(fix_extended_args(instructions))
1346
1347    remove_extra_line_nums(instructions)
1348    bytecode, lnotab = assemble(instructions, code_options["co_firstlineno"])
1349    if sys.version_info < (3, 10):
1350        code_options["co_lnotab"] = lnotab
1351    else:
1352        code_options["co_linetable"] = lnotab
1353
1354    code_options["co_code"] = bytecode
1355    code_options["co_stacksize"] = stacksize_analysis(instructions)
1356    assert set(keys) - {"co_posonlyargcount"} == set(code_options.keys()) - {
1357        "co_posonlyargcount"
1358    }
1359    if sys.version_info >= (3, 11):
1360        code_options["co_exceptiontable"] = assemble_exception_table(
1361            compute_exception_table(instructions)
1362        )
1363
1364    return instructions, types.CodeType(*[code_options[k] for k in keys])
1365
1366
1367def populate_kw_names_argval(instructions, consts):
1368    for inst in instructions:
1369        if inst.opname == "KW_NAMES":
1370            inst.argval = consts[inst.arg]
1371
1372
1373def cleaned_instructions(code, safe=False) -> List[Instruction]:
1374    instructions = list(map(convert_instruction, dis.get_instructions(code)))
1375    check_offsets(instructions)
1376    if sys.version_info >= (3, 11):
1377        populate_kw_names_argval(instructions, code.co_consts)
1378        virtualize_exception_table(code.co_exceptiontable, instructions)
1379    virtualize_jumps(instructions)
1380    strip_extended_args(instructions)
1381    if not safe:
1382        if sys.version_info < (3, 11):
1383            remove_load_call_method(instructions)
1384        if sys.version_info < (3, 12):
1385            explicit_super(code, instructions)
1386    if sys.version_info >= (3, 11):
1387        remove_jump_if_none(instructions)
1388        if sys.version_info >= (3, 12):
1389            remove_binary_store_slice(instructions)
1390        if sys.version_info >= (3, 13):
1391            remove_fused_load_store(instructions)
1392        update_offsets(instructions)
1393        devirtualize_jumps(instructions)
1394    return instructions
1395
1396
1397_unique_id_counter = itertools.count()
1398
1399
1400def unique_id(name) -> str:
1401    return f"{name}_{next(_unique_id_counter)}"
1402
1403
1404def is_generator(code: types.CodeType) -> bool:
1405    co_generator = 0x20
1406    return (code.co_flags & co_generator) > 0
1407
1408
1409def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
1410    """Generates bytecode from a template function `fn` for use in
1411    dynamo bytecode generation.
1412
1413    For example, we can generate Python-version-independent bytecode
1414    for looping through a dictionary and copying the values to a new dictionary.
1415
1416    def template(d1, d2):
1417        for k, v in d1.items():
1418            d2[k] = v
1419
1420
1421    or a try block:
1422
1423    def template():
1424        try:
1425            dummy1
1426        except:
1427            dummy2
1428            raise
1429        dummy3
1430
1431    Args:
1432        fn: a function template to generate bytecode from
1433        varname_map: a mapping of `fn`'s varnames to new names. This
1434            map will be applied to the generated bytecode's varnames.
1435            For example, local variables in `fn` can be replaced with
1436            new names that are generated by `OutputGraph.new_var`.
1437        noreturn: remove all RETURN_* bytecodes and replace them with a jump
1438            to the end of the bytecode.
1439        noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
1440    """
1441    insts = cleaned_instructions(fn.__code__)
1442    clear_instruction_args(insts)
1443
1444    if noprefix:
1445        for i, inst in enumerate(insts):
1446            if inst.opname == "RESUME":
1447                insts = insts[i + 1 :]
1448                break
1449
1450    for inst in insts:
1451        # If we don't reset starts_line, then the generated
1452        # bytecode's line number will be based on fn's.
1453        inst.starts_line = None
1454        if varname_map and inst.argval in varname_map:
1455            inst.argval = varname_map[inst.argval]
1456
1457    if noreturn:
1458        if sys.version_info >= (3, 12):
1459            # replace RETURN_CONST with LOAD_CONST RETURN_VALUE
1460            new_insts = []
1461            for inst in insts:
1462                if inst.opname == "RETURN_CONST":
1463                    inst.opcode = dis.opmap["LOAD_CONST"]
1464                    inst.opname = "LOAD_CONST"
1465                    new_insts.append(inst)
1466                    # no need to propagate target/exn table
1467                    new_insts.append(create_instruction("RETURN_VALUE"))
1468                else:
1469                    new_insts.append(inst)
1470            insts = new_insts
1471
1472        returns = []
1473        for inst in insts:
1474            if inst.opname == "RETURN_VALUE":
1475                returns.append(inst)
1476
1477        if len(returns) == 1 and returns[0] is insts[-1]:
1478            # only 1 return at the end - just pop it
1479            insts.pop(-1)
1480        elif len(returns) > 0:
1481            # create jump target - if the last inst is a return,
1482            # we can replace it with a NOP and make that the jump target.
1483            if insts[-1] is returns[-1]:
1484                insts[-1].opname = "NOP"
1485                insts[-1].opcode = dis.opmap["NOP"]
1486                insts[-1].arg = None
1487                insts[-1].argval = _NotProvided
1488                returns.pop(-1)
1489            else:
1490                insts.append(create_instruction("NOP"))
1491
1492            # replace returns with jumps
1493            for inst in returns:
1494                # don't replace inst with new instruction
1495                # due to targetting/exn table/etc.
1496                jump_inst = create_jump_absolute(insts[-1])
1497                inst.opname = jump_inst.opname
1498                inst.opcode = jump_inst.opcode
1499                inst.arg = jump_inst.arg
1500                inst.argval = jump_inst.argval
1501                inst.target = jump_inst.target
1502
1503    return insts
1504