1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/Copy.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/JitLoops.cuh>
7 #include <ATen/Dispatch.h>
8 #include <ATen/NumericUtils.h>
9 #include <ATen/native/DispatchStub.h>
10 #include <ATen/native/TensorIterator.h>
11
12 namespace at::native {
13
14 // We manually overload angle because std::arg does not work with types other than c10::complex.
15 template<typename scalar_t>
angle_wrapper(scalar_t v)16 __host__ __device__ static inline scalar_t angle_wrapper(scalar_t v) {
17 if (at::_isnan(v)){
18 return v;
19 }
20 return v < 0 ? M_PI : 0;
21 }
22
23 template<typename T>
angle_wrapper(c10::complex<T> v)24 __host__ __device__ static inline c10::complex<T> angle_wrapper(c10::complex<T> v) {
25 return c10::complex<T>{std::arg(v), 0};
26 }
27
28 #if AT_USE_JITERATOR()
29 CONSTEXPR_EXCEPT_WIN_CUDA char angle_name[] = "angle_kernel";
30 #endif
31
angle_kernel_cuda(TensorIteratorBase & iter)32 void angle_kernel_cuda(TensorIteratorBase& iter) {
33 auto dtype = iter.common_dtype();
34 if (at::isComplexType(dtype)) {
35 #if AT_USE_JITERATOR()
36 static const auto angle_string = jiterator_stringify(
37 template <typename T>
38 T angle_kernel(T v) {
39 return T{std::arg(v)};
40 }
41 ); // angle string
42 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "angle_cuda", [&]() {
43 jitted_gpu_kernel<
44 /*name=*/ angle_name,
45 /*return_dtype=*/ scalar_t,
46 /*common_dtype=*/ scalar_t,
47 /*arity=*/ 1>(iter, angle_string);
48 });
49 #else
50 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "angle_cuda", [&]() {
51 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
52 return angle_wrapper(a);
53 });
54 });
55 #endif
56 } else {
57 AT_DISPATCH_FLOATING_TYPES(dtype, "angle_cuda", [&]() {
58 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
59 return angle_wrapper(a);
60 });
61 });
62 }
63 }
64
65 // NB: Ignores the negative bit on tensors
66 CONSTEXPR_EXCEPT_WIN_CUDA char conj_name[] = "conj_kernel";
conj_kernel_cuda(TensorIteratorBase & iter)67 void conj_kernel_cuda(TensorIteratorBase& iter) {
68 auto conj_chalf = [&] {
69 using scalar_t = c10::complex<at::Half>;
70 #if AT_USE_JITERATOR()
71 static const auto conj_string = jiterator_stringify(
72 template <typename T>
73 T conj_kernel(T z) {
74 return std::conj(z);
75 }
76 );
77 jitted_gpu_kernel<conj_name, scalar_t, scalar_t, 1>(iter, conj_string);
78 #else
79 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
80 return std::conj(a);
81 });
82 #endif
83 };
84
85 AT_DISPATCH_SWITCH(iter.common_dtype(), "conj_cuda",
86 AT_DISPATCH_CASE_ALL_TYPES_AND3(kBool, kBFloat16, kHalf, [&] {
87 // Conj is a no-op for non-complex types
88 direct_copy_kernel_cuda(iter);
89 })
90 AT_DISPATCH_CASE_COMPLEX_TYPES([&] {
91 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
92 return std::conj(a);
93 });
94 })
95 AT_DISPATCH_CASE(kComplexHalf, conj_chalf)
96 );
97 }
98
99 REGISTER_DISPATCH(angle_stub, &angle_kernel_cuda);
100 REGISTER_DISPATCH(conj_physical_stub, &conj_kernel_cuda);
101
102 } // namespace at::native
103