xref: /aosp_15_r20/external/pytorch/c10/util/complex_utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H)
2 #error \
3     "c10/util/complex_utils.h is not meant to be individually included. Include c10/util/complex.h instead."
4 #endif
5 
6 #include <limits>
7 
8 namespace c10 {
9 
10 template <typename T>
11 struct is_complex : public std::false_type {};
12 
13 template <typename T>
14 struct is_complex<std::complex<T>> : public std::true_type {};
15 
16 template <typename T>
17 struct is_complex<c10::complex<T>> : public std::true_type {};
18 
19 // Extract double from std::complex<double>; is identity otherwise
20 // TODO: Write in more idiomatic C++17
21 template <typename T>
22 struct scalar_value_type {
23   using type = T;
24 };
25 template <typename T>
26 struct scalar_value_type<std::complex<T>> {
27   using type = T;
28 };
29 template <typename T>
30 struct scalar_value_type<c10::complex<T>> {
31   using type = T;
32 };
33 
34 } // namespace c10
35 
36 namespace std {
37 
38 template <typename T>
39 class numeric_limits<c10::complex<T>> : public numeric_limits<T> {};
40 
41 template <typename T>
42 bool isnan(const c10::complex<T>& v) {
43   return std::isnan(v.real()) || std::isnan(v.imag());
44 }
45 
46 } // namespace std
47