1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7
8 // NOTE: CUDA on Windows requires that the enclosing function
9 // of a __device__ lambda not have internal linkage.
10
11 namespace at::native {
12
13
lshift_kernel_cuda(TensorIteratorBase & iter)14 void lshift_kernel_cuda(TensorIteratorBase& iter) {
15 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
16 gpu_kernel_with_scalars(iter,
17 []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
18 constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
19 if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
20 return 0;
21 }
22 return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
23 });
24 });
25 }
26
rshift_kernel_cuda(TensorIteratorBase & iter)27 void rshift_kernel_cuda(TensorIteratorBase& iter) {
28 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
29 gpu_kernel_with_scalars(iter,
30 []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
31 // right shift value to retain sign bit for signed and no bits for unsigned
32 constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
33 if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
34 return a >> max_shift;
35 }
36 return a >> b;
37 });
38 });
39 }
40
41 REGISTER_DISPATCH(lshift_stub, &lshift_kernel_cuda);
42 REGISTER_DISPATCH(rshift_stub, &rshift_kernel_cuda);
43
44 } // namespace at::native
45