1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/TensorFactories.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/native/cuda/Loops.cuh>
6
7 // NOTE: CUDA on Windows requires that the enclosing function
8 // of a __device__ lambda not have internal linkage.
9
10 namespace at::native {
11 namespace {
12
complex_kernel_cuda(TensorIterator & iter)13 void complex_kernel_cuda(TensorIterator& iter) {
14 AT_DISPATCH_FLOATING_TYPES_AND(kHalf, iter.input_dtype(0), "complex_cuda", [&]() {
15 gpu_kernel(
16 iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> c10::complex<scalar_t> {
17 return c10::complex<scalar_t>(a, b);
18 });
19 });
20 }
21
polar_kernel_cuda(TensorIterator & iter)22 void polar_kernel_cuda(TensorIterator& iter) {
23 AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(0), "polar_cuda", [&]() {
24 gpu_kernel(
25 iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> c10::complex<scalar_t> {
26 return c10::complex<scalar_t>(a * std::cos(b), a * std::sin(b));
27 });
28 });
29 }
30
31 } // anonymous namespace
32
33 REGISTER_DISPATCH(complex_stub, &complex_kernel_cuda);
34 REGISTER_DISPATCH(polar_stub, &polar_kernel_cuda);
35
36 } // namespace at::native
37