1 #pragma once 2 #include <c10/macros/Macros.h> 3 #include <c10/util/BFloat16.h> 4 #include <c10/util/Float8_e4m3fn.h> 5 #include <c10/util/Float8_e4m3fnuz.h> 6 #include <c10/util/Float8_e5m2.h> 7 #include <c10/util/Float8_e5m2fnuz.h> 8 #include <c10/util/Half.h> 9 #include <c10/util/complex.h> 10 11 #include <type_traits> 12 13 C10_CLANG_DIAGNOSTIC_PUSH() 14 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") 15 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") 16 #endif 17 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion") 18 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion") 19 #endif 20 21 namespace c10 { 22 23 template <typename dest_t, typename src_t> 24 struct needs_real { 25 constexpr static bool value = 26 (is_complex<src_t>::value && !is_complex<dest_t>::value); 27 }; 28 29 template <bool, typename src_t> 30 struct maybe_real { applymaybe_real31 C10_HOST_DEVICE static inline src_t apply(src_t src) { 32 return src; 33 } 34 }; 35 36 template <typename src_t> 37 struct maybe_real<true, src_t> { 38 C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { 39 return src.real(); 40 } 41 }; 42 43 template <bool, typename src_t> 44 struct maybe_bool { 45 C10_HOST_DEVICE static inline src_t apply(src_t src) { 46 return src; 47 } 48 }; 49 50 template <typename src_t> 51 struct maybe_bool<true, src_t> { 52 C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) { 53 // Don't use bool operator so as to to also compile for ComplexHalf. 54 return src.real() || src.imag(); 55 } 56 }; 57 58 // Note: deliberately ignores undefined behavior, consistent with NumPy. 59 // PyTorch's type conversions can cause a variety of undefined behavior, 60 // including float to integral overflow and signed to unsigned integer overflow. 61 // Some of this undefined behavior is addressed below. 62 template <typename dest_t, typename src_t> 63 struct static_cast_with_inter_type { 64 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply( 65 src_t src) { 66 constexpr bool real = needs_real<dest_t, src_t>::value; 67 auto r = maybe_real<real, src_t>::apply(src); 68 return static_cast<dest_t>(r); 69 } 70 }; 71 72 // Partial template specialization for casting to bool. 73 // Need to handle complex types separately, as we don't 74 // simply want to cast the real part to bool. 75 template <typename src_t> 76 struct static_cast_with_inter_type<bool, src_t> { 77 C10_HOST_DEVICE static inline bool apply(src_t src) { 78 constexpr bool complex = needs_real<bool, src_t>::value; 79 return static_cast<bool>(maybe_bool<complex, src_t>::apply(src)); 80 } 81 }; 82 83 // Partial template instantiation for casting to uint8. 84 // Note: Converting from negative float values to unsigned integer types is 85 // undefined behavior in C++, and current CPU and GPU compilers exhibit 86 // divergent behavior. Casting from negative float values to signed 87 // integer types and then to unsigned integer types is not undefined, 88 // however, so this cast improves the consistency of type conversions 89 // to uint8 across compilers. 90 // Further note: Type conversions across compilers still have other undefined 91 // and divergent behavior. 92 template <typename src_t> 93 struct static_cast_with_inter_type<uint8_t, src_t> { 94 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply( 95 src_t src) { 96 constexpr bool real = needs_real<uint8_t, src_t>::value; 97 return static_cast<uint8_t>( 98 static_cast<int64_t>(maybe_real<real, src_t>::apply(src))); 99 } 100 }; 101 102 template <> 103 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::BFloat16> { 104 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 105 c10::Half> 106 apply(c10::BFloat16 src) { 107 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src}); 108 } 109 }; 110 111 template <> 112 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Float8_e5m2> { 113 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 114 c10::Half> 115 apply(c10::Float8_e5m2 src) { 116 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src}); 117 } 118 }; 119 120 template <> 121 struct static_cast_with_inter_type< 122 c10::complex<c10::Half>, 123 c10::Float8_e5m2fnuz> { 124 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 125 c10::Half> 126 apply(c10::Float8_e5m2fnuz src) { 127 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src}); 128 } 129 }; 130 131 template <> 132 struct static_cast_with_inter_type< 133 c10::complex<c10::Half>, 134 c10::Float8_e4m3fn> { 135 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 136 c10::Half> 137 apply(c10::Float8_e4m3fn src) { 138 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src}); 139 } 140 }; 141 142 template <> 143 struct static_cast_with_inter_type< 144 c10::complex<c10::Half>, 145 c10::Float8_e4m3fnuz> { 146 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 147 c10::Half> 148 apply(c10::Float8_e4m3fnuz src) { 149 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src}); 150 } 151 }; 152 153 template <> 154 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> { 155 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 156 c10::Half> 157 apply(c10::Half src) { 158 return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src}); 159 } 160 }; 161 162 template <> 163 struct static_cast_with_inter_type< 164 c10::complex<c10::Half>, 165 c10::complex<double>> { 166 C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex< 167 c10::Half> 168 apply(c10::complex<double> src) { 169 return static_cast<c10::complex<c10::Half>>( 170 static_cast<c10::complex<float>>(src)); 171 } 172 }; 173 174 template <typename To, typename From> 175 C10_HOST_DEVICE To convert(From f) { 176 return static_cast_with_inter_type<To, From>::apply(f); 177 } 178 179 // Define separately to avoid being inlined and prevent code-size bloat 180 [[noreturn]] C10_API void report_overflow(const char* name); 181 182 template <typename To, typename From> 183 To checked_convert(From f, const char* name) { 184 // Converting to bool can't overflow so we exclude this case from checking. 185 if (!std::is_same_v<To, bool> && overflows<To, From>(f)) { 186 report_overflow(name); 187 } 188 return convert<To, From>(f); 189 } 190 191 } // namespace c10 192 193 C10_CLANG_DIAGNOSTIC_POP() 194 195 // Trigger tests for D25440771. TODO: Remove this line any time you want. 196