xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryInternal.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // DON'T include this except from Binary*.cu files. It should not leak into
2 // headers.
3 #pragma once
4 #define TORCH_ASSERT_NO_OPERATORS
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/BinaryOps.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <c10/cuda/CUDAGuard.h>
11 #include <c10/cuda/CUDAMathCompat.h>
12 #include <c10/util/TypeSafeSignMath.h>
13 #include <ATen/native/cuda/JitLoops.cuh>
14 #include <ATen/native/cuda/Loops.cuh>
15 
16 #include <type_traits>
17 
18 namespace at {
19 namespace native {
20 namespace binary_internal {
21 
22 template <typename scalar_t>
23 struct DivFunctor {
operatorDivFunctor24   __device__ scalar_t operator()(scalar_t a, scalar_t b) const {
25     return a / b;
26   }
27 };
28 
29 template <typename T>
30 struct MulFunctor {
operatorMulFunctor31   __device__ T operator()(T a, T b) const {
32     return a * b;
33   }
34 };
35 
36 // Workaround for the error: '*' in boolean context, suggest '&&' instead
37 // [-Werror=int-in-bool-context]
38 template <>
39 struct MulFunctor<bool> {
40   __device__ bool operator()(bool a, bool b) const {
41     return a && b;
42   }
43 };
44 void div_true_kernel_cuda(TensorIteratorBase& iter);
45 void div_trunc_kernel_cuda(TensorIteratorBase& iter);
46 } // namespace binary_internal
47 } // namespace native
48 } // namespace at
49