xref: /aosp_15_r20/external/pytorch/torch/_dynamo/variables/torch_function.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import inspect
4from typing import Dict, List, TYPE_CHECKING
5
6import torch.utils._pytree as pytree
7from torch._guards import Source
8from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
9from torch.utils._device import DeviceContext
10
11from ..exc import unimplemented
12from ..guards import GuardBuilder, install_guard
13from ..source import AttrSource, GlobalSource, TorchFunctionModeStackSource, TypeSource
14from ..utils import get_safe_global_name, has_torch_function, is_tensor_base_attr_getter
15from .base import VariableTracker
16from .constant import ConstantVariable
17from .ctx_manager import ContextWrappingVariable
18from .lists import TupleVariable
19from .tensor import TensorSubclassVariable, TensorVariable
20from .user_defined import UserDefinedObjectVariable
21
22
23if TYPE_CHECKING:
24    from torch._dynamo.symbolic_convert import InstructionTranslator
25
26
27# [Note: __torch_function__] This feature is a prototype and has some rough edges (contact mlazos with issues):
28# At a high level, a torch function tensor subclass is represented as a TensorWithTFOverrideVariable, which dispatches
29# __torch_function__ on attribute accesses, method calls, and torch API calls.
30# The following is not supported:
31# - triggering __torch_function__ on tensor subclass non-tensor custom attributes
32# - graph breaking on mutating guardable tensor properties within a __torch_function__ context, this can cause
33# excessive recompiles in certain degenerate cases
34# - Matching the exact eager behavior of *ignoring* __torch_function__ objects in non-tensor argument positions of Torch API calls
35
36# The following is supported:
37# - static method impls of __torch_function__ on custom objects; this will trigger on torch API calls with the object as
38# any argument
39# - triggering __torch_function__ on torch API calls with tensor subclass arguments
40# - __torch_function__ calls on base tensor attribute access and method calls for tensor subclass instances
41# - matches the dispatch ordering behavior of eager __torch_function__ with subclass/object argumnents in any argument position
42
43# See https://docs.google.com/document/d/1WBxBSvW3NXhRp9ncmtokJloMLCtF4AYNhJaffvHe8Kw/edit#heading=h.vacn73lozd9w
44# for more information on the design.
45
46# To enable subclass behavior, add your tensor subclass type to traceable_tensor_subclasses in dynamo/config.py
47
48
49banned_attrs = [
50    fn.__self__.__name__
51    for fn in get_default_nowrap_functions()
52    if is_tensor_base_attr_getter(fn)
53]
54
55# Today set default device is placed in the graph and guarded on separately
56# so we should not trace through it. In the future we can trace it once
57# mode tracing is implemented and not put in the graph, but this is more
58# of a BE project and can be evaluated later
59IGNORED_MODES = {DeviceContext}
60
61
62class TorchFunctionModeStackVariable(VariableTracker):
63    """Fake VT to use as a dummy object, indicating the presence of torch function mode stack mutation"""
64
65    # singleton value representing the global torch function mode stack
66    # singleton (it exists in C++)
67    stack_value_singleton = object()
68
69    # offset is used to track if we have inserted/removed a
70    # device context which is always placed at the bottom of the stack
71    # if a device context is inserted, the graph will run this mutation
72    # so when we want to reconstruct any other modes on the stack
73    # their indices should be shifted right by 1 (+1)
74    # Conversely, if there was a device context on the stack, and the graph
75    # mutates the stack to remove that context (set default device to None)
76    # each of the indices of other modes should be shifted left by 1 (-1)
77    offset = 0
78
79    def __init__(self, source, symbolic_stack):
80        self.source = source
81        self.symbolic_stack = symbolic_stack
82
83    @classmethod
84    def reset(cls):
85        cls.offset = 0
86
87    @classmethod
88    def register_mutation(cls, tx: "InstructionTranslator"):
89        if cls.stack_value_singleton not in tx.output.side_effects:
90            var = cls(
91                source=Source(), symbolic_stack=tx.symbolic_torch_function_mode_stack
92            )
93            tx.output.side_effects.track_mutable(cls.stack_value_singleton, var)
94            tx.output.side_effects.mutation(var)
95
96    @classmethod
97    def register_device_context_insertion(cls, tx: "InstructionTranslator"):
98        stack = tx.symbolic_torch_function_mode_stack
99        if stack and cls.is_device_context(stack[0]):
100            return
101        else:
102            cls.offset += 1
103            tx.symbolic_torch_function_mode_stack.insert(
104                0,
105                TorchFunctionModeVariable(
106                    None, source=TorchFunctionModeStackSource(-cls.offset)
107                ),
108            )
109
110    @classmethod
111    def clear_default_device(cls, tx: "InstructionTranslator"):
112        stack = tx.symbolic_torch_function_mode_stack
113        if stack and cls.is_device_context(stack[0]):
114            stack.popleft()
115            cls.offset -= 1
116
117    @staticmethod
118    def is_device_context(var):
119        return isinstance(var.value, DeviceContext) or var.value is None
120
121    @classmethod
122    def get_mode_index(cls, ind):
123        return ind + cls.offset
124
125
126class TorchFunctionModeVariable(ContextWrappingVariable):
127    def __init__(self, value, **kwargs):
128        super().__init__(value, **kwargs)
129        self.value = value
130
131    @staticmethod
132    def get_global_mangled_name(tx, val):
133        return get_safe_global_name(
134            tx, f"__torch_function_mode_{val.__class__.__name__}", val
135        )
136
137    def reconstruct(self, codegen):
138        # We don't support locally created torch function modes yet
139        assert self.source
140        self.source.reconstruct(codegen)
141
142    def _call_func(self, tx, values):
143        unimplemented("torch function mode context manager is not supported yet")
144
145
146def _get_all_args(args, kwargs):
147    return _flatten_vts(pytree.arg_tree_leaves(*args, **kwargs))
148
149
150def _flatten_vts(vts):
151    from collections import deque
152
153    from .dicts import ConstDictVariable
154    from .lazy import LazyVariableTracker
155    from .lists import ListVariable
156
157    vts = deque(vts)
158    output = []
159
160    while vts:
161        vt = vts.pop()
162        LazyVariableTracker.realize_all(vt)
163        if isinstance(vt, ListVariable):
164            vts.extend(vt.items)
165        elif isinstance(vt, ConstDictVariable):
166            vts.extend(vt.items.values())
167        else:
168            output.append(vt)
169
170    return output
171
172
173def _get_subclass_type(var):
174    assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
175    return var.python_type()
176
177
178def _get_subclass_type_var(tx: "InstructionTranslator", var):
179    assert isinstance(var, (TensorWithTFOverrideVariable, UserDefinedObjectVariable))
180    if isinstance(var, TensorWithTFOverrideVariable):
181        return var.class_type_var(tx)
182    elif isinstance(var, UserDefinedObjectVariable):
183        from .builder import SourcelessBuilder, VariableBuilder
184
185        if var.source:
186            return VariableBuilder(tx, TypeSource(var.source))(var.python_type())
187        else:
188            return SourcelessBuilder.create(tx, var.python_type())
189
190
191def _is_attr_overidden(tx: "InstructionTranslator", var, name):
192    import torch
193
194    overridden = False
195    try:
196        attr_val = inspect.getattr_static(var.python_type(), name)
197        overridden |= attr_val != getattr(torch.Tensor, name)
198    except AttributeError:
199        pass
200
201    return overridden
202
203
204def call_torch_function(
205    tx, torch_function_type, torch_function_var, fn, types, args, kwargs
206):
207    from .builder import SourcelessBuilder
208
209    # signature:
210    # def __torch_function__(cls, func, types, args=(), kwargs=None):
211    tf_args = (
212        torch_function_type,
213        fn,
214        types,
215        SourcelessBuilder.create(tx, tuple(args)),
216        SourcelessBuilder.create(tx, kwargs),
217    )
218    return tx.inline_user_function_return(torch_function_var, tf_args, {})
219
220
221def build_torch_function_fn(tx: "InstructionTranslator", value, source):
222    from .builder import SourcelessBuilder, VariableBuilder
223
224    if source:
225        return VariableBuilder(
226            tx,
227            AttrSource(AttrSource(source, "__torch_function__"), "__func__"),
228        )(value.__torch_function__.__func__)
229    else:
230        return SourcelessBuilder.create(tx, value.__torch_function__.__func__)
231
232
233def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs):
234    return tx.output.torch_function_enabled and any(
235        has_torch_function(arg) for arg in _get_all_args(args, kwargs)
236    )
237
238
239def dispatch_torch_function(tx: "InstructionTranslator", fn, args, kwargs):
240    """Gathers all args that are TensorWithTFOverrideVariable and dispatches based on the ordering in _get_overloaded_args"""
241
242    all_args = _get_all_args(args, kwargs)
243    overloaded_args = _get_overloaded_args(
244        [arg for arg in all_args if has_torch_function(arg)],
245        _get_subclass_type,
246    )
247
248    for arg in overloaded_args:
249        res = arg.call_torch_function(
250            tx,
251            fn,
252            TupleVariable([_get_subclass_type_var(tx, arg) for arg in overloaded_args]),
253            args,
254            kwargs,
255        )
256
257        if not (isinstance(res, ConstantVariable) and res.value is NotImplemented):
258            return res
259
260    unimplemented(
261        f"All __torch_function__ overrides for call {fn} with args {args} and kwargs {kwargs} returned NotImplemented"
262    )
263
264
265class TensorWithTFOverrideVariable(TensorVariable):
266    """
267    Represents a tensor subclass instance with a __torch_function__ override.
268    """
269
270    def __init__(self, *args, **kwargs) -> None:
271        self.torch_function_fn = kwargs.pop("torch_function_fn")
272        super().__init__(*args, **kwargs)
273
274    @classmethod
275    def from_tensor_var(cls, tx, tensor_var, class_type, torch_function_fn):
276        import torch
277
278        kwargs = dict(tensor_var.__dict__)
279        assert (
280            kwargs.pop("class_type") is torch.Tensor
281        ), "invalid class type in TensorWithTFOverrideVariable.from_tensor_var"
282        var = cls(torch_function_fn=torch_function_fn, class_type=class_type, **kwargs)
283        var.install_global(tx)
284        return var
285
286    def install_global(self, tx):
287        # stash the subclass type to rewrap an output tensor if needed
288        # this is needed because the actual type needs to be available
289        # each time the compiled artifact is run and outputs a wrapped tensor.
290        if self.global_mangled_class_name(tx) not in tx.output.global_scope:
291            # Safe because global_mangled_class_name figures it out
292            tx.output.install_global_unsafe(
293                self.global_mangled_class_name(tx), self.class_type
294            )
295
296    def python_type(self):
297        return self.class_type
298
299    def class_type_var(self, tx):
300        return TensorSubclassVariable(
301            self.class_type, source=GlobalSource(self.global_mangled_class_name(tx))
302        )
303
304    def global_mangled_class_name(self, tx):
305        return get_safe_global_name(
306            tx, f"__subclass_{self.class_type.__name__}", self.class_type
307        )
308
309    def var_getattr(self, tx: "InstructionTranslator", name):
310        # [Note: __torch_function__] We currently only support attributes that are defined on
311        # base tensors, custom attribute accesses will graph break.
312        import torch
313
314        from .builder import SourcelessBuilder
315
316        if name in banned_attrs:
317            unimplemented(
318                f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported"
319            )
320
321        if _is_attr_overidden(tx, self, name):
322            unimplemented(
323                f"Accessing overridden method/attribute {name} on a tensor"
324                " subclass with a __torch_function__ override is not supported"
325            )
326
327        if tx.output.torch_function_enabled and hasattr(torch.Tensor, name):
328            if self.source:
329                install_guard(
330                    AttrSource(AttrSource(self.source, "__class__"), name).make_guard(
331                        GuardBuilder.FUNCTION_MATCH
332                    )
333                )
334            get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__)
335
336            return self.call_torch_function(
337                tx,
338                get_fn,
339                TupleVariable([self.class_type_var(tx)]),
340                [self],
341                {},
342            )
343        else:
344            return super().var_getattr(tx, name)
345
346    def call_torch_function(self, tx: "InstructionTranslator", fn, types, args, kwargs):
347        return call_torch_function(
348            tx,
349            self.class_type_var(tx),
350            self.torch_function_fn,
351            fn,
352            types,
353            args,
354            kwargs,
355        )
356
357    def call_method(
358        self,
359        tx,
360        name,
361        args: "List[VariableTracker]",
362        kwargs: "Dict[str, VariableTracker]",
363    ) -> "VariableTracker":
364        # This code block implements inlining the __torch_function__ override
365        # of `call_method`.
366        if tx.output.torch_function_enabled:
367            import torch
368
369            from .builder import SourcelessBuilder, VariableBuilder
370
371            if _is_attr_overidden(tx, self, name):
372                unimplemented(
373                    f"Calling overridden method {name} on a tensor"
374                    " subclass with a __torch_function__ override is not supported"
375                )
376
377            # [Note: __torch_function__] Currently we only support methods that are defined on tensor
378            # we will graph break in other cases this will need a bigger overhaul of extracting methods/comparing them for equality
379            # We've established with the above check that the method is not overridden, so we guard that the method is the same
380            # as the impl defined on tensor and retrieve it
381            if self.source:
382                func_var = VariableBuilder(
383                    tx, AttrSource(AttrSource(self.source, "__class__"), name)
384                )(inspect.getattr_static(self.python_type(), name))
385            else:
386                func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name))
387            return dispatch_torch_function(tx, func_var, [self] + args, kwargs)
388        else:
389            return super().call_method(tx, name, args, kwargs)
390