xref: /aosp_15_r20/external/pytorch/torch/_dynamo/comptime.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# This file establishes the public comptime interface to Dynamo.
3# This allows Dynamo users to execute arbitrary Python code while
4# Dynamo is symbolically evaluating their original programs.
5#
6# The goal of the public API is to give users rope, without actually
7# leaking private implementation details of Dynamo.
8
9import builtins
10import dis
11import time
12import traceback
13from typing import Optional, Union
14
15import torch
16from torch.fx.experimental.symbolic_shapes import free_symbols
17
18from .exc import unimplemented
19from .variables import NewCellVariable
20from .variables.constant import ConstantVariable
21from .variables.misc import ClosureVariable
22from .variables.tensor import SymNodeVariable
23
24
25class ComptimeVar:
26    """
27    A ComptimeVar represents a Python value, at some particular point
28    in time, in the Python code we are symbolically evaluating with
29    torchdynamo.  This must be distinguished from a runtime value, as
30    at compile-time there are some properties of the variable we
31    do not know (for example, if the ComptimeVar represents a Tensor,
32    we only know metadata about the tensor; we do NOT know what the
33    actual data in the Tensor is.)
34    """
35
36    def __init__(self, v) -> None:
37        self.__variable = v
38
39    def as_proxy(self):
40        """
41        Returns an fx.Proxy (or tuple/list of fx.Proxy) representing
42        this variable in the FX graph we are assembling to pass
43        to the user compiler.
44
45        This method only works for variables we actually track in
46        the FX graph, aka Tensors (and ints, if you are compiling
47        with dynamic shapes).  In particular, if you have a list
48        or tuple of tensors, you will get a list/tuple of proxies
49        (not a single proxy representing the entire list/tuple).
50        """
51        return self.__variable.as_proxy()
52
53    def is_proxy(self):
54        """
55        Returns True if as_proxy() would succeed.
56        """
57        return self.__variable.is_proxy()
58
59    def as_fake(self):
60        """
61        Returns a "fake" value (either a FakeTensor or a SymInt)
62        representing the variable in question.  This only works
63        for variables that denote Tensor or int.  You can use
64        this to query metadata; e.g., v.as_fake().size(0) will
65        tell you the compile-time known size of the tensor.
66
67        WARNING: Do NOT mutate the returned tensor.
68        """
69        return self.__variable.as_proxy().node.meta["example_value"]
70
71    def size(self, dim: Optional[int] = None) -> Union[int, torch.SymInt]:
72        """
73        Returns the size of the tensor (if dim is None) or the size
74        at the dimension dim.  The returned size may be a SymInt.
75        """
76        return self.as_fake().size(dim)
77
78    def python_type(self):
79        """
80        Returns what type(v) would have returned for the variable
81        at compile time.
82        """
83        return self.__variable.python_type()
84
85    def as_python_constant(self):
86        """
87        Returns the Python value this variable would have, but only if it is
88        completely known at compile-time (e.g., it is constant).
89
90        WARNING: Do NOT mutate the returned constant.  The returned constant
91        may or may not correspond to the actual value this variable may take
92        on at runtime; for example, if the variable in question is a constant
93        list, we may return a copy of that list.
94        """
95        return self.__variable.as_python_constant()
96
97    def is_python_constant(self):
98        """
99        Returns True if as_python_constant would succeed.
100        """
101        return self.__variable.is_python_constant()
102
103    def is_dynamic(self):
104        if isinstance(self.__variable, SymNodeVariable):
105            fs = free_symbols(self.__variable.sym_num)
106            return bool(fs)
107        return False
108
109    def force_static(self):
110        """
111        Forces that a value is static, inducing a guard on its specific value
112        """
113        if isinstance(self.__variable, SymNodeVariable):
114            self.__variable.evaluate_expr()
115        elif isinstance(self.__variable, ConstantVariable):
116            # TODO: Maybe complain if this isn't a int/bool/float variable
117            pass
118        else:
119            raise AssertionError(
120                f"cannot force {self.__variable} ({type(self.__variable)}) static"
121            )
122
123    def _i_will_not_complain_if_bc_breaks_VariableTracker(self):
124        """
125        Returns the internal data structure VariableTracker that Dynamo uses
126        to represent variables at compile time.  There are no BC guarantees on
127        this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if you rely on
128        it.
129        """
130        return self.__variable
131
132    def __repr__(self) -> str:
133        return self.__variable.debug_repr()
134
135    # TODO: API for adding a custom guard
136
137
138class ComptimeContext:
139    """
140    This context class provides access to a public API for Dynamo's internals.
141    If there is something here you would find useful that is missing, please
142    file a feature request at https://github.com/pytorch/pytorch/
143    """
144
145    def __init__(self, tx) -> None:
146        self.__tx = tx
147
148    def get_local(self, name: str, *, stacklevel=0) -> ComptimeVar:
149        """
150        Retrieve the compile-time known information about a local.
151        """
152        tx = self.__get_tx(stacklevel)
153
154        # This is analogous to LOAD_DEREF
155        if hasattr(tx, "closure_cells") and name in tx.closure_cells:
156            cell = tx.closure_cells[name]
157            if isinstance(cell, ClosureVariable):
158                return ComptimeVar(tx.output.root_tx.symbolic_locals[cell.name])
159            else:
160                return ComptimeVar(tx.output.side_effects.load_cell(cell))
161        else:
162            r = tx.symbolic_locals[name]
163            if isinstance(r, NewCellVariable):
164                return ComptimeVar(tx.output.side_effects.load_cell(r))
165            else:
166                return ComptimeVar(r)
167
168    def graph_break(self, msg="ComptimeContext.graph_break"):
169        """
170        Manually trigger a graph break
171        """
172        unimplemented(msg)
173
174    def graph(self):
175        """
176        Retrieve the partially constructed FX graph that would be
177        passed to the user compiler after compilation.
178        """
179        return self.__tx.output.graph
180
181    def assert_static(self, val):
182        """
183        Asserts that the int is static (and not dynamic, per dynamic shapes)
184        """
185        assert (
186            not val.is_dynamic()
187        ), "expected static but got dynamic (run with TORCH_LOGS=dynamic for more info)"
188
189    def print_graph(self, *, verbose=True, file=None):
190        """
191        Print the partially constructed FX graph that would be passed
192        to the user compiler after compilation.
193        """
194        print(
195            self.__tx.output.graph.python_code("self", verbose=verbose).src, file=file
196        )
197
198    def parent(self):
199        return ComptimeContext(self.__tx.parent)
200
201    def __get_tx(self, stacklevel):
202        tx = self.__tx
203        for _ in range(stacklevel):
204            tx = tx.parent
205        return tx
206
207    def print(self, val, *, file=None):
208        print(repr(val), file=file)
209
210    def print_disas(self, *, file=None, stacklevel=0):
211        """
212        Print the current series of opcodes being executed (not including
213        parent frames), including where you are in the particular opcode
214        stream.
215        """
216        tx = self.__get_tx(stacklevel)
217        print(
218            dis.Bytecode(
219                tx.f_code,
220                current_offset=tx.instructions[tx.instruction_pointer].offset,
221            ).dis(),
222            file=file,
223        )
224
225    def print_value_stack(self, *, file=None, stacklevel=0):
226        """
227        Print the current Python value stack.  Note that this is NOT the same
228        as the traceback; use print_bt() to print that.  Note that at
229        stacklevel=0, this will typically be empty, as comptime cannot
230        currently be used in an expression context where there would be
231        intermediates on the stack.  If you would find this useful, please
232        file a bug at https://github.com/pytorch/pytorch/
233
234        NB: Stack grows downwards in our print
235        """
236        tx = self.__get_tx(stacklevel)
237        for s in tx.stack:
238            print(f"- {s.debug_repr()}", file=file)
239
240    def print_locals(self, *, file=None, stacklevel=0):
241        """
242        Print all of the locals available in the current context.
243        By default this view is very limited; you can get more information
244        about any individual local using get_local().
245        """
246        tx = self.__get_tx(stacklevel)
247        for k, v in tx.symbolic_locals.items():
248            print(f"{k} = {v.debug_repr()}", file=file)
249
250    def print_bt(self, *, file=None, stacklevel=0):
251        """
252        Print the user code backtrace, starting at the beginning of the
253        frame Dynamo started evaluating.  Note that this MAY NOT go all
254        the way to the torch.compile invocation, as we may have done
255        a graph break and are compiling an intermediate frame as the
256        starting point.  If you think the other behavior would be better,
257        file a bug at https://github.com/pytorch/pytorch/
258        """
259        stack = []
260        tx = self.__get_tx(stacklevel)
261        while tx is not None:
262            stack.append(tx.frame_summary())
263            tx = getattr(tx, "parent", None)
264        print(
265            "".join(traceback.StackSummary.from_list(reversed(stack)).format()),
266            file=file,
267        )
268
269    def print_guards(self, *, file=None):
270        """
271        Print the currently installed guards for the Dynamo context.
272        This does NOT include guards associated with variables that
273        may or may not be installed in the future if those variables
274        are used.
275        """
276        # TODO: improve print format, current guard format is extremely
277        # verbose
278        print(
279            "\n".join(f"{repr(guard)}" for guard in sorted(self.__tx.output.guards)),
280            file=file,
281        )
282
283    def _i_will_not_complain_if_bc_breaks_InstructionTranslator(self):
284        """
285        Returns the internal data structure InstructionTranslator that Dynamo
286        uses to track state of symbolic evaluation.  There are no BC
287        guarantees on this API and WE RESERVE THE RIGHT TO BREAK YOUR CODE if
288        you rely on it.
289        """
290        return self.__tx
291
292    def sleep(self, sec):
293        time.sleep(sec)
294
295
296class _Comptime:
297    @staticmethod
298    def __call__(fn, fallback_fn=lambda: None):
299        """fn gets called at compile time in TorchDynamo, calls fallback_fn otherwise"""
300        fallback_fn()
301
302    # Convenience wrappers that are more compact to use
303
304    @staticmethod
305    def graph_break():
306        comptime(lambda ctx: ctx.graph_break())
307
308    @staticmethod
309    def print(e):
310        comptime(lambda ctx: ctx.print(ctx.get_local("e")), lambda: print(e))
311
312    @staticmethod
313    def print_graph():
314        comptime(lambda ctx: ctx.print_graph())
315
316    @staticmethod
317    def print_disas(*, stacklevel=0):
318        comptime(
319            lambda ctx: ctx.print_disas(
320                stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
321            )
322        )
323
324    @staticmethod
325    def print_value_stack(*, stacklevel=0):
326        comptime(
327            lambda ctx: ctx.print_value_stack(
328                stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
329            )
330        )
331
332    # This is a more useful variant of print_value_stack that can be used
333    # in an expression context; e.g., x + print_value_stack_and_return(y + z),
334    # you will see x on the stack prior to the addition operation
335    @staticmethod
336    def print_value_stack_and_return(e, *, stacklevel=0):
337        comptime(
338            lambda ctx: ctx.print_value_stack(
339                stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
340            )
341        )
342        return e
343
344    @staticmethod
345    def print_locals(*, stacklevel=0):
346        comptime(
347            lambda ctx: ctx.print_locals(
348                stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
349            )
350        )
351
352    @staticmethod
353    def print_bt(*, stacklevel=0):
354        comptime(
355            lambda ctx: ctx.print_bt(
356                stacklevel=ctx.get_local("stacklevel").as_python_constant() + 1
357            )
358        )
359
360    @staticmethod
361    def print_guards():
362        comptime(lambda ctx: ctx.print_guards())
363
364    @staticmethod
365    def assert_static(val):
366        comptime(lambda ctx: ctx.assert_static(ctx.get_local("val")))
367
368    @staticmethod
369    def force_static(val):
370        comptime(lambda ctx: ctx.get_local("val").force_static())
371
372    @staticmethod
373    def breakpoint():
374        """
375        Like pdb breakpoint(), but drop into pdb whenever this line
376        of code is compiled by dynamo.  Use it by putting
377        this in your model code::
378
379            from torch._dynamo.comptime import comptime
380            comptime.breakpoint()
381
382        And then, inside pdb, you can access 'ctx' to query things
383        about the compilation context::
384
385            (Pdb) !ctx.print_bt()
386            (Pdb) !ctx.print_locals()
387            (Pdb) p ctx.get_local("attention").as_fake()
388        """
389
390        def inner(inner_ctx):
391            ctx = inner_ctx.parent()
392            builtins.breakpoint()
393
394        comptime(inner)
395
396    @staticmethod
397    def sleep(sec):
398        comptime(lambda ctx: ctx.sleep(ctx.get_local("sec").as_python_constant()))
399
400
401comptime = _Comptime()
402