xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryLogicalOpsKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/cuda/JitLoops.cuh>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/BinaryOps.h>
8 
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11 
12 namespace at::native {
13 
14 CONSTEXPR_EXCEPT_WIN_CUDA char logical_and_name[] = "logical_and_kernel";
logical_and_kernel_cuda(TensorIterator & iter)15 void logical_and_kernel_cuda(TensorIterator& iter) {
16   auto dtype = iter.common_dtype();
17   if (at::isComplexType(dtype)) {
18 #if AT_USE_JITERATOR()
19     static const auto logical_and_string = jiterator_stringify(
20         template <typename T>
21         bool logical_and_kernel(T a, T b) {
22           return a && b;
23         }
24     ); // logical_and_string
25     AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_cuda", [&]() {
26       jitted_gpu_kernel<
27         /*name=*/ logical_and_name,
28         /*return_dtype=*/ scalar_t,
29         /*common_dtype=*/ scalar_t,
30         /*arity=*/ 2>(iter, logical_and_string);
31     }); // logical_and_string
32 #else
33     AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_cuda", [&]() {
34       opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
35           iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
36         return a && b;
37       });
38     });
39 #endif
40   } else {
41     AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16,
42                                dtype, "logical_and_cuda", [&]() {
43       opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
44           iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
45         return a && b;
46       });
47    });
48   }
49 }
50 
51 CONSTEXPR_EXCEPT_WIN_CUDA char logical_or_name[] = "logical_or_kernel";
logical_or_kernel_cuda(TensorIterator & iter)52 void logical_or_kernel_cuda(TensorIterator& iter) {
53   auto dtype = iter.common_dtype();
54   if (at::isComplexType(dtype)) {
55 #if AT_USE_JITERATOR()
56     static const auto logical_or_string = jiterator_stringify(
57       template <typename T>
58       bool logical_or_kernel(T a, T b) {
59         return a || b;
60       }
61     ); // logical_or_string
62     AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_cuda", [&]() {
63       jitted_gpu_kernel<
64         /*name=*/ logical_or_name,
65         /*return_dtype=*/ scalar_t,
66         /*common_dtype=*/ scalar_t,
67         /*arity=*/ 2>(iter, logical_or_string);
68     });
69 #else
70     AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_cuda", [&]() {
71       gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
72         return a || b;
73       });
74     });
75 #endif
76   } else {
77   AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16,
78                              dtype, "logical_or_cuda", [&]() {
79     opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
80         iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
81       return a || b;
82     });
83   });
84   }
85 }
86 
87 CONSTEXPR_EXCEPT_WIN_CUDA char logical_xor_name[] = "logical_xor_kernel";
logical_xor_kernel_cuda(TensorIterator & iter)88 void logical_xor_kernel_cuda(TensorIterator& iter) {
89   auto dtype = iter.common_dtype();
90   if (at::isComplexType(dtype)) {
91 #if AT_USE_JITERATOR()
92     static const auto logical_xor_string = jiterator_stringify(
93         template <typename T>
94         bool logical_xor_kernel(T a, T b) {
95           return bool(a) != bool(b);
96         }
97     );
98     AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_cuda", [&]() {
99       jitted_gpu_kernel<
100         /*name=*/ logical_xor_name,
101         /*return_dtype=*/ scalar_t,
102         /*common_dtype=*/ scalar_t,
103         /*arity=*/ 2>(iter, logical_xor_string);
104     }); // logical_xor_string
105 #else
106     AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_cuda", [&]() {
107       gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
108         return bool(a) != bool(b);
109       });
110     });
111 #endif
112   } else {
113   AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16,
114                              dtype, "logical_xor_cuda", [&]() {
115     opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
116         iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
117       return bool(a) != bool(b);
118     });
119   });
120   }
121 }
122 
123 REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda);
124 REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda);
125 REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel_cuda);
126 
127 
128 } // namespace at::native
129