1 #include <ATen/cuda/CUDAGeneratorImpl.h>
2 #include <ATen/cuda/CUDAGraphsUtils.cuh>
3 #include <ATen/Utils.h>
4
5 #include <curand.h>
6 #include <curand_kernel.h>
7 #include <curand_philox4x32_x.h>
8
9 namespace {
10
11 // See note [Algorithm of randperm]
12 template<typename T, typename scalar_t>
randperm_handle_duplicate_keys_kernel(T * keys,scalar_t * data,T mask,int n,at::PhiloxCudaState philox_args)13 __global__ void randperm_handle_duplicate_keys_kernel(T *keys, scalar_t *data, T mask, int n, at::PhiloxCudaState philox_args) {
14 int tid = threadIdx.x + blockDim.x * blockIdx.x;
15
16 // find the beginning of islands
17 if (tid >= n - 1) return; // out of range
18 if ((keys[tid] & mask) != (keys[tid + 1] & mask)) return; // not in an island
19 if (tid != 0 && (keys[tid] & mask) == (keys[tid - 1] & mask)) return; // not the beginning of an island
20
21 // find the size of islands
22 int island_size = 0;
23 do { island_size++; }
24 while ((tid + island_size < n) && (keys[tid + island_size] & mask) == (keys[tid] & mask));
25
26 // do random permutation inside each island.
27 data += tid;
28 auto seeds = at::cuda::philox::unpack(philox_args);
29 curandStatePhilox4_32_10_t state;
30 curand_init(std::get<0>(seeds), tid, std::get<1>(seeds), &state);
31 for (int i = island_size - 1; i > 0; i--) {
32 unsigned int r = curand(&state) % (i + 1);
33 if (i != r) {
34 scalar_t tmp = data[i];
35 data[i] = data[r];
36 data[r] = tmp;
37 }
38 }
39 }
40
41 // See note [Algorithm of randperm]
42 template<typename T, typename scalar_t>
randperm_handle_duplicate_keys(T * keys,scalar_t * data,int bits,int64_t n,std::optional<at::Generator> & gen_)43 void randperm_handle_duplicate_keys(T *keys, scalar_t *data, int bits, int64_t n, std::optional<at::Generator> &gen_) {
44 auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(gen_, at::cuda::detail::getDefaultCUDAGenerator());
45 int64_t counter_offset = n;
46 at::PhiloxCudaState rng_engine_inputs;
47 {
48 // See Note [Acquire lock when using random generators]
49 std::lock_guard<std::mutex> lock(gen->mutex_);
50 rng_engine_inputs = gen->philox_cuda_state(counter_offset);
51 }
52 T mask = static_cast<T>((1UL << bits) - 1);
53 randperm_handle_duplicate_keys_kernel<<<(n + 511) / 512, 512, 0, at::cuda::getCurrentCUDAStream()>>>(
54 keys, data, mask, n, rng_engine_inputs);
55 C10_CUDA_KERNEL_LAUNCH_CHECK();
56 }
57
58 }
59