1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/native/BinaryOps.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <c10/cuda/CUDAGuard.h>
8 #include <c10/cuda/CUDAMathCompat.h>
9 #include <c10/util/TypeSafeSignMath.h>
10 #include <ATen/native/cuda/JitLoops.cuh>
11 #include <ATen/native/cuda/Loops.cuh>
12
13 #include <type_traits>
14
15 namespace at::native {
16 namespace binary_internal {
17
div_trunc_kernel_cuda(TensorIteratorBase & iter)18 void div_trunc_kernel_cuda(TensorIteratorBase& iter) {
19 auto dtype = iter.common_dtype();
20 if (isIntegralType(dtype, /*includeBool*/ false)) {
21 AT_DISPATCH_INTEGRAL_TYPES(dtype, "div_trunc_cuda", [&]() {
22 gpu_kernel_with_scalars(
23 iter,
24 [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t { return a / b; });
25 });
26 } else if (iter.is_cpu_scalar(2)) {
27 // optimization for floating-point types: if the second operand is a CPU
28 // scalar, compute a * reciprocal(b). Note that this may lose one bit of
29 // precision compared to computing the division.
30 AT_DISPATCH_FLOATING_TYPES_AND2(
31 kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
32 using accscalar_t = at::acc_type<scalar_t, true>;
33 auto inv_b = accscalar_t(1.0) / iter.scalar_value<accscalar_t>(2);
34 iter.remove_operand(2);
35 gpu_kernel(iter, [inv_b] GPU_LAMBDA(scalar_t a) -> scalar_t {
36 return std::trunc(a * inv_b);
37 });
38 });
39 } else {
40 AT_DISPATCH_FLOATING_TYPES_AND2(
41 kHalf, kBFloat16, dtype, "div_trunc_cuda", [&]() {
42 gpu_kernel_with_scalars(
43 iter, [] GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
44 return std::trunc(a / b);
45 });
46 });
47 }
48 }
49 } // namespace binary_internal
50
51 REGISTER_DISPATCH(div_trunc_stub, &binary_internal::div_trunc_kernel_cuda);
52
53 } // namespace at::native
54