xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ComplexKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/TensorFactories.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/cuda/Loops.cuh>
6 
7 // NOTE: CUDA on Windows requires that the enclosing function
8 // of a __device__ lambda not have internal linkage.
9 
10 namespace at::native {
11 namespace {
12 
complex_kernel_cuda(TensorIterator & iter)13 void complex_kernel_cuda(TensorIterator& iter) {
14   AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.input_dtype(0), "complex_cuda", [&]() {
15     gpu_kernel(
16       iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> c10::complex<scalar_t> {
17         return c10::complex<scalar_t>(a, b);
18       });
19   });
20 }
21 
polar_kernel_cuda(TensorIterator & iter)22 void polar_kernel_cuda(TensorIterator& iter) {
23   AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(0), "polar_cuda", [&]() {
24     gpu_kernel(
25       iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> c10::complex<scalar_t> {
26         return c10::complex<scalar_t>(a * std::cos(b), a * std::sin(b));
27       });
28   });
29 }
30 
31 } // anonymous namespace
32 
33 REGISTER_DISPATCH(complex_stub, &complex_kernel_cuda);
34 REGISTER_DISPATCH(polar_stub, &polar_kernel_cuda);
35 
36 } // namespace at::native
37