1# mypy: ignore-errors 2 3from torch.testing._internal.opinfo.core import ( 4 BinaryUfuncInfo, 5 OpInfo, 6 ReductionOpInfo, 7 UnaryUfuncInfo, 8) 9 10 11# NOTE [Python References] 12# Python References emulate existing PyTorch operations, but can ultimately 13# be expressed in terms of "primitive" operations from torch._prims. 14# 15# These references are experimental. 16# See https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-0/577 17# for additional context. 18# 19# Python Reference OpInfos should be added to the python_ref_db list below. 20# Tests can opt-into running on these references by including 21# that list in the Sequence they pass to the @ops decorator. 22# 23# When a Python Reference OpInfo is constructed a pointer to an 24# existing OpInfo must be provided using the torch_opinfo_name kwarg. 25# The existing OpInfo with that name and no variant will be found 26# to inherit from. 27# 28# Instead of just inheriting the existing OpInfo's metadata, the 29# Python Reference OpInfos inherit the existing OpInfo's 30# construction arguments. These arguments can be overridden 31# by adding kwargs to the constructor. 32 33 34def _find_referenced_opinfo(referenced_name, variant_name, *, op_db=None): 35 """ 36 Finds the OpInfo with the given name that has no variant name. 37 """ 38 # NOTE: searching the global op_db doesn't work when OpInfos are split into 39 # different modules, as otherwise the op_db will not be fully constructed 40 # yet. So, instead the local op_db must be passed in explicitly. 41 if op_db is None: 42 from torch.testing._internal.common_methods_invocations import op_db 43 44 for opinfo in op_db: 45 if opinfo.name == referenced_name and opinfo.variant_test_name == variant_name: 46 return opinfo 47 48 49def _inherit_constructor_args(name, op, inherited, overrides): 50 # inherits metadata 51 common_kwargs = { 52 "name": name, 53 "op": op, 54 "aliases": None, # TODO add a check for alias coverage 55 "method_variant": None, 56 "inplace_variant": None, # TODO: add a check for inplace coverage 57 "supports_scripting": False, 58 } 59 60 # Acquires inherited kwargs 61 kwargs = inherited.copy() 62 63 # Fixes metadata 64 if "kwargs" in kwargs: 65 kwargs.update(kwargs["kwargs"]) 66 del kwargs["kwargs"] 67 if "self" in kwargs: 68 del kwargs["self"] 69 if "__class__" in kwargs: 70 del kwargs["__class__"] 71 if "skips" in kwargs: 72 del kwargs["skips"] 73 if "decorators" in kwargs: 74 del kwargs["decorators"] 75 76 # Overrides metadata 77 kwargs.update(common_kwargs) 78 kwargs.update(overrides) 79 80 # At the moment no prims support autograd, so we must not run autograd 81 # tests e.g. when testing dtype support. Once we start writing autograd 82 # formulas for prims this can be removed. 83 kwargs["supports_autograd"] = False 84 kwargs["supports_gradgrad"] = False 85 kwargs["supports_fwgrad_bwgrad"] = False 86 kwargs["supports_inplace_autograd"] = False 87 kwargs["supports_forward_ad"] = False 88 89 return kwargs 90 91 92class PythonRefInfo(OpInfo): 93 """ 94 An OpInfo for a Python reference of an OpInfo base class operation. 95 """ 96 97 def __init__( 98 self, 99 name, # the stringname of the callable Python reference 100 *, 101 op=None, # the function variant of the operation, populated as torch.<name> if None 102 op_db=None, # The database of opinfos to search for the parent opinfo 103 torch_opinfo_name, # the string name of the corresponding torch opinfo 104 torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo 105 validate_view_consistency=True, 106 **kwargs, 107 ): # additional kwargs override kwargs inherited from the torch opinfo 108 self.torch_opinfo_name = torch_opinfo_name 109 self.torch_opinfo_variant_name = torch_opinfo_variant_name 110 self.torch_opinfo = _find_referenced_opinfo( 111 torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db 112 ) 113 self.validate_view_consistency = validate_view_consistency 114 assert isinstance(self.torch_opinfo, OpInfo) 115 116 inherited = self.torch_opinfo._original_opinfo_args 117 ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) 118 super().__init__(**ukwargs) 119 120 121class ReductionPythonRefInfo(ReductionOpInfo): 122 """ 123 An OpInfo for a Python reference of an elementwise unary operation. 124 """ 125 126 def __init__( 127 self, 128 name, # the stringname of the callable Python reference 129 *, 130 op=None, # the function variant of the operation, populated as torch.<name> if None 131 op_db=None, # The database of opinfos to search for the parent opinfo 132 torch_opinfo_name, # the string name of the corresponding torch opinfo 133 torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo 134 **kwargs, 135 ): # additional kwargs override kwargs inherited from the torch opinfo 136 self.torch_opinfo_name = torch_opinfo_name 137 self.torch_opinfo_variant_name = torch_opinfo_variant_name 138 self.torch_opinfo = _find_referenced_opinfo( 139 torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db 140 ) 141 assert isinstance(self.torch_opinfo, ReductionOpInfo) 142 143 inherited = self.torch_opinfo._original_reduction_args 144 ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) 145 146 # See https://github.com/pytorch/pytorch/issues/77216 147 self.validate_view_consistency = False 148 149 super().__init__(**ukwargs) 150 151 152class ElementwiseUnaryPythonRefInfo(UnaryUfuncInfo): 153 """ 154 An OpInfo for a Python reference of an elementwise unary operation. 155 """ 156 157 def __init__( 158 self, 159 name, # the stringname of the callable Python reference 160 *, 161 op=None, # the function variant of the operation, populated as torch.<name> if None 162 op_db=None, # The database of opinfos to search for the parent opinfo 163 torch_opinfo_name, # the string name of the corresponding torch opinfo 164 torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo 165 validate_view_consistency=True, 166 **kwargs, 167 ): # additional kwargs override kwargs inherited from the torch opinfo 168 self.torch_opinfo_name = torch_opinfo_name 169 self.torch_opinfo_variant_name = torch_opinfo_variant_name 170 self.torch_opinfo = _find_referenced_opinfo( 171 torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db 172 ) 173 self.validate_view_consistency = validate_view_consistency 174 assert isinstance(self.torch_opinfo, UnaryUfuncInfo) 175 176 inherited = self.torch_opinfo._original_unary_ufunc_args 177 ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) 178 179 super().__init__(**ukwargs) 180 181 182class ElementwiseBinaryPythonRefInfo(BinaryUfuncInfo): 183 """ 184 An OpInfo for a Python reference of an elementwise binary operation. 185 """ 186 187 def __init__( 188 self, 189 name, # the stringname of the callable Python reference 190 *, 191 op=None, # the function variant of the operation, populated as torch.<name> if None 192 op_db=None, # The database of opinfos to search for the parent opinfo 193 torch_opinfo_name, # the string name of the corresponding torch opinfo 194 torch_opinfo_variant_name="", # the variant name for corresponding torch opinfo 195 **kwargs, 196 ): # additional kwargs override kwargs inherited from the torch opinfo 197 self.torch_opinfo_name = torch_opinfo_name 198 self.torch_opinfo_variant_name = torch_opinfo_variant_name 199 self.torch_opinfo = _find_referenced_opinfo( 200 torch_opinfo_name, torch_opinfo_variant_name, op_db=op_db 201 ) 202 assert isinstance(self.torch_opinfo, BinaryUfuncInfo) 203 204 inherited = self.torch_opinfo._original_binary_ufunc_args 205 ukwargs = _inherit_constructor_args(name, op, inherited, kwargs) 206 207 super().__init__(**ukwargs) 208