xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryDivTruncKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/BinaryOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <c10/cuda/CUDAMathCompat.h>
9 #include <c10/util/TypeSafeSignMath.h>
10 #include <ATen/native/cuda/JitLoops.cuh>
11 #include <ATen/native/cuda/Loops.cuh>
12 
13 #include <type_traits>
14 
15 namespace at::native {
16 namespace binary_internal {
17 
div_trunc_kernel_cuda(TensorIteratorBase & iter)18 void div_trunc_kernel_cuda(TensorIteratorBase& iter) {
19   auto dtype = iter.common_dtype();
20   if (isIntegralType(dtype, /*includeBool*/ false)) {
21     AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() {
22       gpu_kernel_with_scalars(
23           iter,
24           [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a / b; });
25     });
26   } else if (iter.is_cpu_scalar(2)) {
27     // optimization for floating-point types: if the second operand is a CPU
28     // scalar, compute a * reciprocal(b). Note that this may lose one bit of
29     // precision compared to computing the division.
30     AT_DISPATCH_FLOATING_TYPES_AND2(
31         kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
32           using accscalar_t = at::acc_type<scalar_t, true>;
33           auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
34           iter.remove_operand(2);
35           gpu_kernel(iter, [inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t {
36             return std::trunc(a * inv_b);
37           });
38         });
39   } else {
40     AT_DISPATCH_FLOATING_TYPES_AND2(
41         kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
42           gpu_kernel_with_scalars(
43               iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
44                 return std::trunc(a / b);
45               });
46         });
47   }
48 }
49 } // namespace binary_internal
50 
51 REGISTER_DISPATCH(div_trunc_stub, &binary_internal::div_trunc_kernel_cuda);
52 
53 } // namespace at::native
54