xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryBitwiseOpsKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 
8 // NOTE: CUDA on Windows requires that the enclosing function
9 // of a __device__ lambda not have internal linkage.
10 
11 namespace at::native {
12 
13 template<typename scalar_t>
14 struct BitwiseAndFunctor {
operator ()at::native::BitwiseAndFunctor15   __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const {
16     return a & b;
17   }
18 };
19 
20 template<>
21 struct BitwiseAndFunctor<bool> {
operator ()at::native::BitwiseAndFunctor22   __device__ __forceinline__ bool operator()(bool a, bool b) const {
23     return a && b;
24   }
25 };
26 
bitwise_and_kernel_cuda(TensorIteratorBase & iter)27 void bitwise_and_kernel_cuda(TensorIteratorBase& iter) {
28   AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_cuda", [&]() {
29     BitwiseAndFunctor<scalar_t> f;
30     opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
31   });
32 }
33 
34 template<typename scalar_t>
35 struct BitwiseOrFunctor {
operator ()at::native::BitwiseOrFunctor36   __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const {
37     return a | b;
38   }
39 };
40 
41 template<>
42 struct BitwiseOrFunctor<bool> {
operator ()at::native::BitwiseOrFunctor43   __device__ __forceinline__ bool operator()(bool a, bool b) const {
44     return a || b;
45   }
46 };
47 
bitwise_or_kernel_cuda(TensorIteratorBase & iter)48 void bitwise_or_kernel_cuda(TensorIteratorBase& iter) {
49   AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_cuda", [&]() {
50     BitwiseOrFunctor<scalar_t> f;
51     opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
52   });
53 }
54 
55 template<typename scalar_t>
56 struct BitwiseXorFunctor {
operator ()at::native::BitwiseXorFunctor57   __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const {
58     return a ^ b;
59   }
60 };
61 
62 template<>
63 struct BitwiseXorFunctor<bool> {
operator ()at::native::BitwiseXorFunctor64   __device__ __forceinline__ bool operator()(bool a, bool b) const {
65     return a != b;
66   }
67 };
68 
bitwise_xor_kernel_cuda(TensorIteratorBase & iter)69 void bitwise_xor_kernel_cuda(TensorIteratorBase& iter) {
70   AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_cuda", [&]() {
71     BitwiseXorFunctor<scalar_t> f;
72     opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
73   });
74 }
75 
76 REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda);
77 REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel_cuda);
78 REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda);
79 
80 
81 } // namespace at::native
82