1# mypy: ignore-errors 2 3import collections 4import warnings 5from functools import partial, wraps 6from typing import Sequence 7 8import numpy as np 9 10import torch 11from torch.testing._internal.common_cuda import TEST_CUDA 12from torch.testing._internal.common_dtype import ( 13 _dispatch_dtypes, 14 all_types, 15 all_types_and, 16 all_types_and_complex, 17 all_types_and_complex_and, 18 all_types_and_half, 19 complex_types, 20 floating_and_complex_types, 21 floating_and_complex_types_and, 22 floating_types, 23 floating_types_and, 24 floating_types_and_half, 25 integral_types, 26 integral_types_and, 27) 28from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict 29 30 31COMPLETE_DTYPES_DISPATCH = ( 32 all_types, 33 all_types_and_complex, 34 all_types_and_half, 35 floating_types, 36 floating_and_complex_types, 37 floating_types_and_half, 38 integral_types, 39 complex_types, 40) 41 42EXTENSIBLE_DTYPE_DISPATCH = ( 43 all_types_and_complex_and, 44 floating_types_and, 45 floating_and_complex_types_and, 46 integral_types_and, 47 all_types_and, 48) 49 50# Better way to acquire devices? 51DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else []) 52 53 54class _dynamic_dispatch_dtypes(_dispatch_dtypes): 55 # Class to tag the dynamically generated types. 56 pass 57 58 59def get_supported_dtypes(op, sample_inputs_fn, device_type): 60 # Returns the supported dtypes for the given operator and device_type pair. 61 assert device_type in ["cpu", "cuda"] 62 if not TEST_CUDA and device_type == "cuda": 63 warnings.warn( 64 "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!" 65 ) 66 return _dynamic_dispatch_dtypes(()) 67 68 supported_dtypes = set() 69 for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half): 70 try: 71 samples = sample_inputs_fn(op, device_type, dtype, False) 72 except RuntimeError: 73 # If `sample_inputs_fn` doesn't support sampling for a given 74 # `dtype`, we assume that the `dtype` is not supported. 75 # We raise a warning, so that user knows that this was the case 76 # and can investigate if there was an issue with the `sample_inputs_fn`. 77 warnings.warn( 78 f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}" 79 ) 80 continue 81 82 # We assume the dtype is supported 83 # only if all samples pass for the given dtype. 84 supported = True 85 for sample in samples: 86 try: 87 op(sample.input, *sample.args, **sample.kwargs) 88 except RuntimeError as re: 89 # dtype is not supported 90 supported = False 91 break 92 93 if supported: 94 supported_dtypes.add(dtype) 95 96 return _dynamic_dispatch_dtypes(supported_dtypes) 97 98 99def dtypes_dispatch_hint(dtypes): 100 # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH) 101 # and its string representation for the passed `dtypes`. 102 return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str") 103 104 # CUDA is not available, dtypes will be empty. 105 if len(dtypes) == 0: 106 return return_type((), "()") 107 108 set_dtypes = set(dtypes) 109 for dispatch in COMPLETE_DTYPES_DISPATCH: 110 # Short circuit if we get an exact match. 111 if set(dispatch()) == set_dtypes: 112 return return_type(dispatch, dispatch.__name__ + "()") 113 114 chosen_dispatch = None 115 chosen_dispatch_score = 0.0 116 for dispatch in EXTENSIBLE_DTYPE_DISPATCH: 117 dispatch_dtypes = set(dispatch()) 118 if not dispatch_dtypes.issubset(set_dtypes): 119 continue 120 121 score = len(dispatch_dtypes) 122 if score > chosen_dispatch_score: 123 chosen_dispatch_score = score 124 chosen_dispatch = dispatch 125 126 # If user passed dtypes which are lower than the lowest 127 # dispatch type available (not likely but possible in code path). 128 if chosen_dispatch is None: 129 return return_type((), str(dtypes)) 130 131 return return_type( 132 partial(dispatch, *tuple(set(dtypes) - set(dispatch()))), 133 dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))), 134 ) 135 136 137def is_dynamic_dtype_set(op): 138 # Detect if the OpInfo entry acquired dtypes dynamically 139 # using `get_supported_dtypes`. 140 return op.dynamic_dtypes 141 142 143def str_format_dynamic_dtype(op): 144 fmt_str = f""" 145 OpInfo({op.name}, 146 dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str}, 147 dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str}, 148 ) 149 """ 150 151 return fmt_str 152 153 154def np_unary_ufunc_integer_promotion_wrapper(fn): 155 # Wrapper that passes PyTorch's default scalar 156 # type as an argument to the wrapped NumPy 157 # unary ufunc when given an integer input. 158 # This mimicks PyTorch's integer->floating point 159 # type promotion. 160 # 161 # This is necessary when NumPy promotes 162 # integer types to double, since PyTorch promotes 163 # integer types to the default scalar type. 164 165 # Helper to determine if promotion is needed 166 def is_integral(dtype): 167 return dtype in [ 168 np.bool_, 169 bool, 170 np.uint8, 171 np.int8, 172 np.int16, 173 np.int32, 174 np.int64, 175 ] 176 177 @wraps(fn) 178 def wrapped_fn(x): 179 # As the default dtype can change, acquire it when function is called. 180 # NOTE: Promotion in PyTorch is from integer types to the default dtype 181 np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()] 182 183 if is_integral(x.dtype): 184 return fn(x.astype(np_dtype)) 185 return fn(x) 186 187 return wrapped_fn 188 189 190def reference_reduction_numpy(f, supports_keepdims=True): 191 """Wraps a NumPy reduction operator. 192 193 The wrapper function will forward dim, keepdim, mask, and identity 194 kwargs to the wrapped function as the NumPy equivalent axis, 195 keepdims, where, and initiak kwargs, respectively. 196 197 Args: 198 f: NumPy reduction operator to wrap 199 supports_keepdims (bool, optional): Whether the NumPy operator accepts 200 keepdims parameter. If it does not, the wrapper will manually unsqueeze 201 the reduced dimensions if it was called with keepdim=True. Defaults to True. 202 203 Returns: 204 Wrapped function 205 206 """ 207 208 @wraps(f) 209 def wrapper(x: np.ndarray, *args, **kwargs): 210 # Copy keys into a set 211 keys = set(kwargs.keys()) 212 213 dim = kwargs.pop("dim", None) 214 keepdim = kwargs.pop("keepdim", False) 215 216 if "dim" in keys: 217 dim = tuple(dim) if isinstance(dim, Sequence) else dim 218 219 # NumPy reductions don't accept dim=0 for scalar inputs 220 # so we convert it to None if and only if dim is equivalent 221 if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}: 222 kwargs["axis"] = None 223 else: 224 kwargs["axis"] = dim 225 226 if "keepdim" in keys and supports_keepdims: 227 kwargs["keepdims"] = keepdim 228 229 if "mask" in keys: 230 mask = kwargs.pop("mask") 231 if mask is not None: 232 assert mask.layout == torch.strided 233 kwargs["where"] = mask.cpu().numpy() 234 235 if "identity" in keys: 236 identity = kwargs.pop("identity") 237 if identity is not None: 238 if identity.dtype is torch.bfloat16: 239 identity = identity.cpu().to(torch.float32) 240 else: 241 identity = identity.cpu() 242 kwargs["initial"] = identity.numpy() 243 244 result = f(x, *args, **kwargs) 245 246 # Unsqueeze reduced dimensions if NumPy does not support keepdims 247 if keepdim and not supports_keepdims and x.ndim > 0: 248 dim = list(range(x.ndim)) if dim is None else dim 249 result = np.expand_dims(result, dim) 250 251 return result 252 253 return wrapper 254 255 256def prod_numpy(a, *args, **kwargs): 257 """ 258 The function will call np.prod with type as np.int64 if the input type 259 is int or uint64 if is uint. This is necessary because windows np.prod uses by default 260 int32 while on linux it uses int64. 261 This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320 262 263 Returns: 264 np.prod of input 265 """ 266 if "dtype" not in kwargs: 267 if np.issubdtype(a.dtype, np.signedinteger): 268 a = a.astype(np.int64) 269 elif np.issubdtype(a.dtype, np.unsignedinteger): 270 a = a.astype(np.uint64) 271 272 fn = reference_reduction_numpy(np.prod) 273 return fn(a, *args, **kwargs) 274