xref: /aosp_15_r20/external/pytorch/torch/_dynamo/decorators.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# ruff: noqa: TCH004
3from dataclasses import dataclass
4from typing import TYPE_CHECKING
5
6import torch
7from torch.utils._python_dispatch import is_traceable_wrapper_subclass
8from . import trace_rules, variables
9from .comptime import comptime
10from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
11from .exc import IncorrectUsage
12from .external_utils import is_compiling
13
14if TYPE_CHECKING:
15    from torch._C._dynamo.eval_frame import (  # noqa: F401
16        reset_code,
17        set_eval_frame,
18        set_guard_error_hook,
19        skip_code,
20        unsupported,
21    )
22else:
23    for name in dir(torch._C._dynamo.eval_frame):
24        if name.startswith("__"):
25            continue
26        globals()[name] = getattr(torch._C._dynamo.eval_frame, name)
27
28
29def run(fn=None):
30    """Don't do any dynamic compiles, just use prior optimizations"""
31    if fn is not None:
32        fn = innermost_fn(fn)
33        assert callable(fn)
34        return RunOnlyContext()(fn)
35    return RunOnlyContext()
36
37
38def disable(fn=None, recursive=True):
39    """
40    Decorator and context manager to disable TorchDynamo
41
42    If recursive=True, Dynamo is completely skipped on the decorated function
43    frame as well as the recursively invoked functions.
44
45    If recursive=False, Dynamo skips frames associated with the function code,
46    but still process recursively invoked frames.
47    """
48    if recursive:
49        if fn is not None:
50            fn = innermost_fn(fn)
51            assert callable(fn)
52            return DisableContext()(fn)
53        return DisableContext()
54    else:
55        return skip(fn)
56
57
58def skip(fn=None):
59    """
60    Skip frames associated with the function code, but still process recursively
61    invoked frames
62    """
63    if fn is None:
64        return skip
65    fn = innermost_fn(fn)
66    assert callable(fn)
67    skip_code(fn.__code__)
68    fn._torchdynamo_disable = True
69    return fn
70
71
72def assume_constant_result(fn):
73    fn._dynamo_marked_constant = True
74    return fn
75
76
77def allow_in_graph(fn):
78    """
79    Tells the compiler frontend (Dynamo) to skip symbolic introspection of the function
80    and instead directly write it to the graph when encountered.
81
82    See :func:`torch.compiler.allow_in_graph`'s docstring for the full documentation
83
84    WARNING: this API can be a footgun, please read the documentation carefully.
85    """
86    if isinstance(fn, (list, tuple)):
87        return [allow_in_graph(x) for x in fn]
88    assert callable(fn), "allow_in_graph expects a callable"
89    if trace_rules.lookup_callable(fn) != variables.TorchInGraphFunctionVariable:
90        trace_rules._disallowed_callable_ids.remove(id(fn))
91        trace_rules._allowed_callable_ids.add(id(fn))
92    return fn
93
94
95def _disallow_in_graph_helper(throw_if_not_allowed):
96    def inner(fn):
97        if isinstance(fn, (list, tuple)):
98            return [disallow_in_graph(x) for x in fn]
99        assert callable(fn), "disallow_in_graph expects a callable"
100        if (
101            throw_if_not_allowed
102            and trace_rules.lookup_callable(fn)
103            != variables.TorchInGraphFunctionVariable
104            and trace_rules.lookup(fn) != variables.TorchInGraphFunctionVariable
105        ):
106            raise IncorrectUsage(
107                "disallow_in_graph is expected to be used on an already allowed callable (like torch.* ops). "
108                "Allowed callables means callables that TorchDynamo puts as-is in the extracted graph."
109            )
110        trace_rules._allowed_callable_ids.remove(id(fn))
111        trace_rules._disallowed_callable_ids.add(id(fn))
112        return fn
113
114    return inner
115
116
117def disallow_in_graph(fn):
118    """
119    Customize which functions TorchDynamo will exclude in the generated
120    graph and force a graph break on.
121    ::
122
123        torch._dynamo.disallow_in_graph(torch.sub)
124
125        @torch._dynamo.optimize(...)
126        def fn(a):
127            x = torch.add(x, 1)
128            x = torch.sub(x, 1)
129            x = torch.add(x, 1)
130            return x
131
132        fn(...)
133
134    Will break the graph on `torch.sub`, and give two graphs each with a
135    single `torch.add()` op.
136    """
137    return _disallow_in_graph_helper(throw_if_not_allowed=True)(fn)
138
139
140@_disallow_in_graph_helper(throw_if_not_allowed=False)
141def graph_break():
142    """Force a graph break"""
143    pass
144
145
146def forbid_in_graph(fn):
147    """
148    Customize which functions TorchDynamo will assert are not present while tracing.
149
150    If you want a graph break on this function instead, use disallow_in_graph.
151    TODO(voz): We now have allow_in_graph, disallow_in_graph, forbid_in_graph - some more robust
152    documentation would not be amiss.
153    """
154    if isinstance(fn, (list, tuple)):
155        return [forbid_in_graph(x) for x in fn]
156    assert callable(fn), "forbid_in_graph applies only to callables"
157    fn._dynamo_forbidden = True
158    return fn
159
160
161# Helper function to flatten a tensor subclass and apply a function to
162# all inner tensors that match the outer dim. Used to reduce duplication
163# across the various marking APIs.
164def _apply_func_to_inner_tensors_of_same_dim(func, t, *args, **kwargs):
165    assert is_traceable_wrapper_subclass(t)
166
167    attrs, ctx = t.__tensor_flatten__()
168    for attr in attrs:
169        inner = getattr(t, attr)
170        if inner.dim() == t.dim():
171            func(inner, *args, **kwargs)
172
173
174@dataclass(frozen=True)
175class _DimRange:
176    """
177    This represents an dimension of a tensor and the corresponding
178    min and max values it can take.  Don't create this
179    class directly; instead, use :func:`mark_dynamic`.
180    """
181
182    dim: int
183    min: int
184    max: int
185
186
187@forbid_in_graph
188def mark_dynamic(t, index, *, min=None, max=None):
189    """
190    Mark a tensor as having a dynamic dim and set corresponding min and max range for the dim.
191
192    [Note - on the state of mark_dynamic]
193
194    The behavior of having a dynamic dimension on a tensor is governed by a few factors:
195
196    1) torch._dynamo.config dynamic_shapes True or False.
197        a) dynamic_shapes=True - dynamic_shapes must be True for mark_dynamic to work.
198        a) dynamic_shapes=False - This config will raise an exception when used in conjunction with
199        mark_dynamic. We will eventually support this.
200
201    2) If the dimension is fully constrained - as in, it does not allow more than a single value
202    in both eager (torch.compile, torch._dynamo.optimize) mode and export mode (torch._dynamo.export),
203    we will raise an error
204
205    3) If the dimension is partially constrained - allowing at least 2 values but not the full unbounded
206    range of shapes, in eager we will pass it through, but export will raise an error.
207
208    4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
209    before torch.compile.
210
211    """
212    if is_traceable_wrapper_subclass(t):
213        # default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
214        # TODO: Make this configurable via a supported public API
215        _apply_func_to_inner_tensors_of_same_dim(
216            mark_dynamic, t, index, min=min, max=max
217        )
218
219    if isinstance(index, int):
220        if not hasattr(t, "_dynamo_dynamic_indices"):
221            t._dynamo_dynamic_indices = set()
222            t._dynamo_dynamic_range = set()
223        # TODO(voz): Should we bounds check?
224        t._dynamo_dynamic_indices.add(index)
225        t._dynamo_dynamic_range.add(_DimRange(index, min, max))
226        return
227
228    assert isinstance(index, (list, tuple))
229    for i in index:
230        mark_dynamic(t, i, min=min, max=max)
231
232
233@forbid_in_graph
234def maybe_mark_dynamic(t, index):
235    """
236    Mark a tensor as having a dynamic dim, but don't enforce it (i.e., if this
237    dimension ends up getting specialized, don't error).
238    """
239    if is_traceable_wrapper_subclass(t):
240        # default behavior: mirror maybe_mark_dynamic() on all inner tensors with same dim as t
241        # TODO: Make this configurable via a supported public API
242        _apply_func_to_inner_tensors_of_same_dim(maybe_mark_dynamic, t, index)
243
244    if isinstance(index, int):
245        if not hasattr(t, "_dynamo_weak_dynamic_indices"):
246            t._dynamo_weak_dynamic_indices = set()
247        # TODO(voz): Should we bounds check?
248        t._dynamo_weak_dynamic_indices.add(index)
249        return
250
251    assert isinstance(index, (list, tuple))
252    for i in index:
253        maybe_mark_dynamic(t, i)
254
255
256def mark_static(t, index=None):
257    """
258    Mark a tensor as having a static dim.
259
260    This will prevent us from attempting to compile it dynamically
261    when dynamic=True; this can improve trace-time performance.
262
263    This has lower precedence than mark_dynamic.
264
265    Unlike mark_dynamic, this can be done inside a graph, in which case it
266    induces specialization on the tensor.
267    """
268    if is_compiling():
269        if index is None:
270            for s in t.size():
271                comptime.force_static(s)
272        else:
273            comptime.force_static(t.size(index))
274        return
275
276    if is_traceable_wrapper_subclass(t):
277        # default behavior: mirror mark_static() on all inner tensors with same dim as t
278        # TODO: Make this configurable via a supported public API
279        _apply_func_to_inner_tensors_of_same_dim(mark_static, t, index)
280
281    if isinstance(index, int):
282        if not hasattr(t, "_dynamo_static_indices"):
283            t._dynamo_static_indices = set()
284        # TODO(voz): Should we bounds check?
285        t._dynamo_static_indices.add(index)
286    elif index is None:
287        for i in range(t.dim()):
288            mark_static(t, i)
289    else:
290        assert isinstance(index, (list, tuple))
291        for i in index:
292            mark_static(t, i)
293
294
295@forbid_in_graph
296def mark_static_address(t, guard=True):
297    """
298    Marks an input tensor whose data_ptr will not change across multiple calls
299    to a dynamo-compiled function. This indicates to cudagraphs that an extra allocation
300    is not needed for this input. The data_ptr will be guarded if guard=True. Note:
301    Tensors marked in this way will be kept alive until `torch._dynamo.reset()` is called.
302    """
303    if not isinstance(t, torch.Tensor):
304        raise TypeError(f"mark_static_address expects a tensor but recieved {type(t)}")
305
306    if guard:
307        t._dynamo_static_input_type = "guarded"  # type: ignore[attr-defined]
308    else:
309        t._dynamo_static_input_type = "unguarded"  # type: ignore[attr-defined]
310
311
312# Note: this carefully avoids eagerly import einops.
313# TODO: we should delete this whole _allow_in_graph_einops logic by approximately 2024 Q2
314def _allow_in_graph_einops():
315    import einops
316
317    try:
318        # requires einops > 0.6.1, torch >= 2.0
319        from einops._torch_specific import (  # type: ignore[attr-defined]  # noqa: F401
320            _ops_were_registered_in_torchdynamo,
321        )
322
323        # einops > 0.6.1 will call the op registration logic as it is imported.
324        pass
325    except ImportError:
326        # einops <= 0.6.1
327        allow_in_graph(einops.rearrange)
328        allow_in_graph(einops.reduce)
329        if hasattr(einops, "repeat"):
330            allow_in_graph(einops.repeat)  # available since einops 0.2.0
331        if hasattr(einops, "einsum"):
332            allow_in_graph(einops.einsum)  # available since einops 0.5.0
333        if hasattr(einops, "pack"):
334            allow_in_graph(einops.pack)  # available since einops 0.6.0
335        if hasattr(einops, "unpack"):
336            allow_in_graph(einops.unpack)  # available since einops 0.6.0
337
338
339trace_rules.add_module_init_func("einops", _allow_in_graph_einops)
340