xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryComplexKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/Copy.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/JitLoops.cuh>
7 #include <ATen/Dispatch.h>
8 #include <ATen/NumericUtils.h>
9 #include <ATen/native/DispatchStub.h>
10 #include <ATen/native/TensorIterator.h>
11 
12 namespace at::native {
13 
14 // We manually overload angle because std::arg does not work with types other than c10::complex.
15 template<typename scalar_t>
angle_wrapper(scalar_t v)16 __host__ __device__ static inline scalar_t angle_wrapper(scalar_t v) {
17   if (at::_isnan(v)){
18     return v;
19   }
20   return v < 0 ? M_PI : 0;
21 }
22 
23 template<typename T>
angle_wrapper(c10::complex<T> v)24 __host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T> v) {
25   return c10::complex<T>{std::arg(v), 0};
26 }
27 
28 #if AT_USE_JITERATOR()
29 CONSTEXPR_EXCEPT_WIN_CUDA char angle_name[] = "angle_kernel";
30 #endif
31 
angle_kernel_cuda(TensorIteratorBase & iter)32 void angle_kernel_cuda(TensorIteratorBase& iter) {
33   auto dtype = iter.common_dtype();
34   if (at::isComplexType(dtype)) {
35 #if AT_USE_JITERATOR()
36     static const auto angle_string = jiterator_stringify(
37         template <typename T>
38         T angle_kernel(T v) {
39           return T{std::arg(v)};
40         }
41     ); // angle string
42     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "angle_cuda", [&]() {
43         jitted_gpu_kernel<
44           /*name=*/ angle_name,
45           /*return_dtype=*/ scalar_t,
46           /*common_dtype=*/ scalar_t,
47           /*arity=*/ 1>(iter, angle_string);
48     });
49 #else
50     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "angle_cuda", [&]() {
51         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
52           return angle_wrapper(a);
53         });
54     });
55 #endif
56   } else {
57     AT_DISPATCH_FLOATING_TYPES(dtype, "angle_cuda", [&]() {
58         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
59           return angle_wrapper(a);
60         });
61     });
62   }
63 }
64 
65 // NB: Ignores the negative bit on tensors
66 CONSTEXPR_EXCEPT_WIN_CUDA char conj_name[] = "conj_kernel";
conj_kernel_cuda(TensorIteratorBase & iter)67 void conj_kernel_cuda(TensorIteratorBase& iter) {
68   auto conj_chalf = [&] {
69     using scalar_t = c10::complex<at::Half>;
70     #if AT_USE_JITERATOR()
71       static const auto conj_string = jiterator_stringify(
72         template <typename T>
73         T conj_kernel(T z) {
74           return std::conj(z);
75         }
76       );
77       jitted_gpu_kernel<conj_name, scalar_t, scalar_t, 1>(iter, conj_string);
78     #else
79       gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
80           return std::conj(a);
81       });
82     #endif
83   };
84 
85   AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_cuda",
86     AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] {
87       // Conj is a no-op for non-complex types
88       direct_copy_kernel_cuda(iter);
89     })
90     AT_DISPATCH_CASE_COMPLEX_TYPES([&] {
91       gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
92         return std::conj(a);
93       });
94     })
95     AT_DISPATCH_CASE(kComplexHalf, conj_chalf)
96   );
97 }
98 
99 REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
100 REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
101 
102 } // namespace at::native
103