xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/opinfo/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import collections
4import warnings
5from functools import partial, wraps
6from typing import Sequence
7
8import numpy as np
9
10import torch
11from torch.testing._internal.common_cuda import TEST_CUDA
12from torch.testing._internal.common_dtype import (
13    _dispatch_dtypes,
14    all_types,
15    all_types_and,
16    all_types_and_complex,
17    all_types_and_complex_and,
18    all_types_and_half,
19    complex_types,
20    floating_and_complex_types,
21    floating_and_complex_types_and,
22    floating_types,
23    floating_types_and,
24    floating_types_and_half,
25    integral_types,
26    integral_types_and,
27)
28from torch.testing._internal.common_utils import torch_to_numpy_dtype_dict
29
30
31COMPLETE_DTYPES_DISPATCH = (
32    all_types,
33    all_types_and_complex,
34    all_types_and_half,
35    floating_types,
36    floating_and_complex_types,
37    floating_types_and_half,
38    integral_types,
39    complex_types,
40)
41
42EXTENSIBLE_DTYPE_DISPATCH = (
43    all_types_and_complex_and,
44    floating_types_and,
45    floating_and_complex_types_and,
46    integral_types_and,
47    all_types_and,
48)
49
50# Better way to acquire devices?
51DEVICES = ["cpu"] + (["cuda"] if TEST_CUDA else [])
52
53
54class _dynamic_dispatch_dtypes(_dispatch_dtypes):
55    # Class to tag the dynamically generated types.
56    pass
57
58
59def get_supported_dtypes(op, sample_inputs_fn, device_type):
60    # Returns the supported dtypes for the given operator and device_type pair.
61    assert device_type in ["cpu", "cuda"]
62    if not TEST_CUDA and device_type == "cuda":
63        warnings.warn(
64            "WARNING: CUDA is not available, empty_dtypes dispatch will be returned!"
65        )
66        return _dynamic_dispatch_dtypes(())
67
68    supported_dtypes = set()
69    for dtype in all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half):
70        try:
71            samples = sample_inputs_fn(op, device_type, dtype, False)
72        except RuntimeError:
73            # If `sample_inputs_fn` doesn't support sampling for a given
74            # `dtype`, we assume that the `dtype` is not supported.
75            # We raise a warning, so that user knows that this was the case
76            # and can investigate if there was an issue with the `sample_inputs_fn`.
77            warnings.warn(
78                f"WARNING: Unable to generate sample for device:{device_type} and dtype:{dtype}"
79            )
80            continue
81
82        # We assume the dtype is supported
83        # only if all samples pass for the given dtype.
84        supported = True
85        for sample in samples:
86            try:
87                op(sample.input, *sample.args, **sample.kwargs)
88            except RuntimeError as re:
89                # dtype is not supported
90                supported = False
91                break
92
93        if supported:
94            supported_dtypes.add(dtype)
95
96    return _dynamic_dispatch_dtypes(supported_dtypes)
97
98
99def dtypes_dispatch_hint(dtypes):
100    # Function returns the appropriate dispatch function (from COMPLETE_DTYPES_DISPATCH and EXTENSIBLE_DTYPE_DISPATCH)
101    # and its string representation for the passed `dtypes`.
102    return_type = collections.namedtuple("return_type", "dispatch_fn dispatch_fn_str")
103
104    # CUDA is not available, dtypes will be empty.
105    if len(dtypes) == 0:
106        return return_type((), "()")
107
108    set_dtypes = set(dtypes)
109    for dispatch in COMPLETE_DTYPES_DISPATCH:
110        # Short circuit if we get an exact match.
111        if set(dispatch()) == set_dtypes:
112            return return_type(dispatch, dispatch.__name__ + "()")
113
114    chosen_dispatch = None
115    chosen_dispatch_score = 0.0
116    for dispatch in EXTENSIBLE_DTYPE_DISPATCH:
117        dispatch_dtypes = set(dispatch())
118        if not dispatch_dtypes.issubset(set_dtypes):
119            continue
120
121        score = len(dispatch_dtypes)
122        if score > chosen_dispatch_score:
123            chosen_dispatch_score = score
124            chosen_dispatch = dispatch
125
126    # If user passed dtypes which are lower than the lowest
127    # dispatch type available (not likely but possible in code path).
128    if chosen_dispatch is None:
129        return return_type((), str(dtypes))
130
131    return return_type(
132        partial(dispatch, *tuple(set(dtypes) - set(dispatch()))),
133        dispatch.__name__ + str(tuple(set(dtypes) - set(dispatch()))),
134    )
135
136
137def is_dynamic_dtype_set(op):
138    # Detect if the OpInfo entry acquired dtypes dynamically
139    # using `get_supported_dtypes`.
140    return op.dynamic_dtypes
141
142
143def str_format_dynamic_dtype(op):
144    fmt_str = f"""
145        OpInfo({op.name},
146               dtypes={dtypes_dispatch_hint(op.dtypes).dispatch_fn_str},
147               dtypesIfCUDA={dtypes_dispatch_hint(op.dtypesIfCUDA).dispatch_fn_str},
148        )
149        """
150
151    return fmt_str
152
153
154def np_unary_ufunc_integer_promotion_wrapper(fn):
155    # Wrapper that passes PyTorch's default scalar
156    #   type as an argument to the wrapped NumPy
157    #   unary ufunc when given an integer input.
158    #   This mimicks PyTorch's integer->floating point
159    #   type promotion.
160    #
161    # This is necessary when NumPy promotes
162    #   integer types to double, since PyTorch promotes
163    #   integer types to the default scalar type.
164
165    # Helper to determine if promotion is needed
166    def is_integral(dtype):
167        return dtype in [
168            np.bool_,
169            bool,
170            np.uint8,
171            np.int8,
172            np.int16,
173            np.int32,
174            np.int64,
175        ]
176
177    @wraps(fn)
178    def wrapped_fn(x):
179        # As the default dtype can change, acquire it when function is called.
180        # NOTE: Promotion in PyTorch is from integer types to the default dtype
181        np_dtype = torch_to_numpy_dtype_dict[torch.get_default_dtype()]
182
183        if is_integral(x.dtype):
184            return fn(x.astype(np_dtype))
185        return fn(x)
186
187    return wrapped_fn
188
189
190def reference_reduction_numpy(f, supports_keepdims=True):
191    """Wraps a NumPy reduction operator.
192
193    The wrapper function will forward dim, keepdim, mask, and identity
194    kwargs to the wrapped function as the NumPy equivalent axis,
195    keepdims, where, and initiak kwargs, respectively.
196
197    Args:
198        f: NumPy reduction operator to wrap
199        supports_keepdims (bool, optional): Whether the NumPy operator accepts
200            keepdims parameter. If it does not, the wrapper will manually unsqueeze
201            the reduced dimensions if it was called with keepdim=True. Defaults to True.
202
203    Returns:
204        Wrapped function
205
206    """
207
208    @wraps(f)
209    def wrapper(x: np.ndarray, *args, **kwargs):
210        # Copy keys into a set
211        keys = set(kwargs.keys())
212
213        dim = kwargs.pop("dim", None)
214        keepdim = kwargs.pop("keepdim", False)
215
216        if "dim" in keys:
217            dim = tuple(dim) if isinstance(dim, Sequence) else dim
218
219            # NumPy reductions don't accept dim=0 for scalar inputs
220            # so we convert it to None if and only if dim is equivalent
221            if x.ndim == 0 and dim in {0, -1, (0,), (-1,)}:
222                kwargs["axis"] = None
223            else:
224                kwargs["axis"] = dim
225
226        if "keepdim" in keys and supports_keepdims:
227            kwargs["keepdims"] = keepdim
228
229        if "mask" in keys:
230            mask = kwargs.pop("mask")
231            if mask is not None:
232                assert mask.layout == torch.strided
233                kwargs["where"] = mask.cpu().numpy()
234
235        if "identity" in keys:
236            identity = kwargs.pop("identity")
237            if identity is not None:
238                if identity.dtype is torch.bfloat16:
239                    identity = identity.cpu().to(torch.float32)
240                else:
241                    identity = identity.cpu()
242                kwargs["initial"] = identity.numpy()
243
244        result = f(x, *args, **kwargs)
245
246        # Unsqueeze reduced dimensions if NumPy does not support keepdims
247        if keepdim and not supports_keepdims and x.ndim > 0:
248            dim = list(range(x.ndim)) if dim is None else dim
249            result = np.expand_dims(result, dim)
250
251        return result
252
253    return wrapper
254
255
256def prod_numpy(a, *args, **kwargs):
257    """
258    The function will call np.prod with type as np.int64 if the input type
259    is int or uint64 if is uint. This is necessary because windows np.prod uses by default
260    int32 while on linux it uses int64.
261    This is for fixing integer overflow https://github.com/pytorch/pytorch/issues/77320
262
263    Returns:
264        np.prod of input
265    """
266    if "dtype" not in kwargs:
267        if np.issubdtype(a.dtype, np.signedinteger):
268            a = a.astype(np.int64)
269        elif np.issubdtype(a.dtype, np.unsignedinteger):
270            a = a.astype(np.uint64)
271
272    fn = reference_reduction_numpy(np.prod)
273    return fn(a, *args, **kwargs)
274