xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/refs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3from torch.testing._internal.opinfo.core import (
4    BinaryUfuncInfo,
5    OpInfo,
6    ReductionOpInfo,
7    UnaryUfuncInfo,
8)
9
10
11# NOTE [Python References]
12# Python References emulate existing PyTorch operations, but can ultimately
13#   be expressed in terms of "primitive" operations from torch._prims.
14#
15# These references are experimental.
16# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577
17#   for additional context.
18#
19# Python Reference OpInfos should be added to the python_ref_db list below.
20#   Tests can opt-into running on these references by including
21#   that list in the Sequence they pass to the @ops decorator.
22#
23# When a Python Reference OpInfo is constructed a pointer to an
24#   existing OpInfo must be provided using the torch_opinfo_name kwarg.
25#   The existing OpInfo with that name and no variant will be found
26#   to inherit from.
27#
28# Instead of just inheriting the existing OpInfo's metadata, the
29#   Python Reference OpInfos inherit the existing OpInfo's
30#   construction arguments. These arguments can be overridden
31#   by adding kwargs to the constructor.
32
33
34def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None):
35    """
36    Finds the OpInfo with the given name that has no variant name.
37    """
38    # NOTE: searching the global op_db doesn't work when OpInfos are split into
39    # different modules, as otherwise the op_db will not be fully constructed
40    # yet. So, instead the local op_db must be passed in explicitly.
41    if op_db is None:
42        from torch.testing._internal.common_methods_invocations import op_db
43
44    for opinfo in op_db:
45        if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name:
46            return opinfo
47
48
49def _inherit_constructor_args(name, op, inherited, overrides):
50    # inherits metadata
51    common_kwargs = {
52        "name": name,
53        "op": op,
54        "aliases": None,  # TODO add a check for alias coverage
55        "method_variant": None,
56        "inplace_variant": None,  # TODO: add a check for inplace coverage
57        "supports_scripting": False,
58    }
59
60    # Acquires inherited kwargs
61    kwargs = inherited.copy()
62
63    # Fixes metadata
64    if "kwargs" in kwargs:
65        kwargs.update(kwargs["kwargs"])
66        del kwargs["kwargs"]
67    if "self" in kwargs:
68        del kwargs["self"]
69    if "__class__" in kwargs:
70        del kwargs["__class__"]
71    if "skips" in kwargs:
72        del kwargs["skips"]
73    if "decorators" in kwargs:
74        del kwargs["decorators"]
75
76    # Overrides metadata
77    kwargs.update(common_kwargs)
78    kwargs.update(overrides)
79
80    # At the moment no prims support autograd, so we must not run autograd
81    # tests e.g. when testing dtype support.  Once we start writing autograd
82    # formulas for prims this can be removed.
83    kwargs["supports_autograd"] = False
84    kwargs["supports_gradgrad"] = False
85    kwargs["supports_fwgrad_bwgrad"] = False
86    kwargs["supports_inplace_autograd"] = False
87    kwargs["supports_forward_ad"] = False
88
89    return kwargs
90
91
92class PythonRefInfo(OpInfo):
93    """
94    An OpInfo for a Python reference of an OpInfo base class operation.
95    """
96
97    def __init__(
98        self,
99        name,  # the stringname of the callable Python reference
100        *,
101        op=None,  # the function variant of the operation, populated as torch.<name> if None
102        op_db=None,  # The database of opinfos to search for the parent opinfo
103        torch_opinfo_name,  # the string name of the corresponding torch opinfo
104        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
105        validate_view_consistency=True,
106        **kwargs,
107    ):  # additional kwargs override kwargs inherited from the torch opinfo
108        self.torch_opinfo_name = torch_opinfo_name
109        self.torch_opinfo_variant_name = torch_opinfo_variant_name
110        self.torch_opinfo = _find_referenced_opinfo(
111            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
112        )
113        self.validate_view_consistency = validate_view_consistency
114        assert isinstance(self.torch_opinfo, OpInfo)
115
116        inherited = self.torch_opinfo._original_opinfo_args
117        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
118        super().__init__(**ukwargs)
119
120
121class ReductionPythonRefInfo(ReductionOpInfo):
122    """
123    An OpInfo for a Python reference of an elementwise unary operation.
124    """
125
126    def __init__(
127        self,
128        name,  # the stringname of the callable Python reference
129        *,
130        op=None,  # the function variant of the operation, populated as torch.<name> if None
131        op_db=None,  # The database of opinfos to search for the parent opinfo
132        torch_opinfo_name,  # the string name of the corresponding torch opinfo
133        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
134        **kwargs,
135    ):  # additional kwargs override kwargs inherited from the torch opinfo
136        self.torch_opinfo_name = torch_opinfo_name
137        self.torch_opinfo_variant_name = torch_opinfo_variant_name
138        self.torch_opinfo = _find_referenced_opinfo(
139            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
140        )
141        assert isinstance(self.torch_opinfo, ReductionOpInfo)
142
143        inherited = self.torch_opinfo._original_reduction_args
144        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
145
146        # See https://github.com/pytorch/pytorch/issues/77216
147        self.validate_view_consistency = False
148
149        super().__init__(**ukwargs)
150
151
152class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo):
153    """
154    An OpInfo for a Python reference of an elementwise unary operation.
155    """
156
157    def __init__(
158        self,
159        name,  # the stringname of the callable Python reference
160        *,
161        op=None,  # the function variant of the operation, populated as torch.<name> if None
162        op_db=None,  # The database of opinfos to search for the parent opinfo
163        torch_opinfo_name,  # the string name of the corresponding torch opinfo
164        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
165        validate_view_consistency=True,
166        **kwargs,
167    ):  # additional kwargs override kwargs inherited from the torch opinfo
168        self.torch_opinfo_name = torch_opinfo_name
169        self.torch_opinfo_variant_name = torch_opinfo_variant_name
170        self.torch_opinfo = _find_referenced_opinfo(
171            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
172        )
173        self.validate_view_consistency = validate_view_consistency
174        assert isinstance(self.torch_opinfo, UnaryUfuncInfo)
175
176        inherited = self.torch_opinfo._original_unary_ufunc_args
177        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
178
179        super().__init__(**ukwargs)
180
181
182class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo):
183    """
184    An OpInfo for a Python reference of an elementwise binary operation.
185    """
186
187    def __init__(
188        self,
189        name,  # the stringname of the callable Python reference
190        *,
191        op=None,  # the function variant of the operation, populated as torch.<name> if None
192        op_db=None,  # The database of opinfos to search for the parent opinfo
193        torch_opinfo_name,  # the string name of the corresponding torch opinfo
194        torch_opinfo_variant_name="",  # the variant name for corresponding torch opinfo
195        **kwargs,
196    ):  # additional kwargs override kwargs inherited from the torch opinfo
197        self.torch_opinfo_name = torch_opinfo_name
198        self.torch_opinfo_variant_name = torch_opinfo_variant_name
199        self.torch_opinfo = _find_referenced_opinfo(
200            torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db
201        )
202        assert isinstance(self.torch_opinfo, BinaryUfuncInfo)
203
204        inherited = self.torch_opinfo._original_binary_ufunc_args
205        ukwargs = _inherit_constructor_args(name, op, inherited, kwargs)
206
207        super().__init__(**ukwargs)
208