xref: /aosp_15_r20/external/pytorch/torch/utils/_python_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3
4import warnings
5from dataclasses import dataclass
6from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type
7from typing_extensions import TypeGuard
8from collections import deque
9
10import torch
11import torchgen
12import torchgen.model
13from torch._C import (
14    _get_dispatch_stack_at,
15    _len_torch_dispatch_stack,
16    _pop_torch_dispatch_stack,
17    _push_on_torch_dispatch_stack,
18    DispatchKey,
19)
20
21
22# TODO: Limitations and things about enable_torch_dispatch_mode we should fix before exposing it:
23# - We need a better user-facing api for _DisableTorchDispatch that
24#   is able to selectively disable __torch_dispatch__ of a particular class.
25# - It doesn't work with the tensor constructors (torch.tensor, torch.Tensor)
26# - Better name (see https://github.com/pytorch/pytorch/pull/63496#discussion_r694091694)
27
28_is_in_torch_dispatch_mode = False
29_is_in_non_infra_torch_dispatch_mode = False
30
31def is_in_torch_dispatch_mode(include_infra_modes=True) -> bool:
32    return _is_in_torch_dispatch_mode if include_infra_modes else _is_in_non_infra_torch_dispatch_mode
33
34
35class TorchDispatchMode:
36    """
37    A ``TorchDispatchMode`` allows you to override the meaning of all
38    ``__torch_dispatch__`` overrideable functions within a dynamic scope,
39    without having to actually create a tensor subclass or manually
40    monkey-patch functions in the PyTorch API.  Some common situations
41    where you should use a mode:
42
43        * You want to override the meaning of factory functions, or other
44          functions that do not otherwise take a tensor as an argument
45          (these cannot be overridden with tensor subclasses).
46
47        * You want to override the behavior of all functions without needing
48          to wrap your inputs in tensor subclasses; e.g., if you are just
49          interested in logging intermediate computations.
50
51        * You want to control the order of execution of various tensor
52          subclasses explicitly, rather than implicitly via the return of
53          ``NotImplemented``.
54
55    Independent subclasses of :class:`TorchDispatchMode` are compositional:
56    modes can be pushed onto a stack using ``with MyMode():``.
57    When you call functions in the PyTorch API inside your
58    ``__torch_dispatch__`` implementation, by default, they will forward on to
59    the next mode on the mode stack.  If you want recursively call back into
60    your current ``__torch_dispatch__`` implementation, either explicitly
61    invoke ``self.__torch_dispatch__(...)``, or use the context manager
62    ``__torch_dispatch__(self)`` to make PyTorch
63    API self-referential (beware of infinite loops, in this case!)
64    """
65
66    def __init__(self, _dispatch_key=None):
67        if _dispatch_key is not None:
68            assert isinstance(_dispatch_key, torch._C.DispatchKey)
69            self.__dict__["_dispatch_key"] = _dispatch_key
70
71        self.old_dispatch_mode_flags: Deque[bool] = deque()
72        self.old_non_infra_dispatch_mode_flags: Deque[bool] = deque()
73
74    def _lazy_init_old_dispatch_mode_flags(self):
75        if not hasattr(self, "old_dispatch_mode_flags"):
76            self.old_dispatch_mode_flags: Deque[bool] = deque()  # type: ignore[no-redef]
77
78        if not hasattr(self, "old_non_infra_dispatch_mode_flags"):
79            self.old_non_infra_dispatch_mode_flags: Deque[bool] = deque()  # type: ignore[no-redef]
80
81
82    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
83        raise NotImplementedError
84
85    def __enter__(self):
86        global _is_in_torch_dispatch_mode
87        global _is_in_non_infra_torch_dispatch_mode
88        # Previously, there wasn't any state in this class' constructor
89        # super calls were added to existing modes, but for any new modes
90        # this will replicate the previous behavior of not strictly needing
91        # to call super().__init__()
92        self._lazy_init_old_dispatch_mode_flags()
93        self.old_dispatch_mode_flags.append(_is_in_torch_dispatch_mode)
94        _is_in_torch_dispatch_mode = True
95        self.old_non_infra_dispatch_mode_flags.append(_is_in_non_infra_torch_dispatch_mode)
96        _is_in_non_infra_torch_dispatch_mode = _is_in_non_infra_torch_dispatch_mode or not self.is_infra_mode()
97        _push_mode(self)
98        return self
99
100    def __exit__(self, exc_type, exc_val, exc_tb):
101        mb_dk_or_mode_key = self.__dict__.get("_dispatch_key", None)
102        if mb_dk_or_mode_key is None:
103            # Today, mode keys are not used at all in the per-dispatch-key-mode logic (for pre-dispatch)
104            # We should probably revisit this.
105            mb_dk_or_mode_key = self.__dict__.get("_mode_key", None)
106        global _is_in_torch_dispatch_mode
107        _is_in_torch_dispatch_mode = self.old_dispatch_mode_flags.pop()
108        global _is_in_non_infra_torch_dispatch_mode
109        _is_in_non_infra_torch_dispatch_mode = self.old_non_infra_dispatch_mode_flags.pop()
110        _pop_mode(mb_dk_or_mode_key)
111
112    @classmethod
113    def push(cls, *args, **kwargs):
114        warnings.warn(
115            "`Mode.push()` is no longer necessary and can be replaced with just `with Mode()`"
116        )
117        instance = cls(*args, **kwargs)
118        return instance
119
120    @classmethod
121    def is_infra_mode(cls):
122        return False
123
124
125
126def _get_current_dispatch_mode():
127    stack_len = _len_torch_dispatch_stack()
128    # Return a user mode on the stack if there are any
129    if stack_len > 0:
130        return _get_dispatch_stack_at(stack_len - 1)
131    return None
132
133
134def _detect_infra_mode(key):
135    assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
136    from torch._ops import _get_dispatch_mode_pre_dispatch
137
138    pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
139        key
140    )
141    post_dispatch_mode = torch._C._get_dispatch_mode(
142        key
143    )
144
145    assert (pre_dispatch_mode is None) or (
146        post_dispatch_mode is None
147    )
148
149    if pre_dispatch_mode is None:
150        return post_dispatch_mode
151
152    return pre_dispatch_mode
153
154
155def _unset_infra_mode(key):
156    from torch._ops import _get_dispatch_mode_pre_dispatch, unset_mode_pre_dispatch
157
158    pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(key)
159    post_dispatch_mode = torch._C._get_dispatch_mode(key)
160    if pre_dispatch_mode and post_dispatch_mode:
161        raise AssertionError(
162            "Can't have active infra mode on both pre and post dispatch mode stack"
163        )
164
165    if pre_dispatch_mode:
166        mode = unset_mode_pre_dispatch(key)
167        return mode
168    if post_dispatch_mode:
169        return torch._C._unset_dispatch_mode(key)
170
171
172def _disable_infra_mode(key):
173    assert key in (
174        torch._C._TorchDispatchModeKey.FUNCTIONAL,
175        torch._C._TorchDispatchModeKey.PROXY,
176    )
177    mode_unset = _unset_infra_mode(key)
178    try:
179        yield mode_unset
180    finally:
181        if mode_unset is not None:
182            _push_mode(mode_unset)
183
184
185def _get_current_dispatch_mode_stack():
186    stack_len = _len_torch_dispatch_stack()
187    return [_get_dispatch_stack_at(i) for i in range(stack_len)]
188
189
190def _push_mode(mode: TorchDispatchMode):
191    k = mode._dispatch_key if hasattr(mode, "_dispatch_key") else None
192    assert k is None or k == torch._C.DispatchKey.PreDispatch
193    if k is None:
194        _push_on_torch_dispatch_stack(mode)
195        return
196
197    from torch._ops import _set_mode_pre_dispatch, get_cached_ops
198
199    # See Note [Not Caching Per-Dispatch-Key Mode Handlers]
200    # Clear the cache of every op that has been used so far, for this particular key.
201    ks = torch._C._functionality_to_backend_keys(k)
202    for op in get_cached_ops():
203        for key in ks:
204            op._uncache_dispatch(key)
205    _set_mode_pre_dispatch(mode)
206
207
208def _pop_mode(k: Optional[Union[DispatchKey, torch._C._TorchDispatchModeKey]] = None):
209    if k == torch._C.DispatchKey.PreDispatch:  # type: ignore[attr-defined]
210        from torch._ops import _pop_mode_from_pre_dispatch
211
212        return _pop_mode_from_pre_dispatch()
213
214    if k is None or isinstance(k, torch._C._TorchDispatchModeKey):
215        return _pop_torch_dispatch_stack(k)
216
217
218@contextlib.contextmanager
219def _pop_mode_temporarily(k: Optional[DispatchKey] = None):
220    old = _pop_mode(k)
221    try:
222        yield old
223    finally:
224        _push_mode(old)
225
226
227@contextlib.contextmanager
228def _disable_current_modes():
229    from torch._ops import (
230        _len_torch_dispatch_stack_pre_dispatch,
231        _pop_mode_from_pre_dispatch,
232    )
233    from torch._subclasses.functional_tensor import FunctionalTensorMode
234    from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode
235    from torch._subclasses.schema_check_mode import SchemaCheckMode
236
237    mode_len_pre_dispatch = _len_torch_dispatch_stack_pre_dispatch()
238    old_pre_dispatch_modes = [
239        _pop_mode_from_pre_dispatch() for _ in range(mode_len_pre_dispatch)
240    ]
241
242    has_proxy_mode_in_pre_dispatch = False
243    has_functional_mode_in_pre_dispatch = False
244    has_schema_check_mode_in_pre_dispatch = False
245
246    for i in old_pre_dispatch_modes:
247        if isinstance(i, ProxyTorchDispatchMode):
248            has_proxy_mode_in_pre_dispatch = True
249        if isinstance(i, FunctionalTensorMode):
250            has_functional_mode_in_pre_dispatch = True
251        if isinstance(i, SchemaCheckMode):
252            has_schema_check_mode_in_pre_dispatch = True
253
254    mode_len = _len_torch_dispatch_stack()
255    old_modes = [_pop_mode() for _ in range(mode_len)]
256
257    for old in old_modes:
258        if (
259            isinstance(old, FunctionalTensorMode)
260            and has_functional_mode_in_pre_dispatch
261        ):
262            raise AssertionError(
263                "Can't have FunctionalMode available both in PreDispatch and Python Key"
264            )
265        if isinstance(old, ProxyTorchDispatchMode) and has_proxy_mode_in_pre_dispatch:
266            raise AssertionError(
267                "Can't have ProxyTorchDispatchMode available both in PreDispatch and Python Key"
268            )
269        if (
270            isinstance(old, SchemaCheckMode)
271            and has_schema_check_mode_in_pre_dispatch
272        ):
273            raise AssertionError(
274                "Can't have SchemaCheckMode available both in PreDispatch and Python Key"
275            )
276
277    # Manually disable proxy and fake modes, if any are active
278    try:
279        yield old_pre_dispatch_modes + old_modes
280    finally:
281        for mode in reversed(old_modes):
282            _push_mode(mode)
283        for mode in reversed(old_pre_dispatch_modes):
284            _push_mode(mode)
285
286
287class BaseTorchDispatchMode(TorchDispatchMode):
288    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
289        if kwargs is None:
290            kwargs = {}
291        return func(*args, **kwargs)
292
293
294# Subtypes which have __tensor_flatten__ and __tensor_unflatten__.
295class TensorWithFlatten(Protocol):
296    def __tensor_flatten__(self) -> Tuple[Sequence[str], object]:
297        ...
298
299    @staticmethod
300    def __tensor_unflatten__(inner_tensors: int, flatten_spec: int, outer_size: int, outer_stride: int) -> torch.Tensor:
301        ...
302
303    # It would be really nice to be able to say that the return of
304    # is_traceable_wrapper_subclass() is Intersection[torch.Tensor,
305    # TensorWithFlatten] - but that doesn't exist.
306
307    shape: torch._C.Size
308
309    @overload
310    def stride(self, dim: None = None) -> Tuple[int, ...]:
311        ...
312
313    @overload
314    def stride(self, dim: int) -> int:
315        ...
316
317    def dim(self) -> int:
318        ...
319
320    @overload
321    def to(
322            self,
323            dtype: torch.types._dtype,
324            non_blocking: bool = False,
325            copy: bool = False,
326            *,
327            memory_format: Optional[torch.memory_format] = None
328    ) -> torch.Tensor:
329        ...
330
331    @overload
332    def to(
333            self,
334            device: Optional["torch._prims_common.DeviceLikeType"] = None,
335            dtype: Optional[torch.types._dtype] = None,
336            non_blocking: bool = False,
337            copy: bool = False,
338            *,
339            memory_format: Optional[torch.memory_format] = None
340    ) -> torch.Tensor:
341        ...
342
343    @overload
344    def to(
345            self,
346            other: torch.Tensor,
347            non_blocking: bool = False,
348            copy: bool = False,
349            *,
350            memory_format: Optional[torch.memory_format] = None
351    ) -> torch.Tensor:
352        ...
353
354
355
356
357def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]:
358    """
359    Returns whether or not a tensor subclass that implements __torch_dispatch__
360    is 'traceable' with torch.compile.
361    In order for a tensor subclass to support TorchDispatchMode-style tracing in PT2,
362    It must implement two magic methods: __tensor_flatten__ and __tensor_unflatten__.
363    It is also expected to obey some restrictions around traceability and aliasing:
364        * The subclass's __torch_dispatch__() implementation should desugar into pytorch
365            dispatcher operations that can be traced into a graph.
366        * The subclass should use return_and_correct_aliasing(). This is needed today to make
367            sure that torch.compile does the right thing in a few cases around input mutation
368            and output aliasing.
369
370    Expected magic method signatures:
371        attrs, ctx = t.__tensor_flatten__()
372            attrs: list of attribute name strings for inner tensors
373            ctx: dict containing any other subclass-specific metadata needed for unflattening
374
375        t = MySubClass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)
376            inner_tensors: dict mapping attribute name -> tensor for each inner tensor
377            ctx: dict with subclass metadata in the form that __tensor_flatten__() produces
378            outer_size: expected (possibly symbolic) size that the returned subclass
379                instance should have. Note that this arg is useful for certain subclasses
380                that require the shape info to be constructed. In most cases, this arg can be
381                safely ignored.
382            outer_stride: expected (possibly symbolic) stride that the returned subclass
383                instance should have. Note that this arg is useful for certain subclasses
384                that require the stride info to be constructed. In most cases, this arg can be
385                safely ignored.
386    """
387    is_subclass = isinstance(t, torch.Tensor) and type(t) != torch.Tensor
388    return (
389        is_subclass
390        and hasattr(t, "__tensor_flatten__")
391        and hasattr(t, "__tensor_unflatten__")
392    )
393
394def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]:
395    """Same as above, but takes a type argument instead of an instance."""
396    return (issubclass(t, torch.Tensor) and t != torch.Tensor
397            and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__"))
398
399
400def transform_subclass(t, callback, outer_size=None, outer_stride=None):
401    """
402    Given a traceable, wrapper tensor subclass ``t`` that implements
403    ``__torch_dispatch__`` and holds some inner tensors,
404    and a callback of type ``Callable[[str, torch.Tensor], torch.Tensor]``,
405    `transform_subclass` will construct a fresh instance of the wrapper tensor subclass.
406    It will do so by grabbing each inner tensor attribute from the wrapper,
407    passing them into ``callback`` to get a transformed tensor,
408    and putting each transformed tensor into the fresh tensor subclass instance.
409
410    Note: this function will not handle ensuring that the fresh subclass
411    gets the same (autograd, and aliasing) metadata as the original tensor.
412    This is generally handled in other subsystems like AOTAutograd.
413    """
414    outer_size = outer_size if outer_size is not None else t.size()
415    outer_stride = outer_stride if outer_stride is not None else t.stride()
416
417    attrs, ctx = t.__tensor_flatten__()
418    transformed_tensors_dict = {}
419    for attr in attrs:
420        transformed_tensors_dict[attr] = callback(attr, getattr(t, attr))
421    sub = type(t).__tensor_unflatten__(
422        transformed_tensors_dict, ctx, outer_size, outer_stride
423    )
424
425    # NB: Purposefully guard here to simplify the inner / outer symbols.
426    # Using sym_eq() for symbolic comparison can result in an expression that's too
427    # difficult to guard on, so we use == here.
428    assert sub.shape == outer_size, (
429        f"Expected return value from {type(t)}__tensor_unflatten__() to have "
430        f"shape equal to {outer_size}, but got: {sub.shape}"
431    )
432    assert sub.stride() == outer_stride, (
433        f"Expected return value from {type(t)}__tensor_unflatten__() to have "
434        f"stride equal to {outer_stride}, but got: {sub.stride()}"
435    )
436
437    return sub
438
439
440def _correct_storage_aliasing(func, schema_info, args, outs):
441    """
442    Given: an OpOverload, a SchemaInfo (cached information from torchgen about schema),
443    and the inputs/outputs to the OpOverload,
444    this function checks to see if func is a view operator
445    (by checking if any of the outputs in the op's schema
446     are immutable aliases of inputs).
447    If so, this function manually aliases the storage of the output tensor
448    with its corresponding input tensor alias.
449    It does this by unsafely overwriting the storage field of the output tensor
450    to be the same storage as the input.
451    """
452    assert isinstance(func, torch._ops.OpOverload)
453    assert isinstance(args, tuple)
454    assert isinstance(outs, (list, tuple))
455    flat_outs = torch.utils._pytree.tree_leaves(outs)
456
457    def alias_non_inplace_storage(arg, ret):
458        # This is hopefully a reasonable assert:
459        # subclasses that rely on this API for output aliasing
460        # should always return wrapper tensor subclasses for us to manually alias.
461        # in theory if a subclass that needs this API wants to sometimes return
462        # plain tensors, we could remove the assert and just not perform the aliasing,
463        # but it seems safer to learn more about this case first.
464        if is_traceable_wrapper_subclass(arg) or is_traceable_wrapper_subclass(ret):
465            ret_list = ret if isinstance(ret, list) else [ret]
466            for r in ret_list:
467                assert type(arg) == type(
468                    r
469                ), f"""Called {str(func)} with input of type {type(arg)}
470and output of type {type(ret)}. But expected types to match."""
471        # Need to call a non-dispatcher helper, because we explicitly do **not**
472        # want our subclass to intercept the set_() call.
473        # instead, our subclass should directly have its storage swapped out.
474        # we **explicitly** don't want to reset the sizes on ret, if the storage implies a size change.
475        # Why?
476        # The purpose of this API is *not* to change the size/strides of our output- we assume it's already correct.
477        # We just want to "fix up" the storage aliasing, without modifying or output's metadata.
478        # Example: out = inp.expand(inp.shape[0], inp.shape[0])
479        #     This requires swapping the storage of out to be the same as inp,
480        #     but we do *not* want it to change the sizes/strides that were compute for out.
481
482        if isinstance(ret, list):
483            for r in ret:
484                torch._functionalize_unsafe_set(r, arg)
485        else:
486            assert isinstance(ret, torch.Tensor), f"type: {type(ret)}"
487            torch._functionalize_unsafe_set(ret, arg)
488
489    def is_read_only_alias_match(arg, ret):
490        shared_aliases = arg.alias_set & ret.alias_set
491        return len(shared_aliases) > 0 and not arg.is_write
492
493    num_args = len(func._schema.arguments)
494    num_returns = len(func._schema.returns)
495    for arg_idx in range(num_args):
496        for return_idx in range(num_returns):
497            if is_read_only_alias_match(
498                schema_info.args[arg_idx], schema_info.outs[return_idx]
499            ):
500                alias_non_inplace_storage(args[arg_idx], outs[return_idx])
501
502
503# This abstracts over the fact that in return_and_correct_aliasing,
504# we sometimes use torchgen schema parsing (for aten ops, since torchscript's schema parsing is sometimes buggy),
505# and sometimes use torchscript schema parsing (for custom ops, for which torchgen parsing is untested).
506@dataclass
507class AliasInfo:
508    alias_set: Set[str]
509    is_write: bool
510    name: Optional[str]
511
512
513@dataclass
514class SchemaInfo:
515    args: List[AliasInfo]
516    outs: List[AliasInfo]
517
518
519# Can't import torch._ops.OpOverload due to circular reference
520parsed_schema_map: Dict[Any, SchemaInfo] = {}
521
522
523# Given an OpOverload, returns schema information on it.
524# This is cached for efficiency, since it can involve running torchgen
525def get_alias_info(func) -> SchemaInfo:
526    if func in parsed_schema_map:
527        return parsed_schema_map[func]
528    # For ATen ops: use torchgen (since torchscript parser doesn't handle alias annotations
529    # properly for some ops that output tensorlists)
530    if func.namespace == "aten":
531        torchgen_schema_str = str(func._schema)
532        assert torchgen_schema_str.startswith("aten::")
533        # remove the aten:: namespace, which is added by the torchscript parser,
534        # and torchgen doesn't know how to handle
535        torchgen_schema_str = torchgen_schema_str[6:]
536        import re
537
538        # the torchscript parser ends up converting int[2]=1 into int[2]=[1, 1],
539        # which torchgen chokes on.
540        torchgen_schema_str = re.sub(r"=\[[0, ]+\]", "=0", torchgen_schema_str)
541        torchgen_schema_str = re.sub(r"=\[[1, ]+\]", "=1", torchgen_schema_str)
542        # for aten::rot90
543        torchgen_schema_str = torchgen_schema_str.replace("=[0, 1]", "=[0,1]")
544        torchgen_schema = torchgen.model.FunctionSchema.parse(torchgen_schema_str)
545        arg_schemas = [
546            AliasInfo(
547                alias_set=(
548                    set() if a.annotation is None else set(a.annotation.alias_set)
549                ),
550                is_write=a.annotation is not None and a.annotation.is_write,
551                name=a.name,
552            )
553            for a in torchgen_schema.arguments.flat_all
554        ]
555        out_schemas = [
556            AliasInfo(
557                alias_set=(
558                    set() if a.annotation is None else set(a.annotation.alias_set)
559                ),
560                is_write=a.annotation is not None and a.annotation.is_write,
561                name=a.name,
562            )
563            for a in torchgen_schema.returns
564        ]
565    else:
566        # For non-aten ops, torchgen is untested so we rely on torchscript schema parsing
567        arg_schemas = [
568            AliasInfo(
569                alias_set=(
570                    set() if a.alias_info is None else set(a.alias_info.before_set)
571                ),
572                is_write=a.alias_info is not None and a.alias_info.is_write,
573                name=a.name,
574            )
575            for a in func._schema.arguments
576        ]
577        out_schemas = [
578            AliasInfo(
579                alias_set=(
580                    set() if a.alias_info is None else set(a.alias_info.before_set)
581                ),
582                is_write=a.alias_info is not None and a.alias_info.is_write,
583                name=a.name,
584            )
585            for a in func._schema.returns
586        ]
587    schema_info = SchemaInfo(args=arg_schemas, outs=out_schemas)
588    parsed_schema_map[func] = schema_info
589    return schema_info
590
591
592def return_and_correct_aliasing(func, args, kwargs, out):
593    """
594    This function should be used by wrapper tensor ``__torch_dispatch__`` subclasses
595    that would like to work with torch.compile. It ensures that the subclass
596    properly implements the aliasing behavior of every op,
597    which is needed for correctness in AOTAutograd.
598    This function will handle:
599
600        * When we see a view op, we will alias the storages of any
601          input and output tensor subclasses
602
603        * When we see an inplace or out= op, we will directly
604          return the corresponding input tensor, instead of returning
605          a (potentially) fresh output tensor.
606    """
607
608    # Caching here because torchgen parsing is definitely not fast, and this function is called
609    # once for every op in the graph during functionalization.
610    schema_info = get_alias_info(func)
611
612    def get_write_alias(x):
613        if len(x.alias_set) == 0:
614            return None
615        alias_set = list(x.alias_set)
616        # torchscript allows for complicated alias sets, but our dispatcher ops only really involve simple aliasing
617        assert len(alias_set) == 1
618        if x.is_write:
619            return alias_set[0]
620        return None
621
622    def get_arg_from_alias(output_alias, schema_info, args, kwargs):
623        new_args, new_kwargs = torch.fx.operator_schemas.normalize_function(  # type: ignore[misc]
624            func, args=args, kwargs=kwargs
625        )
626
627        arg_indices = [
628            i for i, a in enumerate(schema_info.args) if output_alias in a.alias_set
629        ]
630        # For any dispatcher op with an output alias, we expect it to map to exactly one alias in the schema's input arguments.
631        assert len(arg_indices) == 1
632        idx = arg_indices[0]
633        arg_info = schema_info.args[idx]
634        if arg_info.name is not None and arg_info.name in new_kwargs:
635            return new_kwargs[arg_info.name]
636        return new_args[idx]
637
638    # Fix up the storages of any outs so that they point to the same storage as the input,
639    # if func is a view op.
640    _correct_storage_aliasing(
641        func, schema_info, args, (out,) if not isinstance(out, tuple) else out
642    )
643
644    # For inplace_view ops in particular, we'll try hard to make sure that the wrapper subclass's
645    # metadata is set correctly.
646    if torch.Tag.inplace_view in func.tags:
647        # no_dispatch() to make sure that we secretly change the metadata on the wrapper,
648        # but don't end up dispatching the op anywhere else.
649        mutated_args = [
650            x
651            for i, x in enumerate(args)
652            if get_write_alias(schema_info.args[i]) is not None
653        ]
654        # Assumption: we have a very small number of inplace_view ops that follow a strict schema:
655        # there is only a single argument that gets its metadata mutated.
656        assert len(mutated_args) == 1
657        # This check exists because we generally *do* want to update the metadata of any wrapper subclasses,
658        # but FunctionalTensor is special: it overrides all size/stride calls to plumb to the inner tensor.
659        # so we don't actually need to update the metadata (and attempting to do so causes errors)
660        from torch._subclasses.functional_tensor import FunctionalTensor
661
662        if not isinstance(mutated_args[0], FunctionalTensor):
663            with torch.utils._mode_utils.no_dispatch():
664                # See Note: [Fake Tensor Dispatch Keys]
665                # we're borrowing the way it modifies dispatch key TLS.
666                meta_in_tls = torch._C._meta_in_tls_dispatch_include()
667                torch._C._set_meta_in_tls_dispatch_include(True)
668                try:
669                    func(*args, **kwargs)
670                finally:
671                    torch._C._set_meta_in_tls_dispatch_include(meta_in_tls)
672
673    # Next: we need to make sure to return inputs directly, if the output is a mutable alias (e.g. add_()).
674
675    # simple case: none of our outputs have mutable aliases, so we can return the output as-is
676    if not any(get_write_alias(r) is not None for r in schema_info.outs):
677        return out
678
679    # simplifying assumption: we don't have **any** ops with return types like "-> (Tensor(a!), Tensor)"
680    if not all(get_write_alias(r) is not None for r in schema_info.outs):
681        raise RuntimeError("Unsupported schema: " + str(func._schema))
682
683    if len(func._schema.returns) == 1:
684        return get_arg_from_alias(
685            get_write_alias(schema_info.outs[0]), schema_info, args, kwargs
686        )
687
688    # In the multi-return case, all aten ops return a tuple / list, so cast accordingly.
689    outs_to_return = type(out)(
690        [
691            (
692                get_arg_from_alias(
693                    get_write_alias(schema_info.outs[i]), schema_info, args, kwargs
694                )
695                if get_write_alias(r) is not None
696                else o
697            )
698            for ((i, r), o) in zip(enumerate(schema_info.outs), out)
699        ]
700    )
701    return outs_to_return
702