xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Randperm.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/native/TensorFactories.h>
6 #include <ATen/cuda/cub.h>
7 #include <ATen/native/cuda/Randperm.cuh>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/arange.h>
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/empty_like.h>
16 #include <ATen/ops/randperm_native.h>
17 #endif
18 
19 #include <limits>
20 
21 namespace at::native {
22 
23 // [Algorithm of randperm]
24 //
25 // randperm is implemented by sorting an arange tensor of size n with randomly
26 // generated keys. When random keys are different from each other, all different
27 // permutations have the same probability.
28 //
29 // However, there is a pitfall here:
30 // For better performance, these N random keys are generated independently,
31 // and there is no effort to make sure they are different at the time of generation.
32 // When two keys are identical, stable sorting algorithms will not permute these two keys.
33 // As a result, (0, 1) will appear more often than (1, 0).
34 //
35 // To overcome this pitfall we first carefully choose the number of bits in these keys,
36 // so that the probability of having duplicate keys is under a threshold. Let q be the
37 // threshold probability for having non-duplicate keys, then it can be proved that[1]
38 // the number of bits required is: ceil(log2(n - (6 n^2 + 1) / (12 log(q))))
39 //
40 // Then after sort, we lauch a separate kernel that additionally shuffles any islands
41 // of values whose keys matched. The algorithm of this kernel is as follows:
42 // Each thread reads its key and the keys of its neighbors to tell if it's part of an island.
43 // For each island, the first thread in the island sees a key match at index i+1 but not index i-1.
44 // This thread considers itself the "island leader". The island leader then reads more indices to
45 // the right to figure out how big the island is. Most likely, the island will be very small,
46 // just a few values. The island leader then rolls that many RNG, uses them to additionally
47 // shuffle values within the island using serial Fisher-Yates, and writes them out.
48 //
49 // Reference
50 // [1] https://osf.io/af2hy/
51 
52 // The kernels are templated on an opaque, self-aligned type of the correct
53 // size to avoid redundant kernels for different types of the same size.
54 namespace {
55 template <int N> struct alignas(N) OpaqueType { char data[N]; };
56 }
57 
randperm_out_cuda(int64_t n,std::optional<Generator> generator,Tensor & result)58 Tensor& randperm_out_cuda(int64_t n, std::optional<Generator> generator, Tensor& result) {
59   TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
60 
61   check_supported_max_int_with_precision(n, result);
62 
63   result.resize_({n});
64 
65   auto range = at::arange(n, result.options());
66 
67   // shuffled_data points to the underlying data of the output tensor if the tensor is contiguous; otherwise it
68   // points to a new tensor.
69   Tensor shuffled;
70   void *shuffled_data;
71   if (result.is_contiguous()) {
72     shuffled_data = result.data_ptr();
73   } else {
74     shuffled = at::empty(n, result.options());
75     shuffled_data = shuffled.data_ptr();
76   }
77 
78   auto opt = TensorOptions().device(result.device());
79 
80   // See note [Algorithm of randperm]
81   const double log_threshold_12 = std::log(0.9) * 12;
82   double nd = static_cast<double>(n);
83 
84   int bits = std::min(64,
85     static_cast<int>(std::ceil(std::log2(nd - (6 * nd * nd + 1) / log_threshold_12))));
86 
87   if (n == 0) {
88     return result;
89   } else if (bits <= 32) {
90     // For asserting device type match of the generator and result,
91     // we deligate that to the 'random_' function below.
92 
93     auto keys = at::empty(result.sizes(), opt.dtype(kInt)).random_(
94       std::numeric_limits<int>::min(), std::numeric_limits<int>::max(), generator);
95     auto keys_tmp = at::empty_like(keys);
96     auto keys_out = keys_tmp.mutable_data_ptr<int>();
97     AT_DISPATCH_ALL_TYPES_AND(kHalf, result.scalar_type(), "randperm_out_cuda", [&] {
98       using dtype = OpaqueType<sizeof(scalar_t)>;
99       auto shuffled_data_ = reinterpret_cast<dtype*>(shuffled_data);
100       auto* range_data = reinterpret_cast<const dtype*>(range.const_data_ptr());
101       at::cuda::cub::radix_sort_pairs<int, dtype>(
102         keys.const_data_ptr<int>(), keys_out,
103         range_data, shuffled_data_,
104         n, false, 0, bits);
105 
106       randperm_handle_duplicate_keys(keys_out, shuffled_data_, bits, n, generator);
107     });
108   } else {
109     auto keys = at::empty(result.sizes(), opt.dtype(kLong)).random_(
110       std::numeric_limits<int64_t>::min(), std::numeric_limits<int64_t>::max(), generator);
111     auto keys_tmp = at::empty_like(keys);
112     auto keys_out = keys_tmp.mutable_data_ptr<int64_t>();
113     AT_DISPATCH_ALL_TYPES_AND(kHalf, result.scalar_type(), "randperm_out_cuda", [&] {
114       using dtype = OpaqueType<sizeof(scalar_t)>;
115       auto shuffled_data_ = reinterpret_cast<dtype*>(shuffled_data);
116       auto* range_data = reinterpret_cast<const dtype*>(range.data_ptr());
117       at::cuda::cub::radix_sort_pairs<int64_t, dtype>(
118         keys.const_data_ptr<int64_t>(), keys_out,
119         range_data, shuffled_data_,
120         n, false, 0, bits);
121 
122       randperm_handle_duplicate_keys(keys_out, shuffled_data_, bits, n, generator);
123     });
124   }
125 
126   if (!result.is_contiguous()) {
127     result.copy_(shuffled);
128   }
129 
130   return result;
131 }
132 
133 } // namespace at::native
134