xref: /aosp_15_r20/external/pytorch/torch/_prims/context.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import functools
3from contextlib import nullcontext
4from typing import Any, Callable, Dict, Optional, Sequence
5
6import torch
7import torch._decomp
8import torch._prims
9import torch._refs
10import torch._refs.nn
11import torch._refs.nn.functional
12import torch._refs.special
13import torch.overrides
14from torch._prims_common import torch_function_passthrough
15
16
17@functools.lru_cache(None)
18def torch_to_refs_map():
19    """
20    Mapping of torch API functions to torch._refs functions.
21    E.g. torch_to_refs_map()[torch.add] == torch._refs.add
22    """
23    modules = [
24        (torch, torch._refs),
25        (torch.nn, torch._refs.nn),
26        (torch.nn.functional, torch._refs.nn.functional),
27        (torch.special, torch._refs.special),
28        (torch.fft, torch._refs.fft),
29        (torch.linalg, torch._refs.linalg),
30    ]
31    r: Dict[Any, Any] = {
32        torch.Tensor.__invert__: torch._refs.bitwise_not,
33        torch.Tensor.__xor__: torch._refs.bitwise_xor,
34        torch.Tensor.__and__: torch._refs.bitwise_and,
35        torch.Tensor.__or__: torch._refs.bitwise_or,
36        torch.Tensor.__eq__: torch._refs.eq,
37        torch.Tensor.__rsub__: torch._refs.rsub,
38        torch.Tensor.__rtruediv__: torch._refs.rtruediv,
39        torch.Tensor.__floordiv__: torch._refs.floor_divide,
40        torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
41        torch.Tensor.__pow__: torch._refs.pow,
42        torch.Tensor.__rpow__: torch._refs.rpow,
43        torch.Tensor.new_empty: torch._refs.new_empty,
44        torch.Tensor.new_full: torch._refs.new_full,
45        torch.Tensor.new_zeros: torch._refs.new_zeros,
46        torch.Tensor.new_ones: torch._refs.new_ones,
47        torch.Tensor.fill_: torch._refs.fill_,
48        torch.Tensor.zero_: torch._refs.zero_,
49        torch.Tensor.to: torch._refs.to,
50        torch.Tensor.sum_to_size: torch._refs.sum_to_size,
51        # TODO: Should these methods be mapped some other way?
52        torch.Tensor.copy_: torch._prims.copy_to,
53        torch.Tensor.resize: torch._prims.resize,
54    }
55    for mod_torch, mod_refs in modules:
56        for s in mod_refs.__all__:  # type: ignore[attr-defined]
57            r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)
58
59    # Support remapping torch.Tensor.foo to _refs.foo
60    for s in dir(torch.Tensor):
61        if s in torch._refs.__all__:
62            r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
63
64    # Support conversions
65    for s in torch._refs._conversions.__all__:
66        tensor_attr = getattr(torch.Tensor, s, None) or getattr(torch, s)
67        r[tensor_attr] = torch._refs._conversions.__dict__.get(s)
68
69    return r
70
71
72@functools.lru_cache(None)
73def all_prims():
74    """
75    Set of all prim functions, e.g., torch._prims.add in all_prims()
76    """
77    return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}
78
79
80class TorchRefsMode(torch.overrides.TorchFunctionMode):
81    """
82    Switches the interpretation of torch.* functions and Tensor methods to
83    use PrimTorch refs in torch._refs.  (Direct calls to _refs are unaffected.)
84
85    >>> # xdoctest: +SKIP
86    >>> with TorchRefsMode():
87    ...     torch.add(x, y)  # calls torch._refs.add(x, y)
88
89    By default, this context manager will fall back on the torch.* if the
90    ref does not exist; set strict=True to error if this occurs.
91    If the ref exists we still would like to fall back on the torch.* sometimes,
92    this behavior can be customized by passing a function to should_fallback_fn.
93    """
94
95    def __init__(
96        self,
97        strict=False,
98        should_fallback_fn=lambda *_: False,
99        prims_mode_cls=nullcontext,
100    ):
101        self.strict = strict
102        self.should_fallback_fn = should_fallback_fn
103        self.prims_mode_cls = prims_mode_cls
104
105    def __torch_function__(
106        self,
107        orig_func: Callable,
108        types: Sequence,
109        args: Sequence[Any] = (),
110        kwargs: Optional[Dict] = None,
111    ):
112        if kwargs is None:
113            kwargs = {}
114        # For primitive operations, run them as is without interception
115        # Unless we are in prims_mode, in which case we want to use nvprims
116        if orig_func in torch_function_passthrough or orig_func in all_prims():
117            with self.prims_mode_cls():
118                return orig_func(*args, **kwargs)
119        mapping = torch_to_refs_map()
120        func = mapping.get(orig_func, None)
121
122        # For torch.ops.aten.*, use registered decompositions from torch._decomp
123        # torch._decomp.decomposition_table provides a mapping from
124        # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
125        # implementations.
126        # There're other ways to implement this functionality,
127        # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
128        if func is None and isinstance(orig_func, torch._ops.OpOverload):
129            func = torch._decomp.decomposition_table.get(orig_func, None)
130        elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket):
131            default = getattr(orig_func, "default", None)
132            if default is not None:
133                func = torch._decomp.decomposition_table.get(default, None)
134
135        if func is not None:
136            # If the ref exists query whether we should use it or not
137            if self.should_fallback_fn(self, orig_func, func, args, kwargs):
138                return orig_func(*args, **kwargs)
139            # torch calls inside func should be interpreted as refs calls
140            with self:
141                return func(*args, **kwargs)
142        if self.strict:
143            raise RuntimeError(
144                f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
145            )
146        return orig_func(*args, **kwargs)
147