xref: /aosp_15_r20/external/pytorch/torch/_numpy/_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3"""Assorted utilities, which do not need anything other then torch and stdlib.
4"""
5
6import operator
7
8import torch
9
10from . import _dtypes_impl
11
12
13# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
14def is_sequence(seq):
15    if isinstance(seq, str):
16        return False
17    try:
18        len(seq)
19    except Exception:
20        return False
21    return True
22
23
24class AxisError(ValueError, IndexError):
25    pass
26
27
28class UFuncTypeError(TypeError, RuntimeError):
29    pass
30
31
32def cast_if_needed(tensor, dtype):
33    # NB: no casting if dtype=None
34    if dtype is not None and tensor.dtype != dtype:
35        tensor = tensor.to(dtype)
36    return tensor
37
38
39def cast_int_to_float(x):
40    # cast integers and bools to the default float dtype
41    if _dtypes_impl._category(x.dtype) < 2:
42        x = x.to(_dtypes_impl.default_dtypes().float_dtype)
43    return x
44
45
46# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
47def normalize_axis_index(ax, ndim, argname=None):
48    if not (-ndim <= ax < ndim):
49        raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}")
50    if ax < 0:
51        ax += ndim
52    return ax
53
54
55# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378
56def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
57    """
58    Normalizes an axis argument into a tuple of non-negative integer axes.
59
60    This handles shorthands such as ``1`` and converts them to ``(1,)``,
61    as well as performing the handling of negative indices covered by
62    `normalize_axis_index`.
63
64    By default, this forbids axes from being specified multiple times.
65    Used internally by multi-axis-checking logic.
66
67    Parameters
68    ----------
69    axis : int, iterable of int
70        The un-normalized index or indices of the axis.
71    ndim : int
72        The number of dimensions of the array that `axis` should be normalized
73        against.
74    argname : str, optional
75        A prefix to put before the error message, typically the name of the
76        argument.
77    allow_duplicate : bool, optional
78        If False, the default, disallow an axis from being specified twice.
79
80    Returns
81    -------
82    normalized_axes : tuple of int
83        The normalized axis index, such that `0 <= normalized_axis < ndim`
84    """
85    # Optimization to speed-up the most common cases.
86    if type(axis) not in (tuple, list):
87        try:
88            axis = [operator.index(axis)]
89        except TypeError:
90            pass
91    # Going via an iterator directly is slower than via list comprehension.
92    axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
93    if not allow_duplicate and len(set(map(int, axis))) != len(axis):
94        if argname:
95            raise ValueError(f"repeated axis in `{argname}` argument")
96        else:
97            raise ValueError("repeated axis")
98    return axis
99
100
101def allow_only_single_axis(axis):
102    if axis is None:
103        return axis
104    if len(axis) != 1:
105        raise NotImplementedError("does not handle tuple axis")
106    return axis[0]
107
108
109def expand_shape(arr_shape, axis):
110    # taken from numpy 1.23.x, expand_dims function
111    if type(axis) not in (list, tuple):
112        axis = (axis,)
113    out_ndim = len(axis) + len(arr_shape)
114    axis = normalize_axis_tuple(axis, out_ndim)
115    shape_it = iter(arr_shape)
116    shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
117    return shape
118
119
120def apply_keepdims(tensor, axis, ndim):
121    if axis is None:
122        # tensor was a scalar
123        shape = (1,) * ndim
124        tensor = tensor.expand(shape).contiguous()
125    else:
126        shape = expand_shape(tensor.shape, axis)
127        tensor = tensor.reshape(shape)
128    return tensor
129
130
131def axis_none_flatten(*tensors, axis=None):
132    """Flatten the arrays if axis is None."""
133    if axis is None:
134        tensors = tuple(ar.flatten() for ar in tensors)
135        return tensors, 0
136    else:
137        return tensors, axis
138
139
140def typecast_tensor(t, target_dtype, casting):
141    """Dtype-cast tensor to target_dtype.
142
143    Parameters
144    ----------
145    t : torch.Tensor
146        The tensor to cast
147    target_dtype : torch dtype object
148        The array dtype to cast all tensors to
149    casting : str
150        The casting mode, see `np.can_cast`
151
152     Returns
153     -------
154    `torch.Tensor` of the `target_dtype` dtype
155
156     Raises
157     ------
158     ValueError
159        if the argument cannot be cast according to the `casting` rule
160
161    """
162    can_cast = _dtypes_impl.can_cast_impl
163
164    if not can_cast(t.dtype, target_dtype, casting=casting):
165        raise TypeError(
166            f"Cannot cast array data from {t.dtype} to"
167            f" {target_dtype} according to the rule '{casting}'"
168        )
169    return cast_if_needed(t, target_dtype)
170
171
172def typecast_tensors(tensors, target_dtype, casting):
173    return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
174
175
176def _try_convert_to_tensor(obj):
177    try:
178        tensor = torch.as_tensor(obj)
179    except Exception as e:
180        mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}."
181        raise NotImplementedError(mesg)  # noqa: B904
182    return tensor
183
184
185def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
186    """The core logic of the array(...) function.
187
188    Parameters
189    ----------
190    obj : tensor_like
191        The thing to coerce
192    dtype : torch.dtype object or None
193        Coerce to this torch dtype
194    copy : bool
195        Copy or not
196    ndmin : int
197        The results as least this many dimensions
198    is_weak : bool
199        Whether obj is a weakly typed python scalar.
200
201    Returns
202    -------
203    tensor : torch.Tensor
204        a tensor object with requested dtype, ndim and copy semantics.
205
206    Notes
207    -----
208    This is almost a "tensor_like" coersion function. Does not handle wrapper
209    ndarrays (those should be handled in the ndarray-aware layer prior to
210    invoking this function).
211    """
212    if isinstance(obj, torch.Tensor):
213        tensor = obj
214    else:
215        # tensor.dtype is the pytorch default, typically float32. If obj's elements
216        # are not exactly representable in float32, we've lost precision:
217        # >>> torch.as_tensor(1e12).item() - 1e12
218        # -4096.0
219        default_dtype = torch.get_default_dtype()
220        torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32))
221        try:
222            tensor = _try_convert_to_tensor(obj)
223        finally:
224            torch.set_default_dtype(default_dtype)
225
226    # type cast if requested
227    tensor = cast_if_needed(tensor, dtype)
228
229    # adjust ndim if needed
230    ndim_extra = ndmin - tensor.ndim
231    if ndim_extra > 0:
232        tensor = tensor.view((1,) * ndim_extra + tensor.shape)
233
234    # copy if requested
235    if copy:
236        tensor = tensor.clone()
237
238    return tensor
239
240
241def ndarrays_to_tensors(*inputs):
242    """Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
243    from ._ndarray import ndarray
244
245    if len(inputs) == 0:
246        return ValueError()
247    elif len(inputs) == 1:
248        input_ = inputs[0]
249        if isinstance(input_, ndarray):
250            return input_.tensor
251        elif isinstance(input_, tuple):
252            result = []
253            for sub_input in input_:
254                sub_result = ndarrays_to_tensors(sub_input)
255                result.append(sub_result)
256            return tuple(result)
257        else:
258            return input_
259    else:
260        assert isinstance(inputs, tuple)  # sanity check
261        return ndarrays_to_tensors(inputs)
262