1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/BinaryOps.h>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/cuda/Loops.cuh>
7
8
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11
12 namespace at::native { namespace {
13
14 enum class OpType {GE, GT, LE, LT};
15
16 template<typename scalar_t>
17 struct CompareFunctor{
CompareFunctorat::native::__anon4c97a5870111::CompareFunctor18 constexpr CompareFunctor(OpType op): op_(op) {};
19 OpType op_;
operator ()at::native::__anon4c97a5870111::CompareFunctor20 __device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const {
21 if (op_ == OpType::GE) {
22 return a >= b;
23 } else if (op_ == OpType::GT) {
24 return a > b;
25 } else if (op_ == OpType::LE) {
26 return a <= b;
27 } else { //LT
28 return a < b;
29 }
30 }
31 };
32
33 // Reflects the comparison operator, so reflect(op)(a, b) == op(b, a)
reflect(OpType x)34 OpType reflect(OpType x) {
35 switch (x) {
36 case OpType::GE: return OpType::LE;
37 case OpType::GT: return OpType::LT;
38 case OpType::LE: return OpType::GE;
39 case OpType::LT: return OpType::GT;
40 }
41 TORCH_INTERNAL_ASSERT(false, "Invalid OpType");
42 }
43
44 } // namespace (anonymous)
45
46 template <typename scalar_t>
compare_scalar_kernel(TensorIteratorBase & iter,OpType op,scalar_t rhs)47 void compare_scalar_kernel(TensorIteratorBase &iter, OpType op, scalar_t rhs) {
48 CompareFunctor<scalar_t> f(op);
49 gpu_kernel(iter, [=] GPU_LAMBDA (scalar_t lhs) -> bool {
50 return f(lhs, rhs);
51 });
52 }
53
54 template <typename scalar_t>
compare_kernel_impl(TensorIteratorBase & iter,OpType op)55 void compare_kernel_impl(TensorIteratorBase &iter, OpType op) {
56 // If either input is a cpu scalar, perform the equivalent comparison
57 // where the scalar is on the right hand side. This saves us from
58 // generating two otherwise identical kernels with mirrored
59 // arguments.
60 if (iter.is_cpu_scalar(1)) {
61 const scalar_t lhs = iter.scalar_value<scalar_t>(1);
62 iter.remove_operand(1);
63 const DeviceGuard device_guard(iter.device(1));
64 compare_scalar_kernel(iter, reflect(op), lhs);
65 } else if (iter.is_cpu_scalar(2)) {
66 const scalar_t rhs = iter.scalar_value<scalar_t>(2);
67 iter.remove_operand(2);
68 compare_scalar_kernel(iter, op, rhs);
69 } else {
70 CompareFunctor<scalar_t> f(op);
71 gpu_kernel(iter, f);
72 }
73 }
74
compare_kernel_with_scalars(TensorIteratorBase & iter,OpType op)75 C10_NOINLINE void compare_kernel_with_scalars(TensorIteratorBase &iter, OpType op) {
76 AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "compare_cuda", [&]() {
77 compare_kernel_impl<scalar_t>(iter, op);
78 });
79 }
80
81
ge_kernel_cuda(TensorIteratorBase & iter)82 void ge_kernel_cuda(TensorIteratorBase& iter) {
83 compare_kernel_with_scalars(iter, OpType::GE);
84 }
85
gt_kernel_cuda(TensorIteratorBase & iter)86 void gt_kernel_cuda(TensorIteratorBase& iter) {
87 compare_kernel_with_scalars(iter, OpType::GT);
88 }
89
le_kernel_cuda(TensorIteratorBase & iter)90 void le_kernel_cuda(TensorIteratorBase& iter) {
91 compare_kernel_with_scalars(iter, OpType::LE);
92 }
93
lt_kernel_cuda(TensorIteratorBase & iter)94 void lt_kernel_cuda(TensorIteratorBase& iter) {
95 compare_kernel_with_scalars(iter, OpType::LT);
96 }
97
98 REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
99 REGISTER_DISPATCH(gt_stub, >_kernel_cuda);
100 REGISTER_DISPATCH(le_stub, &le_kernel_cuda);
101 REGISTER_DISPATCH(lt_stub, <_kernel_cuda);
102
103 } // namespace at::native
104