#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) #error \ "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead." #endif namespace c10_complex_math { // Exponential functions template C10_HOST_DEVICE inline c10::complex exp(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::exp(static_cast>(x))); #else return static_cast>( std::exp(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::log(static_cast>(x))); #else return static_cast>( std::log(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log10(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::log10(static_cast>(x))); #else return static_cast>( std::log10(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log2(const c10::complex& x) { const c10::complex log2 = c10::complex(::log(2.0), 0.0); return c10_complex_math::log(x) / log2; } // Power functions // #if defined(_LIBCPP_VERSION) || \ (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) namespace _detail { C10_API c10::complex sqrt(const c10::complex& in); C10_API c10::complex sqrt(const c10::complex& in); C10_API c10::complex acos(const c10::complex& in); C10_API c10::complex acos(const c10::complex& in); } // namespace _detail #endif template C10_HOST_DEVICE inline c10::complex sqrt(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::sqrt(static_cast>(x))); #elif !( \ defined(_LIBCPP_VERSION) || \ (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) return static_cast>( std::sqrt(static_cast>(x))); #else return _detail::sqrt(x); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( static_cast>(x), static_cast>(y))); #else return static_cast>(std::pow( static_cast>(x), static_cast>(y))); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const T& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::pow(static_cast>(x), y)); #else return static_cast>( std::pow(static_cast>(x), y)); #endif } template C10_HOST_DEVICE inline c10::complex pow( const T& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::pow(x, static_cast>(y))); #else return static_cast>( std::pow(x, static_cast>(y))); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>(thrust::pow( static_cast>(x), static_cast>(y))); #else return static_cast>(std::pow( static_cast>(x), static_cast>(y))); #endif } template C10_HOST_DEVICE inline c10::complex pow( const c10::complex& x, const U& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::pow(static_cast>(x), y)); #else return static_cast>( std::pow(static_cast>(x), y)); #endif } template C10_HOST_DEVICE inline c10::complex pow( const T& x, const c10::complex& y) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::pow(x, static_cast>(y))); #else return static_cast>( std::pow(x, static_cast>(y))); #endif } // Trigonometric functions template C10_HOST_DEVICE inline c10::complex sin(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::sin(static_cast>(x))); #else return static_cast>( std::sin(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex cos(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::cos(static_cast>(x))); #else return static_cast>( std::cos(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex tan(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::tan(static_cast>(x))); #else return static_cast>( std::tan(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex asin(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::asin(static_cast>(x))); #else return static_cast>( std::asin(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex acos(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::acos(static_cast>(x))); #elif !defined(_LIBCPP_VERSION) return static_cast>( std::acos(static_cast>(x))); #else return _detail::acos(x); #endif } template C10_HOST_DEVICE inline c10::complex atan(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::atan(static_cast>(x))); #else return static_cast>( std::atan(static_cast>(x))); #endif } // Hyperbolic functions template C10_HOST_DEVICE inline c10::complex sinh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::sinh(static_cast>(x))); #else return static_cast>( std::sinh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex cosh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::cosh(static_cast>(x))); #else return static_cast>( std::cosh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex tanh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::tanh(static_cast>(x))); #else return static_cast>( std::tanh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex asinh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::asinh(static_cast>(x))); #else return static_cast>( std::asinh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex acosh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::acosh(static_cast>(x))); #else return static_cast>( std::acosh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex atanh(const c10::complex& x) { #if defined(__CUDACC__) || defined(__HIPCC__) return static_cast>( thrust::atanh(static_cast>(x))); #else return static_cast>( std::atanh(static_cast>(x))); #endif } template C10_HOST_DEVICE inline c10::complex log1p(const c10::complex& z) { #if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \ defined(__HIPCC__) // For Mac, the new implementation yielded a high relative error. Falling back // to the old version for now. // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 // For CUDA we also use this one, as thrust::log(thrust::complex) takes // *forever* to compile // log1p(z) = log(1 + z) // Let's define 1 + z = r * e ^ (i * a), then we have // log(r * e ^ (i * a)) = log(r) + i * a // With z = x + iy, the term r can be written as // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 // So, log(r) is // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) // = 0.5 * log1p(x * (x + 2) + y ^ 2) // we need to use the expression only on certain condition to avoid overflow // and underflow from `(x * (x + 2) + y ^ 2)` T x = z.real(); T y = z.imag(); T zabs = std::abs(z); T theta = std::atan2(y, x + T(1)); if (zabs < 0.5) { T r = x * (T(2) + x) + y * y; if (r == 0) { // handle underflow return {x, theta}; } return {T(0.5) * std::log1p(r), theta}; } else { T z0 = std::hypot(x + 1, y); return {std::log(z0), theta}; } #else // CPU path // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 c10::complex u = z + T(1); if (u == T(1)) { return z; } else { auto log_u = log(u); if (u - T(1) == z) { return log_u; } return log_u * (z / (u - T(1))); } #endif } template C10_HOST_DEVICE inline c10::complex expm1(const c10::complex& z) { // expm1(z) = exp(z) - 1 // Define z = x + i * y // f = e ^ (x + i * y) - 1 // = e ^ x * e ^ (i * y) - 1 // = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y)) // = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y) // = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y) T x = z.real(); T y = z.imag(); T a = std::sin(y / 2); T er = std::expm1(x) * std::cos(y) - T(2) * a * a; T ei = std::exp(x) * std::sin(y); return {er, ei}; } } // namespace c10_complex_math using c10_complex_math::acos; using c10_complex_math::acosh; using c10_complex_math::asin; using c10_complex_math::asinh; using c10_complex_math::atan; using c10_complex_math::atanh; using c10_complex_math::cos; using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::expm1; using c10_complex_math::log; using c10_complex_math::log10; using c10_complex_math::log1p; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; using c10_complex_math::sinh; using c10_complex_math::sqrt; using c10_complex_math::tan; using c10_complex_math::tanh; namespace std { using c10_complex_math::acos; using c10_complex_math::acosh; using c10_complex_math::asin; using c10_complex_math::asinh; using c10_complex_math::atan; using c10_complex_math::atanh; using c10_complex_math::cos; using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::expm1; using c10_complex_math::log; using c10_complex_math::log10; using c10_complex_math::log1p; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; using c10_complex_math::sinh; using c10_complex_math::sqrt; using c10_complex_math::tan; using c10_complex_math::tanh; } // namespace std