1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/JitLoops.cuh>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/Math.cuh>
7 #include <ATen/native/TensorIterator.h>
8 #include <ATen/native/BinaryOps.h>
9 #include <ATen/native/cuda/jit_utils.h>
10
11 // NOTE: CUDA on Windows requires that the enclosing function
12 // of a __device__ lambda not have internal linkage.
13
14 namespace at::native {
15
16 // See note [Jiterator]
17 CONSTEXPR_EXCEPT_WIN_CUDA char gcd_name[] = "gcd";
gcd_kernel_cuda(TensorIteratorBase & iter)18 void gcd_kernel_cuda(TensorIteratorBase& iter) {
19 #if AT_USE_JITERATOR()
20 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
21 jitted_gpu_kernel</*name=*/gcd_name,
22 /*return_dtype=*/ scalar_t,
23 /*common_dtype=*/ scalar_t,
24 /*arity=*/ 2>(iter, gcd_string);
25 });
26 #else
27 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
28 gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
29 return calc_gcd(a, b);
30 });
31 });
32 #endif // AT_USE_JITERATOR()
33 }
34
35 // See note [Jiterator]
36 CONSTEXPR_EXCEPT_WIN_CUDA char lcm_name[] = "lcm";
lcm_kernel_cuda(TensorIteratorBase & iter)37 void lcm_kernel_cuda(TensorIteratorBase& iter) {
38 #if AT_USE_JITERATOR()
39 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
40 jitted_gpu_kernel</*name=*/lcm_name,
41 /*return_dtype=*/ scalar_t,
42 /*common_dtype=*/ scalar_t,
43 /*arity=*/ 2>(iter, lcm_string);
44 });
45 #else
46 AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
47 gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
48 scalar_t g = calc_gcd(a, b);
49 return (g == 0) ? 0 : ::abs(a / g * b);
50 });
51 });
52 #endif // AT_USE_JITERATOR()
53 }
54
55 REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
56 REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
57
58 } // namespace at::native
59