xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/BinaryShiftOpsKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 
8 // NOTE: CUDA on Windows requires that the enclosing function
9 // of a __device__ lambda not have internal linkage.
10 
11 namespace at::native {
12 
13 
lshift_kernel_cuda(TensorIteratorBase & iter)14 void lshift_kernel_cuda(TensorIteratorBase& iter) {
15   AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_cuda", [&]() {
16     gpu_kernel_with_scalars(iter,
17       []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
18         constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
19         if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
20           return 0;
21         }
22         return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
23     });
24   });
25 }
26 
rshift_kernel_cuda(TensorIteratorBase & iter)27 void rshift_kernel_cuda(TensorIteratorBase& iter) {
28   AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_cuda", [&]() {
29     gpu_kernel_with_scalars(iter,
30       []GPU_LAMBDA(scalar_t a, scalar_t b) -> scalar_t {
31         // right shift value to retain sign bit for signed and no bits for unsigned
32         constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
33         if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) || (b >= max_shift)) {
34           return a >> max_shift;
35         }
36         return a >> b;
37     });
38   });
39 }
40 
41 REGISTER_DISPATCH(lshift_stub, &lshift_kernel_cuda);
42 REGISTER_DISPATCH(rshift_stub, &rshift_kernel_cuda);
43 
44 } // namespace at::native
45