1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/cuda/JitLoops.cuh>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/BinaryOps.h>
8
9 // NOTE: CUDA on Windows requires that the enclosing function
10 // of a __device__ lambda not have internal linkage.
11
12 namespace at::native {
13
14 CONSTEXPR_EXCEPT_WIN_CUDA char logical_and_name[] = "logical_and_kernel";
logical_and_kernel_cuda(TensorIterator & iter)15 void logical_and_kernel_cuda(TensorIterator& iter) {
16 auto dtype = iter.common_dtype();
17 if (at::isComplexType(dtype)) {
18 #if AT_USE_JITERATOR()
19 static const auto logical_and_string = jiterator_stringify(
20 template <typename T>
21 bool logical_and_kernel(T a, T b) {
22 return a && b;
23 }
24 ); // logical_and_string
25 AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_cuda", [&]() {
26 jitted_gpu_kernel<
27 /*name=*/ logical_and_name,
28 /*return_dtype=*/ scalar_t,
29 /*common_dtype=*/ scalar_t,
30 /*arity=*/ 2>(iter, logical_and_string);
31 }); // logical_and_string
32 #else
33 AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_cuda", [&]() {
34 opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
35 iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
36 return a && b;
37 });
38 });
39 #endif
40 } else {
41 AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16,
42 dtype, "logical_and_cuda", [&]() {
43 opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
44 iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
45 return a && b;
46 });
47 });
48 }
49 }
50
51 CONSTEXPR_EXCEPT_WIN_CUDA char logical_or_name[] = "logical_or_kernel";
logical_or_kernel_cuda(TensorIterator & iter)52 void logical_or_kernel_cuda(TensorIterator& iter) {
53 auto dtype = iter.common_dtype();
54 if (at::isComplexType(dtype)) {
55 #if AT_USE_JITERATOR()
56 static const auto logical_or_string = jiterator_stringify(
57 template <typename T>
58 bool logical_or_kernel(T a, T b) {
59 return a || b;
60 }
61 ); // logical_or_string
62 AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_cuda", [&]() {
63 jitted_gpu_kernel<
64 /*name=*/ logical_or_name,
65 /*return_dtype=*/ scalar_t,
66 /*common_dtype=*/ scalar_t,
67 /*arity=*/ 2>(iter, logical_or_string);
68 });
69 #else
70 AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_cuda", [&]() {
71 gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
72 return a || b;
73 });
74 });
75 #endif
76 } else {
77 AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16,
78 dtype, "logical_or_cuda", [&]() {
79 opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
80 iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
81 return a || b;
82 });
83 });
84 }
85 }
86
87 CONSTEXPR_EXCEPT_WIN_CUDA char logical_xor_name[] = "logical_xor_kernel";
logical_xor_kernel_cuda(TensorIterator & iter)88 void logical_xor_kernel_cuda(TensorIterator& iter) {
89 auto dtype = iter.common_dtype();
90 if (at::isComplexType(dtype)) {
91 #if AT_USE_JITERATOR()
92 static const auto logical_xor_string = jiterator_stringify(
93 template <typename T>
94 bool logical_xor_kernel(T a, T b) {
95 return bool(a) != bool(b);
96 }
97 );
98 AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_cuda", [&]() {
99 jitted_gpu_kernel<
100 /*name=*/ logical_xor_name,
101 /*return_dtype=*/ scalar_t,
102 /*common_dtype=*/ scalar_t,
103 /*arity=*/ 2>(iter, logical_xor_string);
104 }); // logical_xor_string
105 #else
106 AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_cuda", [&]() {
107 gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
108 return bool(a) != bool(b);
109 });
110 });
111 #endif
112 } else {
113 AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBool, ScalarType::BFloat16,
114 dtype, "logical_xor_cuda", [&]() {
115 opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
116 iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
117 return bool(a) != bool(b);
118 });
119 });
120 }
121 }
122
123 REGISTER_DISPATCH(logical_and_stub, &logical_and_kernel_cuda);
124 REGISTER_DISPATCH(logical_or_stub, &logical_or_kernel_cuda);
125 REGISTER_DISPATCH(logical_xor_stub, &logical_xor_kernel_cuda);
126
127
128 } // namespace at::native
129