xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/BinaryOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <c10/cuda/CUDAMathCompat.h>
9 #include <c10/util/TypeSafeSignMath.h>
10 #include <ATen/native/cuda/BinaryInternal.h>
11 #include <ATen/native/cuda/JitLoops.cuh>
12 #include <ATen/native/cuda/Loops.cuh>
13 
14 #include <type_traits>
15 
16 namespace at::native {
17 namespace binary_internal {
18 
19 CONSTEXPR_EXCEPT_WIN_CUDA char div_name[] = "div_kernel";
div_true_kernel_cuda(TensorIteratorBase & iter)20 void div_true_kernel_cuda(TensorIteratorBase& iter) {
21   auto common_dtype = iter.common_dtype();
22   if (iter.common_dtype() == kComplexHalf) {
23     using scalar_t = c10::complex<at::Half>;
24 #if AT_USE_JITERATOR()
25     static const auto div_string = jiterator_stringify(
26         template <typename T> T div_kernel(T a, T b) { return a / b; });
27     opmath_jitted_gpu_kernel_with_scalars<div_name, scalar_t, scalar_t>(
28         iter, div_string);
29 #else
30     using opmath_t = at::opmath_type<scalar_t>;
31     opmath_gpu_kernel_with_scalars<scalar_t>(iter, DivFunctor<opmath_t>());
32 #endif
33     return;
34   }
35   if (iter.is_cpu_scalar(2)) {
36     // optimization for floating-point types: if the second operand is a CPU
37     // scalar, compute a * reciprocal(b). Note that this may lose one bit of
38     // precision compared to computing the division.
39     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
40         kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() {
41           using opmath_t = at::opmath_type<scalar_t>;
42           auto inv_b = opmath_t(1.0) / iter.scalar_value<opmath_t>(2);
43           iter.remove_operand(2);
44           gpu_kernel(
45               iter,
46               BUnaryFunctor<scalar_t, scalar_t, scalar_t, MulFunctor<opmath_t>>(
47                   MulFunctor<opmath_t>(), inv_b));
48         });
49   } else {
50     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
51         kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() {
52           DivFunctor<scalar_t> f;
53           gpu_kernel_with_scalars(iter, f);
54         });
55   }
56 }
57 } // namespace binary_internal
58 
59 REGISTER_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_cuda);
60 
61 } // namespace at::native
62