xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Randperm.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAGeneratorImpl.h>
2 #include <ATen/cuda/CUDAGraphsUtils.cuh>
3 #include <ATen/Utils.h>
4 
5 #include <curand.h>
6 #include <curand_kernel.h>
7 #include <curand_philox4x32_x.h>
8 
9 namespace {
10 
11 // See note [Algorithm of randperm]
12 template<typename T, typename scalar_t>
randperm_handle_duplicate_keys_kernel(T * keys,scalar_t * data,T mask,int n,at::PhiloxCudaState philox_args)13 __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
14   int tid = threadIdx.x + blockDim.x * blockIdx.x;
15 
16   // find the beginning of islands
17   if (tid >= n - 1) return;  // out of range
18   if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return;  // not in an island
19   if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return;  // not the beginning of an island
20 
21   // find the size of islands
22   int island_size = 0;
23   do { island_size++; }
24   while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
25 
26   // do random permutation inside each island.
27   data += tid;
28   auto seeds = at::cuda::philox::unpack(philox_args);
29   curandStatePhilox4_32_10_t state;
30   curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
31   for (int i = island_size - 1; i > 0; i--) {
32     unsigned int r = curand(&state) % (i + 1);
33     if (i != r) {
34       scalar_t tmp = data[i];
35       data[i] = data[r];
36       data[r] = tmp;
37     }
38   }
39 }
40 
41 // See note [Algorithm of randperm]
42 template<typename T, typename scalar_t>
randperm_handle_duplicate_keys(T * keys,scalar_t * data,int bits,int64_t n,std::optional<at::Generator> & gen_)43 void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional<at::Generator> &gen_) {
44   auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
45   int64_t counter_offset = n;
46   at::PhiloxCudaState rng_engine_inputs;
47   {
48     // See Note [Acquire lock when using random generators]
49     std::lock_guard<std::mutex> lock(gen->mutex_);
50     rng_engine_inputs = gen->philox_cuda_state(counter_offset);
51   }
52   T mask = static_cast<T>((1UL << bits) - 1);
53   randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
54     keys, data, mask, n, rng_engine_inputs);
55   C10_CUDA_KERNEL_LAUNCH_CHECK();
56 }
57 
58 }
59