xref: /aosp_15_r20/external/pytorch/torch/_dynamo/bytecode_analysis.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import bisect
3import dataclasses
4import dis
5import sys
6from typing import Any, Set, Union
7
8
9TERMINAL_OPCODES = {
10    dis.opmap["RETURN_VALUE"],
11    dis.opmap["JUMP_FORWARD"],
12    dis.opmap["RAISE_VARARGS"],
13    # TODO(jansel): double check exception handling
14}
15if sys.version_info >= (3, 9):
16    TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
17if sys.version_info >= (3, 11):
18    TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"])
19    TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
20else:
21    TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
22if sys.version_info >= (3, 12):
23    TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
24JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
25JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
26HASLOCAL = set(dis.haslocal)
27HASFREE = set(dis.hasfree)
28
29stack_effect = dis.stack_effect
30
31
32def get_indexof(insts):
33    """
34    Get a mapping from instruction memory address to index in instruction list.
35    Additionally checks that each instruction only appears once in the list.
36    """
37    indexof = {}
38    for i, inst in enumerate(insts):
39        assert inst not in indexof
40        indexof[inst] = i
41    return indexof
42
43
44def remove_dead_code(instructions):
45    """Dead code elimination"""
46    indexof = get_indexof(instructions)
47    live_code = set()
48
49    def find_live_code(start):
50        for i in range(start, len(instructions)):
51            if i in live_code:
52                return
53            live_code.add(i)
54            inst = instructions[i]
55            if inst.exn_tab_entry:
56                find_live_code(indexof[inst.exn_tab_entry.target])
57            if inst.opcode in JUMP_OPCODES:
58                find_live_code(indexof[inst.target])
59            if inst.opcode in TERMINAL_OPCODES:
60                return
61
62    find_live_code(0)
63
64    # change exception table entries if start/end instructions are dead
65    # assumes that exception table entries have been propagated,
66    # e.g. with bytecode_transformation.propagate_inst_exn_table_entries,
67    # and that instructions with an exn_tab_entry lies within its start/end.
68    if sys.version_info >= (3, 11):
69        live_idx = sorted(live_code)
70        for i, inst in enumerate(instructions):
71            if i in live_code and inst.exn_tab_entry:
72                # find leftmost live instruction >= start
73                start_idx = bisect.bisect_left(
74                    live_idx, indexof[inst.exn_tab_entry.start]
75                )
76                assert start_idx < len(live_idx)
77                # find rightmost live instruction <= end
78                end_idx = (
79                    bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1
80                )
81                assert end_idx >= 0
82                assert live_idx[start_idx] <= i <= live_idx[end_idx]
83                inst.exn_tab_entry.start = instructions[live_idx[start_idx]]
84                inst.exn_tab_entry.end = instructions[live_idx[end_idx]]
85
86    return [inst for i, inst in enumerate(instructions) if i in live_code]
87
88
89def remove_pointless_jumps(instructions):
90    """Eliminate jumps to the next instruction"""
91    pointless_jumps = {
92        id(a)
93        for a, b in zip(instructions, instructions[1:])
94        if a.opname == "JUMP_ABSOLUTE" and a.target is b
95    }
96    return [inst for inst in instructions if id(inst) not in pointless_jumps]
97
98
99def propagate_line_nums(instructions):
100    """Ensure every instruction has line number set in case some are removed"""
101    cur_line_no = None
102
103    def populate_line_num(inst):
104        nonlocal cur_line_no
105        if inst.starts_line:
106            cur_line_no = inst.starts_line
107
108        inst.starts_line = cur_line_no
109
110    for inst in instructions:
111        populate_line_num(inst)
112
113
114def remove_extra_line_nums(instructions):
115    """Remove extra starts line properties before packing bytecode"""
116
117    cur_line_no = None
118
119    def remove_line_num(inst):
120        nonlocal cur_line_no
121        if inst.starts_line is None:
122            return
123        elif inst.starts_line == cur_line_no:
124            inst.starts_line = None
125        else:
126            cur_line_no = inst.starts_line
127
128    for inst in instructions:
129        remove_line_num(inst)
130
131
132@dataclasses.dataclass
133class ReadsWrites:
134    reads: Set[Any]
135    writes: Set[Any]
136    visited: Set[Any]
137
138
139def livevars_analysis(instructions, instruction):
140    indexof = get_indexof(instructions)
141    must = ReadsWrites(set(), set(), set())
142    may = ReadsWrites(set(), set(), set())
143
144    def walk(state, start):
145        if start in state.visited:
146            return
147        state.visited.add(start)
148
149        for i in range(start, len(instructions)):
150            inst = instructions[i]
151            if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
152                if "LOAD" in inst.opname or "DELETE" in inst.opname:
153                    if inst.argval not in must.writes:
154                        state.reads.add(inst.argval)
155                elif "STORE" in inst.opname:
156                    state.writes.add(inst.argval)
157                elif inst.opname == "MAKE_CELL":
158                    pass
159                else:
160                    raise NotImplementedError(f"unhandled {inst.opname}")
161            if inst.exn_tab_entry:
162                walk(may, indexof[inst.exn_tab_entry.target])
163            if inst.opcode in JUMP_OPCODES:
164                walk(may, indexof[inst.target])
165                state = may
166            if inst.opcode in TERMINAL_OPCODES:
167                return
168
169    walk(must, indexof[instruction])
170    return must.reads | may.reads
171
172
173@dataclasses.dataclass
174class FixedPointBox:
175    value: bool = True
176
177
178@dataclasses.dataclass
179class StackSize:
180    low: Union[int, float]
181    high: Union[int, float]
182    fixed_point: FixedPointBox
183
184    def zero(self):
185        self.low = 0
186        self.high = 0
187        self.fixed_point.value = False
188
189    def offset_of(self, other, n):
190        prior = (self.low, self.high)
191        self.low = min(self.low, other.low + n)
192        self.high = max(self.high, other.high + n)
193        if (self.low, self.high) != prior:
194            self.fixed_point.value = False
195
196    def exn_tab_jump(self, depth):
197        prior = (self.low, self.high)
198        self.low = min(self.low, depth)
199        self.high = max(self.high, depth)
200        if (self.low, self.high) != prior:
201            self.fixed_point.value = False
202
203
204def stacksize_analysis(instructions) -> Union[int, float]:
205    assert instructions
206    fixed_point = FixedPointBox()
207    stack_sizes = {
208        inst: StackSize(float("inf"), float("-inf"), fixed_point)
209        for inst in instructions
210    }
211    stack_sizes[instructions[0]].zero()
212
213    for _ in range(100):
214        if fixed_point.value:
215            break
216        fixed_point.value = True
217
218        for inst, next_inst in zip(instructions, instructions[1:] + [None]):
219            stack_size = stack_sizes[inst]
220            # CALL_FINALLY in Python 3.8 is handled differently when determining stack depth.
221            # See https://github.com/python/cpython/blob/3.8/Python/compile.c#L5450.
222            # Essentially, the stack effect of CALL_FINALLY is computed with jump=True,
223            # but the resulting stack depth is propagated to the next instruction, not the
224            # jump target.
225            is_call_finally = (
226                sys.version_info < (3, 9) and inst.opcode == dis.opmap["CALL_FINALLY"]
227            )
228            if inst.opcode not in TERMINAL_OPCODES:
229                assert next_inst is not None, f"missing next inst: {inst}"
230                # total stack effect of CALL_FINALLY and END_FINALLY in 3.8 is 0
231                eff = (
232                    0
233                    if is_call_finally
234                    else stack_effect(inst.opcode, inst.arg, jump=False)
235                )
236                stack_sizes[next_inst].offset_of(stack_size, eff)
237            if inst.opcode in JUMP_OPCODES and not is_call_finally:
238                stack_sizes[inst.target].offset_of(
239                    stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
240                )
241            if inst.exn_tab_entry:
242                # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
243                # on why depth is computed this way.
244                depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1
245                stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth)
246
247    if False:
248        for inst in instructions:
249            stack_size = stack_sizes[inst]
250            print(stack_size.low, stack_size.high, inst)
251
252    low = min(x.low for x in stack_sizes.values())
253    high = max(x.high for x in stack_sizes.values())
254
255    assert fixed_point.value, "failed to reach fixed point"
256    assert low >= 0
257    return high
258