1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7
8 // NOTE: CUDA on Windows requires that the enclosing function
9 // of a __device__ lambda not have internal linkage.
10
11 namespace at::native {
12
13 template<typename scalar_t>
14 struct BitwiseAndFunctor {
operator ()at::native::BitwiseAndFunctor15 __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const {
16 return a & b;
17 }
18 };
19
20 template<>
21 struct BitwiseAndFunctor<bool> {
operator ()at::native::BitwiseAndFunctor22 __device__ __forceinline__ bool operator()(bool a, bool b) const {
23 return a && b;
24 }
25 };
26
bitwise_and_kernel_cuda(TensorIteratorBase & iter)27 void bitwise_and_kernel_cuda(TensorIteratorBase& iter) {
28 AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_and_cuda", [&]() {
29 BitwiseAndFunctor<scalar_t> f;
30 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
31 });
32 }
33
34 template<typename scalar_t>
35 struct BitwiseOrFunctor {
operator ()at::native::BitwiseOrFunctor36 __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const {
37 return a | b;
38 }
39 };
40
41 template<>
42 struct BitwiseOrFunctor<bool> {
operator ()at::native::BitwiseOrFunctor43 __device__ __forceinline__ bool operator()(bool a, bool b) const {
44 return a || b;
45 }
46 };
47
bitwise_or_kernel_cuda(TensorIteratorBase & iter)48 void bitwise_or_kernel_cuda(TensorIteratorBase& iter) {
49 AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_or_cuda", [&]() {
50 BitwiseOrFunctor<scalar_t> f;
51 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
52 });
53 }
54
55 template<typename scalar_t>
56 struct BitwiseXorFunctor {
operator ()at::native::BitwiseXorFunctor57 __device__ __forceinline__ scalar_t operator()(scalar_t a, scalar_t b) const {
58 return a ^ b;
59 }
60 };
61
62 template<>
63 struct BitwiseXorFunctor<bool> {
operator ()at::native::BitwiseXorFunctor64 __device__ __forceinline__ bool operator()(bool a, bool b) const {
65 return a != b;
66 }
67 };
68
bitwise_xor_kernel_cuda(TensorIteratorBase & iter)69 void bitwise_xor_kernel_cuda(TensorIteratorBase& iter) {
70 AT_DISPATCH_INTEGRAL_TYPES_AND(kBool, iter.dtype(), "bitwise_xor_cuda", [&]() {
71 BitwiseXorFunctor<scalar_t> f;
72 opmath_symmetric_gpu_kernel_with_scalars<scalar_t>(iter, f);
73 });
74 }
75
76 REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel_cuda);
77 REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel_cuda);
78 REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel_cuda);
79
80
81 } // namespace at::native
82