xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SortingCommon.cuh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/ceil_div.h>
4 #include <ATen/NumericUtils.h>
5 #include <c10/macros/Macros.h>
6 #include <stdlib.h>
7 #include <ATen/cuda/detail/IndexUtils.cuh>
8 #include <ATen/cuda/detail/TensorInfo.cuh>
9 
10 namespace at {
11 namespace native {
12 
13 // Is this questionable namespace pollution?
14 #if defined(USE_ROCM)
15 constexpr int MAX_BLOCK_SIZE = 256;
16 
17 #else
18 constexpr int MAX_BLOCK_SIZE = 1024;
19 #endif
20 
21 // Maximum size per grid dimension that we assume (compute capability >= 2.0)
22 constexpr int64_t MAX_GRID_SIZE = 65535LL;
23 
getGridFromTiles(int64_t gridTiles,dim3 & grid)24 inline bool getGridFromTiles(int64_t gridTiles, dim3& grid) {
25   if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) {
26     return false;
27   }
28 
29   int64_t gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
30   int64_t gridY = 1;
31   int64_t gridZ = 1;
32 
33   if (gridTiles > MAX_GRID_SIZE) {
34     gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE);
35     gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
36 
37     if (gridTiles > MAX_GRID_SIZE) {
38       gridTiles = ceil_div(gridTiles, MAX_GRID_SIZE);
39       gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles;
40     }
41   }
42 
43   grid = dim3(gridX, gridY, gridZ);
44   return true;
45 }
46 
47 template <typename scalar_t, bool handleNaN = false>
48 struct GTOp {
operator ()at::native::GTOp49   __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
50     return (handleNaN && at::_isnan(lhs) && !at::_isnan(rhs)) || (lhs > rhs);
51   }
52 };
53 
54 template <typename scalar_t, bool handleNaN = false>
55 struct LTOp {
operator ()at::native::LTOp56   __device__ bool operator()(const scalar_t& lhs, const scalar_t& rhs) const {
57     return (handleNaN && at::_isnan(rhs) && !at::_isnan(lhs)) || (lhs < rhs);
58   }
59 };
60 
61 template <typename index_t>
getLinearBlockId()62 __device__ __forceinline__ index_t getLinearBlockId() {
63   return blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x +
64       blockIdx.x;
65 }
66 
67 // For slice sorting in Thrust; extracts a slice index from a linear
68 // index and uses that for comparison
69 struct SliceComp {
SliceCompat::native::SliceComp70   SliceComp(int64_t size) : sliceSize(size) {}
71 
operator ()at::native::SliceComp72   __device__ bool operator()(const int64_t& a, const int64_t& b) const {
73     // Since the slices are guaranteed to be innermost,
74     // the segment is just via int64_t division
75     int64_t segA = a / sliceSize;
76     int64_t segB = b / sliceSize;
77     return segA < segB;
78   }
79 
80   const int64_t sliceSize;
81 };
82 
83 // For sorting in Thurst; extracts a within-slice index from a linear index
84 struct GlobalIndexToPerSliceIndex {
GlobalIndexToPerSliceIndexat::native::GlobalIndexToPerSliceIndex85   GlobalIndexToPerSliceIndex(int64_t size) : sliceSize(size) {}
86 
operator ()at::native::GlobalIndexToPerSliceIndex87   __device__ inline void operator()(int64_t& v) const {
88     v = v % sliceSize;
89   }
90 
91   const int64_t sliceSize;
92 };
93 
94 // Returns 2^(ceil(lg(n)) from Stanford bit twiddling hacks
nextHighestPowerOf2(uint64_t n)95 inline uint64_t nextHighestPowerOf2(uint64_t n) {
96   n--;
97   n |= n >> 1;
98   n |= n >> 2;
99   n |= n >> 4;
100   n |= n >> 8;
101   n |= n >> 16;
102 #ifndef _MSC_VER
103   n |= n >> 32;
104 #endif
105   n++;
106 
107   return n;
108 }
109 
110 
111 // WARNING: This function assumes input tensors are contiguous
112 template <typename scalar_t, typename index_t, typename Launcher>
run_launcher(const TensorBase & values,const TensorBase & indices,const TensorBase & self,int64_t dim,Launcher l)113 void run_launcher(
114     const TensorBase &values,
115     const TensorBase &indices,
116     const TensorBase &self,
117     int64_t dim,
118     Launcher l) {
119   auto self_info = cuda::detail::getTensorInfo<const scalar_t, index_t>(self);
120   auto values_info = cuda::detail::getTensorInfo<scalar_t, index_t>(values);
121   auto indices_info = cuda::detail::getTensorInfo<int64_t, index_t>(indices);
122 
123   int64_t slice_size = self.size(dim);
124   /* We use these structures solely to find the offset to */
125   /* each slice we are operating on */
126   self_info.reduceDim(dim);
127   values_info.reduceDim(dim);
128   indices_info.reduceDim(dim);
129 
130   /* Collapse all other dims */
131   int collapse_self_dim = self_info.collapseDims(dim);
132   int collapse_values_dim = values_info.collapseDims(dim);
133   int collapse_indices_dim = indices_info.collapseDims(dim);
134 
135   int64_t num_slices = 1;
136   for (int i = 0; i < self_info.dims; ++i) {
137     num_slices *= self_info.sizes[i];
138   }
139 
140   /* This is used as a template parameter to calculate indices. */
141   /* We only specialize it if all collapsed dim sizes are the */
142   /* same; otherwise, we use -1 which is the specialization */
143   /* parameter for arbitrary dimensions */
144   int all_dims = self_info.dims;
145   if (values_info.dims != all_dims || indices_info.dims != all_dims) {
146     all_dims = -1;
147   }
148 
149   if (all_dims == 1) {
150     l.template launch<scalar_t, index_t, 1>(
151         values_info,
152         collapse_values_dim,
153         indices_info,
154         collapse_indices_dim,
155         self_info,
156         collapse_self_dim,
157         num_slices,
158         slice_size);
159   } else if (all_dims == 2) {
160     l.template launch<scalar_t, index_t, 2>(
161         values_info,
162         collapse_values_dim,
163         indices_info,
164         collapse_indices_dim,
165         self_info,
166         collapse_self_dim,
167         num_slices,
168         slice_size);
169   } else if (all_dims == 3) {
170     l.template launch<scalar_t, index_t, 3>(
171         values_info,
172         collapse_values_dim,
173         indices_info,
174         collapse_indices_dim,
175         self_info,
176         collapse_self_dim,
177         num_slices,
178         slice_size);
179   } else {
180     l.template launch<scalar_t, index_t, -1>(
181         values_info,
182         collapse_values_dim,
183         indices_info,
184         collapse_indices_dim,
185         self_info,
186         collapse_self_dim,
187         num_slices,
188         slice_size);
189   }
190 }
191 
192 } // namespace native
193 } // namespace at
194