xref: /aosp_15_r20/external/pytorch/c10/util/TypeCast.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/macros/Macros.h>
3 #include <c10/util/BFloat16.h>
4 #include <c10/util/Float8_e4m3fn.h>
5 #include <c10/util/Float8_e4m3fnuz.h>
6 #include <c10/util/Float8_e5m2.h>
7 #include <c10/util/Float8_e5m2fnuz.h>
8 #include <c10/util/Half.h>
9 #include <c10/util/complex.h>
10 
11 #include <type_traits>
12 
13 C10_CLANG_DIAGNOSTIC_PUSH()
14 #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
15 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
16 #endif
17 #if C10_CLANG_HAS_WARNING("-Wimplicit-int-float-conversion")
18 C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-int-float-conversion")
19 #endif
20 
21 namespace c10 {
22 
23 template <typename dest_t, typename src_t>
24 struct needs_real {
25   constexpr static bool value =
26       (is_complex<src_t>::value && !is_complex<dest_t>::value);
27 };
28 
29 template <bool, typename src_t>
30 struct maybe_real {
applymaybe_real31   C10_HOST_DEVICE static inline src_t apply(src_t src) {
32     return src;
33   }
34 };
35 
36 template <typename src_t>
37 struct maybe_real<true, src_t> {
38   C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
39     return src.real();
40   }
41 };
42 
43 template <bool, typename src_t>
44 struct maybe_bool {
45   C10_HOST_DEVICE static inline src_t apply(src_t src) {
46     return src;
47   }
48 };
49 
50 template <typename src_t>
51 struct maybe_bool<true, src_t> {
52   C10_HOST_DEVICE static inline decltype(auto) apply(src_t src) {
53     // Don't use bool operator so as to to also compile for ComplexHalf.
54     return src.real() || src.imag();
55   }
56 };
57 
58 // Note: deliberately ignores undefined behavior, consistent with NumPy.
59 // PyTorch's type conversions can cause a variety of undefined behavior,
60 // including float to integral overflow and signed to unsigned integer overflow.
61 // Some of this undefined behavior is addressed below.
62 template <typename dest_t, typename src_t>
63 struct static_cast_with_inter_type {
64   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline dest_t apply(
65       src_t src) {
66     constexpr bool real = needs_real<dest_t, src_t>::value;
67     auto r = maybe_real<real, src_t>::apply(src);
68     return static_cast<dest_t>(r);
69   }
70 };
71 
72 // Partial template specialization for casting to bool.
73 // Need to handle complex types separately, as we don't
74 // simply want to cast the real part to bool.
75 template <typename src_t>
76 struct static_cast_with_inter_type<bool, src_t> {
77   C10_HOST_DEVICE static inline bool apply(src_t src) {
78     constexpr bool complex = needs_real<bool, src_t>::value;
79     return static_cast<bool>(maybe_bool<complex, src_t>::apply(src));
80   }
81 };
82 
83 // Partial template instantiation for casting to uint8.
84 // Note: Converting from negative float values to unsigned integer types is
85 // undefined behavior in C++, and current CPU and GPU compilers exhibit
86 // divergent behavior. Casting from negative float values to signed
87 // integer types and then to unsigned integer types is not undefined,
88 // however, so this cast improves the consistency of type conversions
89 // to uint8 across compilers.
90 // Further note: Type conversions across compilers still have other undefined
91 // and divergent behavior.
92 template <typename src_t>
93 struct static_cast_with_inter_type<uint8_t, src_t> {
94   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline uint8_t apply(
95       src_t src) {
96     constexpr bool real = needs_real<uint8_t, src_t>::value;
97     return static_cast<uint8_t>(
98         static_cast<int64_t>(maybe_real<real, src_t>::apply(src)));
99   }
100 };
101 
102 template <>
103 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::BFloat16> {
104   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
105       c10::Half>
106   apply(c10::BFloat16 src) {
107     return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
108   }
109 };
110 
111 template <>
112 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Float8_e5m2> {
113   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
114       c10::Half>
115   apply(c10::Float8_e5m2 src) {
116     return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
117   }
118 };
119 
120 template <>
121 struct static_cast_with_inter_type<
122     c10::complex<c10::Half>,
123     c10::Float8_e5m2fnuz> {
124   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
125       c10::Half>
126   apply(c10::Float8_e5m2fnuz src) {
127     return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
128   }
129 };
130 
131 template <>
132 struct static_cast_with_inter_type<
133     c10::complex<c10::Half>,
134     c10::Float8_e4m3fn> {
135   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
136       c10::Half>
137   apply(c10::Float8_e4m3fn src) {
138     return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
139   }
140 };
141 
142 template <>
143 struct static_cast_with_inter_type<
144     c10::complex<c10::Half>,
145     c10::Float8_e4m3fnuz> {
146   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
147       c10::Half>
148   apply(c10::Float8_e4m3fnuz src) {
149     return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
150   }
151 };
152 
153 template <>
154 struct static_cast_with_inter_type<c10::complex<c10::Half>, c10::Half> {
155   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
156       c10::Half>
157   apply(c10::Half src) {
158     return static_cast<c10::complex<c10::Half>>(c10::complex<float>{src});
159   }
160 };
161 
162 template <>
163 struct static_cast_with_inter_type<
164     c10::complex<c10::Half>,
165     c10::complex<double>> {
166   C10_HOST_DEVICE __ubsan_ignore_undefined__ static inline c10::complex<
167       c10::Half>
168   apply(c10::complex<double> src) {
169     return static_cast<c10::complex<c10::Half>>(
170         static_cast<c10::complex<float>>(src));
171   }
172 };
173 
174 template <typename To, typename From>
175 C10_HOST_DEVICE To convert(From f) {
176   return static_cast_with_inter_type<To, From>::apply(f);
177 }
178 
179 // Define separately to avoid being inlined and prevent code-size bloat
180 [[noreturn]] C10_API void report_overflow(const char* name);
181 
182 template <typename To, typename From>
183 To checked_convert(From f, const char* name) {
184   // Converting to bool can't overflow so we exclude this case from checking.
185   if (!std::is_same_v<To, bool> && overflows<To, From>(f)) {
186     report_overflow(name);
187   }
188   return convert<To, From>(f);
189 }
190 
191 } // namespace c10
192 
193 C10_CLANG_DIAGNOSTIC_POP()
194 
195 // Trigger tests for D25440771. TODO: Remove this line any time you want.
196