xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/GcdLcmKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/JitLoops.cuh>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/native/cuda/Math.cuh>
7 #include <ATen/native/TensorIterator.h>
8 #include <ATen/native/BinaryOps.h>
9 #include <ATen/native/cuda/jit_utils.h>
10 
11 // NOTE: CUDA on Windows requires that the enclosing function
12 // of a __device__ lambda not have internal linkage.
13 
14 namespace at::native {
15 
16 // See note [Jiterator]
17 CONSTEXPR_EXCEPT_WIN_CUDA char gcd_name[] = "gcd";
gcd_kernel_cuda(TensorIteratorBase & iter)18 void gcd_kernel_cuda(TensorIteratorBase& iter) {
19   #if AT_USE_JITERATOR()
20     AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
21       jitted_gpu_kernel</*name=*/gcd_name,
22                         /*return_dtype=*/ scalar_t,
23                         /*common_dtype=*/ scalar_t,
24                         /*arity=*/ 2>(iter, gcd_string);
25     });
26   #else
27     AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() {
28       gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
29         return calc_gcd(a, b);
30       });
31     });
32   #endif // AT_USE_JITERATOR()
33 }
34 
35 // See note [Jiterator]
36 CONSTEXPR_EXCEPT_WIN_CUDA char lcm_name[] = "lcm";
lcm_kernel_cuda(TensorIteratorBase & iter)37 void lcm_kernel_cuda(TensorIteratorBase& iter) {
38   #if AT_USE_JITERATOR()
39     AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
40       jitted_gpu_kernel</*name=*/lcm_name,
41                         /*return_dtype=*/ scalar_t,
42                         /*common_dtype=*/ scalar_t,
43                         /*arity=*/ 2>(iter, lcm_string);
44     });
45   #else
46     AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() {
47       gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a, scalar_t b) -> scalar_t {
48         scalar_t g = calc_gcd(a, b);
49         return (g == 0) ? 0 : ::abs(a / g * b);
50       });
51     });
52   #endif // AT_USE_JITERATOR()
53 }
54 
55 REGISTER_DISPATCH(gcd_stub, &gcd_kernel_cuda);
56 REGISTER_DISPATCH(lcm_stub, &lcm_kernel_cuda);
57 
58 } // namespace at::native
59