1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/BinaryOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <c10/cuda/CUDAMathCompat.h>
9 #include <c10/util/TypeSafeSignMath.h>
10 #include <ATen/native/cuda/BinaryInternal.h>
11 #include <ATen/native/cuda/JitLoops.cuh>
12 #include <ATen/native/cuda/Loops.cuh>
13
14 #include <type_traits>
15
16 namespace at::native {
17 namespace binary_internal {
18
19 CONSTEXPR_EXCEPT_WIN_CUDA char div_name[] = "div_kernel";
div_true_kernel_cuda(TensorIteratorBase & iter)20 void div_true_kernel_cuda(TensorIteratorBase& iter) {
21 auto common_dtype = iter.common_dtype();
22 if (iter.common_dtype() == kComplexHalf) {
23 using scalar_t = c10::complex<at::Half>;
24 #if AT_USE_JITERATOR()
25 static const auto div_string = jiterator_stringify(
26 template <typename T> T div_kernel(T a, T b) { return a / b; });
27 opmath_jitted_gpu_kernel_with_scalars<div_name, scalar_t, scalar_t>(
28 iter, div_string);
29 #else
30 using opmath_t = at::opmath_type<scalar_t>;
31 opmath_gpu_kernel_with_scalars<scalar_t>(iter, DivFunctor<opmath_t>());
32 #endif
33 return;
34 }
35 if (iter.is_cpu_scalar(2)) {
36 // optimization for floating-point types: if the second operand is a CPU
37 // scalar, compute a * reciprocal(b). Note that this may lose one bit of
38 // precision compared to computing the division.
39 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
40 kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() {
41 using opmath_t = at::opmath_type<scalar_t>;
42 auto inv_b = opmath_t(1.0) / iter.scalar_value<opmath_t>(2);
43 iter.remove_operand(2);
44 gpu_kernel(
45 iter,
46 BUnaryFunctor<scalar_t, scalar_t, scalar_t, MulFunctor<opmath_t>>(
47 MulFunctor<opmath_t>(), inv_b));
48 });
49 } else {
50 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
51 kHalf, kBFloat16, common_dtype, "div_true_cuda", [&]() {
52 DivFunctor<scalar_t> f;
53 gpu_kernel_with_scalars(iter, f);
54 });
55 }
56 }
57 } // namespace binary_internal
58
59 REGISTER_DISPATCH(div_true_stub, &binary_internal::div_true_kernel_cuda);
60
61 } // namespace at::native
62