1# mypy: ignore-errors 2 3"""Dtypes/scalar type implementaions with torch dtypes. 4 5Here `dtype` is always a torch.dtype, this module knows nothing about 6scalar types, wrapper dtypes or anything like that. PyTorch only. 7""" 8from collections import namedtuple 9 10import torch 11 12 13# defaults : mimic NumPy, allow user control 14DefaultDTypes = namedtuple( 15 "DefaultDTypes", ["float_dtype", "complex_dtype", "int_dtype"] 16) 17 18# a global state 19# We set it the first time we call default_dtypes() to avoid importing 20# torch._dynamo.config and create a circular reference 21_default_dtypes = None 22 23 24def default_dtypes(): 25 global _default_dtypes 26 if _default_dtypes is None: 27 import torch._dynamo.config as config 28 29 _default_dtypes = DefaultDTypes( 30 float_dtype=getattr(torch, config.numpy_default_float), 31 complex_dtype=getattr(torch, config.numpy_default_complex), 32 int_dtype=getattr(torch, config.numpy_default_int), 33 ) 34 assert isinstance(_default_dtypes.float_dtype, torch.dtype) 35 assert isinstance(_default_dtypes.complex_dtype, torch.dtype) 36 assert isinstance(_default_dtypes.int_dtype, torch.dtype) 37 return _default_dtypes 38 39 40def get_default_dtype_for(dtype): 41 """Default scalar type given sctype category.""" 42 if dtype == torch.bool: 43 return dtype 44 if dtype.is_complex: 45 return default_dtypes().complex_dtype 46 if dtype.is_floating_point: 47 return default_dtypes().float_dtype 48 # else, it must be (some) integer 49 return default_dtypes().int_dtype 50 51 52from . import _casting_dicts as _cd 53 54 55def can_cast_impl(from_torch_dtype, to_torch_dtype, casting): 56 return _cd._can_cast_dict[casting][from_torch_dtype][to_torch_dtype] 57 58 59def result_type_impl(*tensors): 60 # NB: torch dtypes here 61 dtyp = tensors[0].dtype 62 if len(tensors) == 1: 63 return dtyp 64 65 for curr in tensors[1:]: 66 dtyp = _cd._result_type_dict[dtyp][curr.dtype] 67 68 return dtyp 69 70 71def python_type_for_torch(dtyp): 72 """Get a python scalar type a torch dtype""" 73 if dtyp.is_floating_point: 74 typ = float 75 elif dtyp.is_complex: 76 typ = complex 77 elif dtyp == torch.bool: 78 typ = bool 79 else: 80 typ = int 81 return typ 82 83 84# ### NEP 50 helpers ### 85 86_SCALAR_TYPES = (int, bool, float, complex) 87 88_SCALAR_AND_SYMBOLIC_TYPES = ( 89 *_SCALAR_TYPES, 90 torch.SymInt, 91 torch.SymFloat, 92 torch.SymBool, 93) 94 95_NEP50_FUNCS_TENSOR_ONLY = ( 96 "minimum", 97 "maximum", 98 "logaddexp", 99 "logaddexp2", 100 "lcm", 101 "gcd", 102 "hypot", 103 "heaviside", 104 "fmod", 105 "fmin", 106 "fmax", 107 "copysign", 108 "arctan2", 109) 110 111 112def is_scalar(x): 113 return isinstance(x, _SCALAR_TYPES) 114 115 116def is_scalar_or_symbolic(x): 117 return isinstance(x, _SCALAR_AND_SYMBOLIC_TYPES) 118 119 120def _dtype_for_scalar(py_type): 121 return { 122 bool: torch.bool, 123 torch.SymBool: torch.bool, 124 int: torch.int64, 125 torch.SymInt: torch.int64, 126 float: torch.float64, 127 torch.SymFloat: torch.float64, 128 complex: torch.complex128, 129 }[py_type] 130 131 132def _dtype_for_scalar_or_tensor(x): 133 return x.dtype if isinstance(x, torch.Tensor) else _dtype_for_scalar(type(x)) 134 135 136def is_float_or_fp_tensor(x): 137 return _dtype_for_scalar_or_tensor(x).is_floating_point 138 139 140def is_complex_or_complex_tensor(x): 141 return _dtype_for_scalar_or_tensor(x).is_complex 142 143 144def _category(dtype): 145 return { 146 torch.bool: 0, 147 torch.SymBool: 0, 148 # int 149 torch.uint8: 1, 150 torch.int8: 1, 151 torch.int16: 1, 152 torch.int32: 1, 153 torch.int64: 1, 154 torch.SymInt: 1, 155 # float 156 torch.float16: 2, 157 torch.float32: 2, 158 torch.float64: 2, 159 torch.SymFloat: 2, 160 # complex 161 torch.complex64: 3, 162 torch.complex128: 3, 163 }[dtype] 164 165 166def nep50_to_tensors(x1, x2, handle_weaks, function_name): 167 """If either of inputs is a python scalar, type-promote with NEP 50.""" 168 169 def to_tensor(scalar, dtype=None): 170 if dtype is None: 171 dtype = _dtype_for_scalar(type(scalar)) 172 dtype = get_default_dtype_for(dtype) 173 return torch.as_tensor(scalar, dtype=dtype) 174 175 x1_is_weak = not isinstance(x1, torch.Tensor) 176 x2_is_weak = not isinstance(x2, torch.Tensor) 177 if not handle_weaks or (x1_is_weak and x2_is_weak): 178 x1 = to_tensor(x1) if x1_is_weak else x1 179 x2 = to_tensor(x2) if x2_is_weak else x2 180 return x1, x2 181 182 # scalar <op> tensor: NEP 50 183 assert x1_is_weak != x2_is_weak 184 185 weak, not_weak = (x1, x2) if x1_is_weak else (x2, x1) 186 187 # find the dtype for the weak's type 188 weak_dtype = _dtype_for_scalar(type(weak)) 189 190 cat_weak = _category(weak_dtype) 191 cat_not_weak = _category(not_weak.dtype) 192 193 dt = not_weak.dtype if cat_weak <= cat_not_weak else None 194 195 # special-case complex + float32 196 if weak_dtype.is_complex and not_weak.dtype == torch.float32: 197 dt = torch.complex64 198 199 # detect overflows: in PyTorch, uint8(-1) wraps around to 255, 200 # while NEP50 mandates an exception. 201 # 202 # Note that we only check if each element of the binop overflows, 203 # not the result. Consider, e.g. `uint8(100) + 200`. Operands are OK 204 # in uint8, but the result overflows and wrap around 255. 205 # Numpy emits a RuntimeWarning, PyTorch does not, and we do not either. 206 if cat_weak == 1 and cat_not_weak == 1: 207 # integers 208 iinfo = torch.iinfo(not_weak.dtype) 209 if not (iinfo.min <= weak <= iinfo.max): 210 raise OverflowError( 211 f"Python integer {weak} out of bounds for {not_weak.dtype}" 212 ) 213 if weak_dtype != dt or function_name in _NEP50_FUNCS_TENSOR_ONLY: 214 # finally, can make `weak` into a 0D tensor, if both parameters are required to be tensor. 215 weak = to_tensor(weak, dt) 216 217 return (weak, not_weak) if x1_is_weak else (not_weak, weak) 218