1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/native/Repeat.h>
6
7 #ifndef AT_PER_OPERATOR_HEADERS
8 #include <ATen/NativeFunctions.h>
9 #else
10 #include <ATen/ops/repeat_interleave_native.h>
11 #endif
12
13 template <typename index_t>
compute_cuda_kernel(const index_t * repeat_ptr,const int64_t * cumsum_ptr,index_t * result_ptr,int64_t size,int64_t result_size)14 __global__ static void compute_cuda_kernel(
15 const index_t* repeat_ptr,
16 const int64_t* cumsum_ptr,
17 index_t* result_ptr,
18 int64_t size,
19 int64_t result_size) {
20 CUDA_KERNEL_ASSERT(result_size == cumsum_ptr[size - 1]);
21 int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
22 int64_t stride = (blockDim.x * gridDim.x) / C10_WARP_SIZE;
23 int warp_id = idx / C10_WARP_SIZE;
24 int tid_in_warp = idx % C10_WARP_SIZE;
25 for (int64_t i = warp_id; i < size; i += stride) {
26 int64_t end = cumsum_ptr[i];
27 index_t repeat = repeat_ptr[i];
28 CUDA_KERNEL_ASSERT(repeat >= 0);
29 int64_t start = end - repeat;
30 for (int64_t j = start + tid_in_warp; j < end; j += C10_WARP_SIZE) {
31 result_ptr[j] = i;
32 }
33 }
34 }
35
36 template <typename index_t>
compute_cuda(const index_t * repeat_ptr,const int64_t * cumsum_ptr,index_t * result_ptr,int64_t size,int64_t result_size)37 static void compute_cuda(
38 const index_t* repeat_ptr,
39 const int64_t* cumsum_ptr,
40 index_t* result_ptr,
41 int64_t size,
42 int64_t result_size) {
43 int64_t block = 512;
44 int64_t warps_per_block = block / at::cuda::warp_size();
45 int64_t grid =
46 std::min<int64_t>((size + warps_per_block - 1) / warps_per_block, 2048L);
47
48 compute_cuda_kernel<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
49 repeat_ptr, cumsum_ptr, result_ptr, size, result_size);
50 C10_CUDA_KERNEL_LAUNCH_CHECK();
51 }
52
53 namespace at::native {
54
repeat_interleave_cuda(const Tensor & repeat,std::optional<int64_t> output_size)55 Tensor repeat_interleave_cuda(
56 const Tensor& repeat,
57 std::optional<int64_t> output_size) {
58 Tensor output;
59 AT_DISPATCH_INDEX_TYPES(
60 repeat.scalar_type(), "repeat_interleave_cuda", [&]() {
61 output = repeat_interleave_common<index_t, compute_cuda<index_t>>(
62 repeat, output_size);
63 });
64 return output;
65 }
66
67 } // namespace at::native
68