xref: /aosp_15_r20/external/pytorch/c10/util/complex_math.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/complex.h>
2 
3 #include <cmath>
4 
5 // Note [ Complex Square root in libc++]
6 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
7 // In libc++ complex square root is computed using polar form
8 // This is a reasonably fast algorithm, but can result in significant
9 // numerical errors when arg is close to 0, pi/2, pi, or 3pi/4
10 // In that case provide a more conservative implementation which is
11 // slower but less prone to those kinds of errors
12 // In libstdc++ complex square root yield invalid results
13 // for -x-0.0j unless C99 csqrt/csqrtf fallbacks are used
14 
15 #if defined(_LIBCPP_VERSION) || \
16     (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))
17 
18 namespace {
19 template <typename T>
compute_csqrt(const c10::complex<T> & z)20 c10::complex<T> compute_csqrt(const c10::complex<T>& z) {
21   constexpr auto half = T(.5);
22 
23   // Trust standard library to correctly handle infs and NaNs
24   if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) ||
25       std::isnan(z.imag())) {
26     return static_cast<c10::complex<T>>(
27         std::sqrt(static_cast<std::complex<T>>(z)));
28   }
29 
30   // Special case for square root of pure imaginary values
31   if (z.real() == T(0)) {
32     if (z.imag() == T(0)) {
33       return c10::complex<T>(T(0), z.imag());
34     }
35     auto v = std::sqrt(half * std::abs(z.imag()));
36     return c10::complex<T>(v, std::copysign(v, z.imag()));
37   }
38 
39   // At this point, z is non-zero and finite
40   if (z.real() >= 0.0) {
41     auto t = std::sqrt((z.real() + std::abs(z)) * half);
42     return c10::complex<T>(t, half * (z.imag() / t));
43   }
44 
45   auto t = std::sqrt((-z.real() + std::abs(z)) * half);
46   return c10::complex<T>(
47       half * std::abs(z.imag() / t), std::copysign(t, z.imag()));
48 }
49 
50 // Compute complex arccosine using formula from W. Kahan
51 // "Branch Cuts for Complex Elementary Functions" 1986 paper:
52 // cacos(z).re = 2*atan2(sqrt(1-z).re(), sqrt(1+z).re())
53 // cacos(z).im = asinh((sqrt(conj(1+z))*sqrt(1-z)).im())
54 template <typename T>
compute_cacos(const c10::complex<T> & z)55 c10::complex<T> compute_cacos(const c10::complex<T>& z) {
56   auto constexpr one = T(1);
57   // Trust standard library to correctly handle infs and NaNs
58   if (std::isinf(z.real()) || std::isinf(z.imag()) || std::isnan(z.real()) ||
59       std::isnan(z.imag())) {
60     return static_cast<c10::complex<T>>(
61         std::acos(static_cast<std::complex<T>>(z)));
62   }
63   auto a = compute_csqrt(c10::complex<T>(one - z.real(), -z.imag()));
64   auto b = compute_csqrt(c10::complex<T>(one + z.real(), z.imag()));
65   auto c = compute_csqrt(c10::complex<T>(one + z.real(), -z.imag()));
66   auto r = T(2) * std::atan2(a.real(), b.real());
67   // Explicitly unroll (a*c).imag()
68   auto i = std::asinh(a.real() * c.imag() + a.imag() * c.real());
69   return c10::complex<T>(r, i);
70 }
71 } // anonymous namespace
72 
73 namespace c10_complex_math {
74 namespace _detail {
sqrt(const c10::complex<float> & in)75 c10::complex<float> sqrt(const c10::complex<float>& in) {
76   return compute_csqrt(in);
77 }
78 
sqrt(const c10::complex<double> & in)79 c10::complex<double> sqrt(const c10::complex<double>& in) {
80   return compute_csqrt(in);
81 }
82 
acos(const c10::complex<float> & in)83 c10::complex<float> acos(const c10::complex<float>& in) {
84   return compute_cacos(in);
85 }
86 
acos(const c10::complex<double> & in)87 c10::complex<double> acos(const c10::complex<double>& in) {
88   return compute_cacos(in);
89 }
90 
91 } // namespace _detail
92 } // namespace c10_complex_math
93 #endif
94