xref: /aosp_15_r20/external/pytorch/torch/_dynamo/codegen.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import dataclasses
4import re
5import sys
6import types
7from typing import Counter, Dict, List, Optional
8
9import torch.nn
10
11from . import utils
12from .bytecode_transformation import (
13    add_push_null,
14    add_push_null_call_function_ex,
15    create_call_function,
16    create_call_method,
17    create_dup_top,
18    create_instruction,
19    create_load_method,
20    create_rot_n,
21    Instruction,
22)
23from .exc import unimplemented
24from .source import AttrSource, Source
25from .utils import is_safe_constant, rot_n_helper
26from .variables.base import VariableTracker
27from .variables.nn_module import NNModuleVariable
28from .variables.tensor import (
29    NumpyNdarrayVariable,
30    SymNodeVariable,
31    TensorVariable,
32    UnspecializedPythonVariable,
33)
34from .variables.torch_function import TensorWithTFOverrideVariable
35
36
37@dataclasses.dataclass
38class GraphOutputEntry:
39    index: int
40    variable: VariableTracker
41
42
43class PyCodegen:
44    """
45    Helper class uses for constructing Python bytecode
46    """
47
48    def __init__(
49        self,
50        tx=None,
51        root: Optional[torch.nn.Module] = None,
52        graph_output_var: Optional[str] = None,
53        tempvars=None,
54    ) -> None:
55        self.root = root
56        self.top_of_stack: Optional[VariableTracker] = None
57        self.uses: Counter[VariableTracker] = collections.Counter()
58        self.graph_outputs: Dict[int, GraphOutputEntry] = {}
59        self._output: List[Instruction] = []
60        self.tempvars = tempvars or {}
61        self.tx = tx
62        self.graph_output_var = graph_output_var
63        self.code_options = self.tx.output.code_options
64        self.cell_and_freevars = self.tx.cell_and_freevars
65        self.new_var = self.tx.output.new_var
66        self.mutable_side_effects_from_source = False
67        self.value_from_source: bool = True
68
69    def restore_stack(self, stack_values, *, value_from_source=True):
70        prior = self.mutable_side_effects_from_source
71        self.mutable_side_effects_from_source = True
72        prev = self.value_from_source
73        self.value_from_source &= value_from_source
74        try:
75            self.foreach(stack_values)
76        finally:
77            self.mutable_side_effects_from_source = prior
78            self.value_from_source = prev
79
80    def graph_output_vars(self):
81        return [x.variable for x in self.graph_outputs.values()]
82
83    def call_reconstruct(self, value):
84        res = value.reconstruct(self)
85        assert res is None, f"reconstruct!=None {value}"
86
87    def add_push_null(self, gen_fn, call_function_ex=False):
88        """
89        `gen_fn` generates instructions via PyCodegen methods
90        that push a single callable to the stack.
91
92        `add_push_null` pushes a NULL to the stack before or after the
93        instructions generated by `gen_fn`, depending on Python version.
94
95        Will attempt to use the NULL push bit for instructions
96        with such bits (LOAD_GLOBAL 3.11+, LOAD_ATTR 3.12+, LOAD_SUPER_ATTR).
97        """
98        old_len = len(self._output)
99        if sys.version_info < (3, 13):
100            # gen_fn may DUP_TOP instead if TOS is not cleared.
101            # Will cause problems since NULL will be pushed right
102            # before the generated instructions in <= 3.12
103            self.clear_tos()
104        gen_fn()
105        # inplace modify self._output
106        added_insts = self._output[old_len:]
107        del self._output[old_len:]
108        if call_function_ex:
109            self._output.extend(add_push_null_call_function_ex(added_insts))
110        else:
111            self._output.extend(add_push_null(added_insts))
112        if sys.version_info >= (3, 13):
113            # NULL will be at top of stack
114            self.clear_tos()
115
116    def __call__(self, value, allow_cache=True):
117        """Generate code such that top-of-stack (TOS) is set to value"""
118        if isinstance(value, Source):
119            self.call_reconstruct(value)
120            self.clear_tos()
121            return
122
123        assert isinstance(value, VariableTracker)
124        output = self._output
125        graph_outputs = self.graph_outputs
126
127        if self.top_of_stack is value and allow_cache:
128            output.append(create_dup_top())
129            return
130
131        if self.mutable_side_effects_from_source:
132            # this is needed to get aliasing relationships right
133            # value.mutable_local.source will get mutated to hold `value`
134            # mutable_side_effects_from_source=False is used to codegen the mutation
135            # mutable_side_effects_from_source=True is used to codegen a reference
136            from .side_effects import MutableSideEffects
137
138            if isinstance(value.mutable_local, MutableSideEffects):
139                self(value.mutable_local.source)
140                return
141
142        if allow_cache:
143            if value.mutable_local and value.mutable_local in self.tempvars:
144                output.append(self.create_load(self.tempvars[value.mutable_local]))
145                self.top_of_stack = value
146                return
147            if self.tempvars.get(value) is not None:
148                output.append(self.create_load(self.tempvars[value]))
149                self.top_of_stack = value
150                return
151
152        if value.source is not None and allow_cache and self.value_from_source:
153            self.call_reconstruct(value.source)
154        elif value.is_python_constant() and is_safe_constant(
155            value.as_python_constant()
156        ):
157            output.append(self.create_load_const(value.as_python_constant()))
158        elif isinstance(value, TensorWithTFOverrideVariable):
159            graph_outputs_key = self.add_graph_output(value)
160
161            self.add_push_null(
162                lambda: self.load_import_from(utils.__name__, "to_subclass")
163            )
164            self.load_graph_output(graph_outputs[graph_outputs_key].index)
165            output.append(
166                self.create_load_global(
167                    value.global_mangled_class_name(self.tx), add=True
168                )
169            )
170            output.extend(create_call_function(2, False))
171        elif (
172            isinstance(value, SymNodeVariable)
173            and value.python_type() == float
174            and not self.tx.export
175        ):
176            # This is a little unusual; force the output convention to be a
177            # Tensor here.  Don't do this for export because this is
178            # apparently load bearing for export tests (but I am a bit
179            # doubtful it actually works in the real world)
180            # NB: It works to add_graph_output on a computed expression
181            # as_tensor here, because we memoize as_tensor calls on
182            # SymNodeVariable!
183            graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx))
184
185            def gen_fn():
186                self.load_graph_output(graph_outputs[graph_outputs_key].index)
187                output.append(self.create_load_attr("item"))
188
189            self.add_push_null(gen_fn)
190            output.extend(create_call_function(0, False))
191        elif isinstance(
192            value,
193            (
194                TensorVariable,
195                SymNodeVariable,
196                UnspecializedPythonVariable,
197                NumpyNdarrayVariable,
198            ),
199        ):
200            graph_outputs_key = self.add_graph_output(value)
201
202            if isinstance(value, NumpyNdarrayVariable):
203                self.add_push_null(
204                    lambda: self.load_import_from(utils.__name__, "to_numpy_helper")
205                )
206                self.load_graph_output(graph_outputs[graph_outputs_key].index)
207                output.extend(create_call_function(1, False))
208            elif isinstance(value, UnspecializedPythonVariable) and value.need_unwrap:
209
210                def gen_fn():
211                    self.load_graph_output(graph_outputs[graph_outputs_key].index)
212                    output.append(self.create_load_attr("item"))
213
214                self.add_push_null(gen_fn)
215                output.extend(create_call_function(0, False))
216            else:
217                self.load_graph_output(graph_outputs[graph_outputs_key].index)
218        elif isinstance(value, NNModuleVariable):
219            parts = value.module_key.split(".")
220            if parts[0] in self.code_options["co_varnames"]:
221                output.append(self.create_load(parts[0]))
222                parts = parts[1:]
223            else:
224                assert self.root is not None
225                output.append(self.create_load_output(self.root))
226            for part in parts:
227                output.append(self.create_load_attr(part))
228        else:
229            self.uses[value] += 1
230            try:
231                self.call_reconstruct(value)
232            except NotImplementedError:
233                unimplemented(f"reconstruct: {value}")
234            if allow_cache and value in self.tempvars:
235                self._output.append(create_dup_top())
236                self.add_cache(value)
237
238        self.top_of_stack = value
239
240    def add_graph_output(self, value):
241        graph_outputs_key = id(value.as_proxy())
242        if graph_outputs_key not in self.graph_outputs:
243            self.graph_outputs[graph_outputs_key] = GraphOutputEntry(
244                len(self.graph_outputs), value
245            )
246        return graph_outputs_key
247
248    def load_graph_output(self, index):
249        output = self._output
250        output.append(self.create_load(self.graph_output_var))
251        output.append(self._create_load_const(index))
252        output.append(create_instruction("BINARY_SUBSCR"))
253
254    def add_cache(self, value):
255        var = self.new_var()
256        self.tempvars[value] = var
257        if value.mutable_local:
258            self.tempvars[value.mutable_local] = var
259        self._output.append(self.create_store(var))
260
261    def foreach(self, items):
262        for i in items:
263            self(i)
264
265    def setup_globally_cached(self, name, value):
266        """Store value in a new global"""
267        name = re.sub(r"[^a-zA-Z0-9_]+", "_", name)
268        f_globals = self.tx.f_globals
269        if name in f_globals:
270            assert id(f_globals[name]) == id(value)
271        else:
272            f_globals[name] = value
273        return [self.create_load_global(name, add=True)]
274
275    def clear_tos(self):
276        self.top_of_stack = None
277
278    def append_output(self, inst):
279        assert isinstance(inst, Instruction)
280        self._output.append(inst)
281        self.clear_tos()
282
283    def extend_output(self, insts):
284        assert all(isinstance(x, Instruction) for x in insts)
285        self._output.extend(insts)
286        self.clear_tos()
287
288    def get_instructions(self) -> List[Instruction]:
289        return self._output
290
291    def create_load(self, name) -> Instruction:
292        if name in self.cell_and_freevars():
293            return create_instruction("LOAD_DEREF", argval=name)
294        assert name in self.code_options["co_varnames"], f"{name} missing"
295        return create_instruction("LOAD_FAST", argval=name)
296
297    def create_load_closure(self, name) -> Instruction:
298        assert name in self.cell_and_freevars()
299        inst_name = "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
300        return create_instruction(inst_name, argval=name)
301
302    def create_store(self, name) -> Instruction:
303        if name in self.cell_and_freevars():
304            return create_instruction("STORE_DEREF", argval=name)
305        assert name in self.code_options["co_varnames"]
306        return create_instruction("STORE_FAST", argval=name)
307
308    def create_load_global(self, name, add=False) -> Instruction:
309        if add:
310            self.tx.output.update_co_names(name)
311        assert name in self.code_options["co_names"], f"{name} not in co_names"
312        return create_instruction("LOAD_GLOBAL", argval=name)
313
314    def create_load_const(self, value) -> Instruction:
315        assert is_safe_constant(value), f"unsafe constant {value}"
316        return self._create_load_const(value)
317
318    def _create_load_const(self, value) -> Instruction:
319        return create_instruction("LOAD_CONST", argval=value)
320
321    create_load_output = _create_load_const
322
323    def load_method(self, name):
324        self.tx.output.update_co_names(name)
325        self.append_output(create_load_method(name))
326
327    def call_method(self, nargs):
328        self.extend_output(create_call_method(nargs))
329
330    def create_load_attr(self, name) -> Instruction:
331        if name not in self.code_options["co_names"]:
332            self.code_options["co_names"] += (name,)
333        return create_instruction("LOAD_ATTR", argval=name)
334
335    def load_attr(self, name):
336        self.append_output(self.create_load_attr(name))
337
338    def create_load_attrs(self, names):
339        return [self.create_load_attr(name) for name in names.split(".")]
340
341    def create_store_attr(self, name) -> Instruction:
342        if name not in self.code_options["co_names"]:
343            self.code_options["co_names"] += (name,)
344        return create_instruction("STORE_ATTR", argval=name)
345
346    def store_attr(self, name):
347        self.append_output(self.create_store_attr(name))
348
349    def load_function_name(self, fn_name, push_null, num_on_stack=0):
350        """Load the global fn_name on the stack num_on_stack down"""
351        output = []
352        if push_null and sys.version_info >= (3, 11):
353            output.extend(add_push_null(self.create_load_global(fn_name, add=True)))
354            if num_on_stack > 0:
355                output.extend(
356                    [
357                        *self.rot_n(num_on_stack + 2),
358                        *self.rot_n(num_on_stack + 2),
359                    ]
360                )
361        else:
362            output.extend(
363                [
364                    self.create_load_global(fn_name, add=True),
365                    *self.rot_n(num_on_stack + 1),
366                ]
367            )
368        return output
369
370    def rot_n(self, n):
371        try:
372            return create_rot_n(n)
373        except AttributeError:
374            # desired rotate bytecode doesn't exist, generate equivalent bytecode
375            return [
376                create_instruction("BUILD_TUPLE", arg=n),
377                self._create_load_const(rot_n_helper(n)),
378                *create_rot_n(2),
379                create_instruction("CALL_FUNCTION_EX", arg=0),
380                create_instruction("UNPACK_SEQUENCE", arg=n),
381            ]
382
383    def pop_null(self):
384        # POP_TOP doesn't work for null, so we pop nulls by pushing in a
385        # nop function, calling it (which consumes the null), and popping the result.
386        assert sys.version_info >= (3, 11)
387        return [
388            self._create_load_const(lambda: None),
389            # 3.13 swapped NULL and callable
390            *(
391                (create_instruction("SWAP", arg=2),)
392                if sys.version_info >= (3, 13)
393                else ()
394            ),
395            *create_call_function(0, False),
396            create_instruction("POP_TOP"),
397        ]
398
399    def pop_top(self):
400        self.append_output(create_instruction("POP_TOP"))
401
402    def call_function(self, nargs: int, push_null: bool):
403        self.extend_output(create_call_function(nargs, push_null=push_null))
404
405    def dup_top(self):
406        self.append_output(create_dup_top())
407
408    def store(self, varname):
409        self.append_output(self.create_store(varname))
410
411    def make_function_with_closure(
412        self, fn_name: str, code: types.CodeType, push_null: bool, num_on_stack=0
413    ):
414        freevars = code.co_freevars
415        assert freevars
416        output = self._output
417
418        def gen_fn():
419            for var in freevars:
420                assert var in self.cell_and_freevars()
421                inst_name = (
422                    "LOAD_FAST" if sys.version_info >= (3, 13) else "LOAD_CLOSURE"
423                )
424                output.append(create_instruction(inst_name, argval=var))
425            output.append(create_instruction("BUILD_TUPLE", arg=len(freevars)))
426            output.append(self.create_load_const(code))
427            if sys.version_info < (3, 11):
428                output.append(self.create_load_const(fn_name))
429            if sys.version_info >= (3, 13):
430                output.extend(
431                    [
432                        create_instruction("MAKE_FUNCTION"),
433                        create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
434                    ]
435                )
436            else:
437                output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
438
439        if push_null and sys.version_info >= (3, 11):
440            self.add_push_null(gen_fn)
441            output.extend(self.rot_n(num_on_stack + 2))
442            output.extend(self.rot_n(num_on_stack + 2))
443        else:
444            gen_fn()
445            output.extend(self.rot_n(num_on_stack + 1))
446        self.clear_tos()
447
448    def create_load_python_module(self, mod) -> Instruction:
449        """
450        Generate a LOAD_GLOBAL instruction to fetch a given python module.
451        """
452        output = self.tx.output
453        global_scope = output.global_scope
454        name = re.sub(r"^.*[.]", "", mod.__name__)
455        if global_scope.get(name, None) is mod:
456            return self.create_load_global(name, add=True)
457        prefix = f"___module_{name}"
458        global_name = self.tx.output.install_global_by_id(prefix, mod)
459        return self.create_load_global(global_name, add=True)
460
461    def make_call_generated_code(self, fn_name: str) -> None:
462        """Call the generated code function stored in fn_name"""
463        self.extend_output(self.load_function_name(fn_name, True))
464
465        graphargs = self.tx.output.graphargs
466        for arg in graphargs:
467            if arg.pass_arg_as_tensor:
468                self.add_push_null(
469                    lambda: self.extend_output(
470                        [
471                            self.create_load_python_module(torch),
472                            self.create_load_attr("as_tensor"),
473                        ]
474                    )
475                )
476                self.call_reconstruct(arg)
477                self.extend_output(create_call_function(1, False))
478            else:
479                self.call_reconstruct(arg)
480
481        self.extend_output(create_call_function(len(graphargs), False))
482
483    def load_import_from(self, module_name, object_name) -> None:
484        self(AttrSource(self.tx.import_source(module_name), object_name))
485
486    def create_call_function_kw(self, nargs, kw_names, push_null) -> List[Instruction]:
487        if sys.version_info >= (3, 13):
488            output = create_call_function(nargs, push_null)
489            assert output[-1].opname == "CALL"
490            output.insert(-1, self.create_load_const(kw_names))
491            output[-1] = create_instruction("CALL_KW", arg=nargs)
492            return output
493        elif sys.version_info >= (3, 11):
494            output = create_call_function(nargs, push_null)
495            if sys.version_info >= (3, 12):
496                idx = -1
497                expected_inst = "CALL"
498            else:
499                idx = -2
500                expected_inst = "PRECALL"
501            assert output[idx].opname == expected_inst
502            kw_names_inst = create_instruction("KW_NAMES", argval=kw_names)
503            output.insert(idx, kw_names_inst)
504            return output
505        return [
506            self.create_load_const(kw_names),
507            create_instruction("CALL_FUNCTION_KW", arg=nargs),
508        ]
509
510    def create_delete(self, value) -> Instruction:
511        return create_instruction("DELETE_FAST", argval=value)
512