xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Repeat.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/native/Repeat.h>
6 
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/repeat_interleave_native.h>
11 #endif
12 
13 template <typename index_t>
compute_cuda_kernel(const index_t * repeat_ptr,const int64_t * cumsum_ptr,index_t * result_ptr,int64_t size,int64_t result_size)14 __global__ static void compute_cuda_kernel(
15     const index_t* repeat_ptr,
16     const int64_t* cumsum_ptr,
17     index_t* result_ptr,
18     int64_t size,
19     int64_t result_size) {
20   CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]);
21   int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
22   int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
23   int warp_id = idx / C10_WARP_SIZE;
24   int tid_in_warp = idx % C10_WARP_SIZE;
25   for (int64_t i = warp_id; i < size; i += stride) {
26     int64_t end = cumsum_ptr[i];
27     index_t repeat = repeat_ptr[i];
28     CUDA_KERNEL_ASSERT(repeat >= 0);
29     int64_t start = end - repeat;
30     for (int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) {
31       result_ptr[j] = i;
32     }
33   }
34 }
35 
36 template <typename index_t>
compute_cuda(const index_t * repeat_ptr,const int64_t * cumsum_ptr,index_t * result_ptr,int64_t size,int64_t result_size)37 static void compute_cuda(
38     const index_t* repeat_ptr,
39     const int64_t* cumsum_ptr,
40     index_t* result_ptr,
41     int64_t size,
42     int64_t result_size) {
43   int64_t block = 512;
44   int64_t warps_per_block = block / at::cuda::warp_size();
45   int64_t grid =
46       std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
47 
48   compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
49       repeat_ptr, cumsum_ptr, result_ptr, size, result_size);
50   C10_CUDA_KERNEL_LAUNCH_CHECK();
51 }
52 
53 namespace at::native {
54 
repeat_interleave_cuda(const Tensor & repeat,std::optional<int64_t> output_size)55 Tensor repeat_interleave_cuda(
56     const Tensor& repeat,
57     std::optional<int64_t> output_size) {
58   Tensor output;
59   AT_DISPATCH_INDEX_TYPES(
60       repeat.scalar_type(), "repeat_interleave_cuda", [&]() {
61         output = repeat_interleave_common<index_t, compute_cuda<index_t>>(
62             repeat, output_size);
63       });
64   return output;
65 }
66 
67 } // namespace at::native
68