xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ReduceLogicKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/TensorIterator.h>
3 #include <ATen/native/cuda/Reduce.cuh>
4 #include <ATen/native/DispatchStub.h>
5 #include <ATen/native/SharedReduceOps.h>
6 #include <ATen/native/ReduceOps.h>
7 #include <ATen/Dispatch.h>
8 
9 namespace at::native {
10 
and_kernel_cuda(TensorIterator & iter)11 void and_kernel_cuda(TensorIterator& iter) {
12   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
13       kHalf, kBFloat16, kBool, iter.common_dtype(), "and_cuda", [&]() {
14         gpu_reduce_kernel<scalar_t, bool>(
15             iter,
16             func_wrapper<bool>([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17               return (static_cast<bool>(a) && static_cast<bool>(b));
18             }),
19             true);
20       });
21 }
22 
or_kernel_cuda(TensorIterator & iter)23 void or_kernel_cuda(TensorIterator& iter) {
24   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
25       kHalf, kBFloat16, kBool, iter.common_dtype(), "or_cuda", [&]() {
26         gpu_reduce_kernel<scalar_t, bool>(
27             iter,
28             func_wrapper<bool>([] GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
29               return (static_cast<bool>(a) || static_cast<bool>(b));
30             }),
31             false);
32       });
33 }
34 
35 REGISTER_DISPATCH(and_stub, &and_kernel_cuda);
36 REGISTER_DISPATCH(or_stub, &or_kernel_cuda);
37 
38 } // namespace at::native
39