xref: /aosp_15_r20/external/pytorch/torch/_numpy/_dtypes_impl.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3"""Dtypes/scalar type implementaions with torch dtypes.
4
5Here `dtype` is always a torch.dtype, this module knows nothing about
6scalar types, wrapper dtypes or anything like that. PyTorch only.
7"""
8from collections import namedtuple
9
10import torch
11
12
13# defaults : mimic NumPy, allow user control
14DefaultDTypes = namedtuple(
15    "DefaultDTypes", ["float_dtype", "complex_dtype", "int_dtype"]
16)
17
18# a global state
19# We set it the first time we call default_dtypes() to avoid importing
20# torch._dynamo.config and create a circular reference
21_default_dtypes = None
22
23
24def default_dtypes():
25    global _default_dtypes
26    if _default_dtypes is None:
27        import torch._dynamo.config as config
28
29        _default_dtypes = DefaultDTypes(
30            float_dtype=getattr(torch, config.numpy_default_float),
31            complex_dtype=getattr(torch, config.numpy_default_complex),
32            int_dtype=getattr(torch, config.numpy_default_int),
33        )
34        assert isinstance(_default_dtypes.float_dtype, torch.dtype)
35        assert isinstance(_default_dtypes.complex_dtype, torch.dtype)
36        assert isinstance(_default_dtypes.int_dtype, torch.dtype)
37    return _default_dtypes
38
39
40def get_default_dtype_for(dtype):
41    """Default scalar type given sctype category."""
42    if dtype == torch.bool:
43        return dtype
44    if dtype.is_complex:
45        return default_dtypes().complex_dtype
46    if dtype.is_floating_point:
47        return default_dtypes().float_dtype
48    # else, it must be (some) integer
49    return default_dtypes().int_dtype
50
51
52from . import _casting_dicts as _cd
53
54
55def can_cast_impl(from_torch_dtype, to_torch_dtype, casting):
56    return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype]
57
58
59def result_type_impl(*tensors):
60    # NB: torch dtypes here
61    dtyp = tensors[0].dtype
62    if len(tensors) == 1:
63        return dtyp
64
65    for curr in tensors[1:]:
66        dtyp = _cd._result_type_dict[dtyp][curr.dtype]
67
68    return dtyp
69
70
71def python_type_for_torch(dtyp):
72    """Get a python scalar type a torch dtype"""
73    if dtyp.is_floating_point:
74        typ = float
75    elif dtyp.is_complex:
76        typ = complex
77    elif dtyp == torch.bool:
78        typ = bool
79    else:
80        typ = int
81    return typ
82
83
84# ### NEP 50 helpers ###
85
86_SCALAR_TYPES = (int, bool, float, complex)
87
88_SCALAR_AND_SYMBOLIC_TYPES = (
89    *_SCALAR_TYPES,
90    torch.SymInt,
91    torch.SymFloat,
92    torch.SymBool,
93)
94
95_NEP50_FUNCS_TENSOR_ONLY = (
96    "minimum",
97    "maximum",
98    "logaddexp",
99    "logaddexp2",
100    "lcm",
101    "gcd",
102    "hypot",
103    "heaviside",
104    "fmod",
105    "fmin",
106    "fmax",
107    "copysign",
108    "arctan2",
109)
110
111
112def is_scalar(x):
113    return isinstance(x, _SCALAR_TYPES)
114
115
116def is_scalar_or_symbolic(x):
117    return isinstance(x, _SCALAR_AND_SYMBOLIC_TYPES)
118
119
120def _dtype_for_scalar(py_type):
121    return {
122        bool: torch.bool,
123        torch.SymBool: torch.bool,
124        int: torch.int64,
125        torch.SymInt: torch.int64,
126        float: torch.float64,
127        torch.SymFloat: torch.float64,
128        complex: torch.complex128,
129    }[py_type]
130
131
132def _dtype_for_scalar_or_tensor(x):
133    return x.dtype if isinstance(x, torch.Tensor) else _dtype_for_scalar(type(x))
134
135
136def is_float_or_fp_tensor(x):
137    return _dtype_for_scalar_or_tensor(x).is_floating_point
138
139
140def is_complex_or_complex_tensor(x):
141    return _dtype_for_scalar_or_tensor(x).is_complex
142
143
144def _category(dtype):
145    return {
146        torch.bool: 0,
147        torch.SymBool: 0,
148        # int
149        torch.uint8: 1,
150        torch.int8: 1,
151        torch.int16: 1,
152        torch.int32: 1,
153        torch.int64: 1,
154        torch.SymInt: 1,
155        # float
156        torch.float16: 2,
157        torch.float32: 2,
158        torch.float64: 2,
159        torch.SymFloat: 2,
160        # complex
161        torch.complex64: 3,
162        torch.complex128: 3,
163    }[dtype]
164
165
166def nep50_to_tensors(x1, x2, handle_weaks, function_name):
167    """If either of inputs is a python scalar, type-promote with NEP 50."""
168
169    def to_tensor(scalar, dtype=None):
170        if dtype is None:
171            dtype = _dtype_for_scalar(type(scalar))
172            dtype = get_default_dtype_for(dtype)
173        return torch.as_tensor(scalar, dtype=dtype)
174
175    x1_is_weak = not isinstance(x1, torch.Tensor)
176    x2_is_weak = not isinstance(x2, torch.Tensor)
177    if not handle_weaks or (x1_is_weak and x2_is_weak):
178        x1 = to_tensor(x1) if x1_is_weak else x1
179        x2 = to_tensor(x2) if x2_is_weak else x2
180        return x1, x2
181
182    # scalar <op> tensor: NEP 50
183    assert x1_is_weak != x2_is_weak
184
185    weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1)
186
187    # find the dtype for the weak's type
188    weak_dtype = _dtype_for_scalar(type(weak))
189
190    cat_weak = _category(weak_dtype)
191    cat_not_weak = _category(not_weak.dtype)
192
193    dt = not_weak.dtype if cat_weak <= cat_not_weak else None
194
195    # special-case complex + float32
196    if weak_dtype.is_complex and not_weak.dtype == torch.float32:
197        dt = torch.complex64
198
199    # detect overflows: in PyTorch, uint8(-1) wraps around to 255,
200    # while NEP50 mandates an exception.
201    #
202    # Note that we only check if each element of the binop overflows,
203    # not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK
204    # in uint8, but the result overflows and wrap around 255.
205    # Numpy emits a RuntimeWarning, PyTorch does not, and we do not either.
206    if cat_weak == 1 and cat_not_weak == 1:
207        # integers
208        iinfo = torch.iinfo(not_weak.dtype)
209        if not (iinfo.min <= weak <= iinfo.max):
210            raise OverflowError(
211                f"Python integer {weak} out of bounds for {not_weak.dtype}"
212            )
213    if weak_dtype != dt or function_name in _NEP50_FUNCS_TENSOR_ONLY:
214        # finally, can make `weak` into a 0D tensor, if both parameters are required to be tensor.
215        weak = to_tensor(weak, dt)
216
217    return (weak, not_weak) if x1_is_weak else (not_weak, weak)
218