xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CompareKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/BinaryOps.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/cuda/Loops.cuh>
7 
8 
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11 
12 namespace at::native { namespace {
13 
14 enum class OpType {GE, GT, LE, LT};
15 
16 template<typename scalar_t>
17 struct CompareFunctor{
CompareFunctorat::native::__anon4c97a5870111::CompareFunctor18   constexpr CompareFunctor(OpType op): op_(op) {};
19   OpType op_;
operator ()at::native::__anon4c97a5870111::CompareFunctor20   __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const {
21     if (op_ == OpType::GE) {
22       return a >= b;
23     } else if (op_ == OpType::GT) {
24       return a > b;
25     } else if (op_ == OpType::LE) {
26       return a <= b;
27     } else { //LT
28       return a < b;
29     }
30   }
31 };
32 
33 // Reflects the comparison operator, so reflect(op)(a, b) == op(b, a)
reflect(OpType x)34 OpType reflect(OpType x) {
35   switch (x) {
36     case OpType::GE: return OpType::LE;
37     case OpType::GT: return OpType::LT;
38     case OpType::LE: return OpType::GE;
39     case OpType::LT: return OpType::GT;
40   }
41   TORCH_INTERNAL_ASSERT(false, "Invalid OpType");
42 }
43 
44 }  // namespace (anonymous)
45 
46 template <typename scalar_t>
compare_scalar_kernel(TensorIteratorBase & iter,OpType op,scalar_t rhs)47 void compare_scalar_kernel(TensorIteratorBase &iter, OpType op, scalar_t rhs) {
48   CompareFunctor<scalar_t> f(op);
49   gpu_kernel(iter, [=] GPU_LAMBDA (scalar_t lhs) -> bool {
50     return f(lhs, rhs);
51   });
52 }
53 
54 template <typename scalar_t>
compare_kernel_impl(TensorIteratorBase & iter,OpType op)55 void compare_kernel_impl(TensorIteratorBase &iter, OpType op) {
56   // If either input is a cpu scalar, perform the equivalent comparison
57   // where the scalar is on the right hand side. This saves us from
58   // generating two otherwise identical kernels with mirrored
59   // arguments.
60   if (iter.is_cpu_scalar(1)) {
61     const scalar_t lhs = iter.scalar_value<scalar_t>(1);
62     iter.remove_operand(1);
63     const DeviceGuard device_guard(iter.device(1));
64     compare_scalar_kernel(iter, reflect(op), lhs);
65   } else if (iter.is_cpu_scalar(2)) {
66     const scalar_t rhs = iter.scalar_value<scalar_t>(2);
67     iter.remove_operand(2);
68     compare_scalar_kernel(iter, op, rhs);
69   } else {
70     CompareFunctor<scalar_t> f(op);
71     gpu_kernel(iter, f);
72   }
73 }
74 
compare_kernel_with_scalars(TensorIteratorBase & iter,OpType op)75 C10_NOINLINE void compare_kernel_with_scalars(TensorIteratorBase &iter, OpType op) {
76   AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "compare_cuda", [&]() {
77     compare_kernel_impl<scalar_t>(iter, op);
78   });
79 }
80 
81 
ge_kernel_cuda(TensorIteratorBase & iter)82 void ge_kernel_cuda(TensorIteratorBase& iter) {
83   compare_kernel_with_scalars(iter, OpType::GE);
84 }
85 
gt_kernel_cuda(TensorIteratorBase & iter)86 void gt_kernel_cuda(TensorIteratorBase& iter) {
87   compare_kernel_with_scalars(iter, OpType::GT);
88 }
89 
le_kernel_cuda(TensorIteratorBase & iter)90 void le_kernel_cuda(TensorIteratorBase& iter) {
91   compare_kernel_with_scalars(iter, OpType::LE);
92 }
93 
lt_kernel_cuda(TensorIteratorBase & iter)94 void lt_kernel_cuda(TensorIteratorBase& iter) {
95   compare_kernel_with_scalars(iter, OpType::LT);
96 }
97 
98 REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
99 REGISTER_DISPATCH(gt_stub, &gt_kernel_cuda);
100 REGISTER_DISPATCH(le_stub, &le_kernel_cuda);
101 REGISTER_DISPATCH(lt_stub, &lt_kernel_cuda);
102 
103 } // namespace at::native
104