xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/Dispatch.h>
3 #include <ATen/Tensor.h>
4 #include <ATen/Utils.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/sparse/Macros.h>
7 #include <ATen/native/SparseTensorUtils.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/tensor.h>
17 #endif
18 
19 #ifdef GPUCC
20 #define NAME "compressed_index_invariance_checks_cuda"
21 #else
22 #define NAME "compressed_index_invariance_checks_cpu"
23 #endif
24 
25 #define INVARIANT_CHECK_FUNC_API static INLINE FUNCAPI void
26 
27 namespace at::native {
28 
29 namespace {
30 
31 // NOTE: all the checks but the very last one are designed
32 // to work with vectors.
33 // To enable vectorization one would need to write a conversion
34 // Vec -> bool and make kernel launchers call into vectorized
35 // execution paths.
36 
37 // All the invariants are described in
38 // https://pearu.github.io/bsr_tensor_invariants.html NOTE: in the code we also
39 // use `cidx/idx` to refer to `compressed_indices/plain_indices` respectively.
40 
41 INVARIANT_CHECK_FUNC_API
_assert(const bool cond,const char * const message)42 _assert(const bool cond, const char* const message) {
43 #ifdef GPUCC
44   CUDA_KERNEL_ASSERT(cond && message);
45 #else
46   TORCH_CHECK(cond, message);
47 #endif
48 }
49 
50 enum class CDimName : bool { CRow, CCol };
51 
52 // Invariant 5.1
53 // compressed_index[..., 0] == 0.
54 template <CDimName cdim_name, typename index_t>
_check_first_cidx_is_zero(const index_t & cidx,const index_t & zero)55 INVARIANT_CHECK_FUNC_API _check_first_cidx_is_zero(
56     const index_t& cidx,
57     const index_t& zero) {
58   const bool invariant = cidx == zero;
59   if (cdim_name == CDimName::CRow) {
60     _assert(invariant, "`crow_indices[..., 0] == 0` is not satisfied.");
61   } else {
62     _assert(invariant, "`ccol_indices[..., 0] == 0` is not satisfied.");
63   }
64 }
65 
66 // Invariant 5.2
67 // compressed_index[..., -1] == nnz.
68 template <CDimName cdim_name, typename index_t>
_check_last_cidx_is_nnz(const index_t & cidx,const index_t & nnz)69 INVARIANT_CHECK_FUNC_API _check_last_cidx_is_nnz(
70     const index_t& cidx,
71     const index_t& nnz) {
72   const bool invariant = cidx == nnz;
73   if (cdim_name == CDimName::CRow) {
74     _assert(invariant, "`crow_indices[..., -1] == nnz` is not satisfied.");
75   } else {
76     _assert(invariant, "`ccol_indices[..., -1] == nnz` is not satisfied.");
77   }
78 }
79 
80 // Invariant 5.3
81 // 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim.
82 template <CDimName cdim_name, typename index_t>
_check_cidx_nondecreasing_locally_bounded_sequence(const index_t & cidx,const index_t & cidx_next,const index_t & zero,const index_t & dim)83 INVARIANT_CHECK_FUNC_API _check_cidx_nondecreasing_locally_bounded_sequence(
84     const index_t& cidx,
85     const index_t& cidx_next,
86     const index_t& zero,
87     const index_t& dim) {
88   const auto s_cidx = cidx_next - cidx;
89   const bool invariant = zero <= s_cidx && s_cidx <= dim;
90   if (cdim_name == CDimName::CRow) {
91     _assert(
92         invariant,
93         "`0 <= crow_indices[..., 1:] - crow_indices[..., :-1] <= ncols` is not satisfied.");
94   } else {
95     _assert(
96         invariant,
97         "`0 <= ccol_indices[..., 1:] - ccol_indices[..., :-1] <= nrows` is not satisfied.");
98   }
99 }
100 
101 // Invariants 5.4 and 5.5
102 // 0 <= plain_index < plain_dim.
103 template <CDimName cdim_name, typename index_t>
_check_idx_bounds(const index_t & idx,const index_t & zero,const index_t & dim)104 INVARIANT_CHECK_FUNC_API _check_idx_bounds(
105     const index_t& idx,
106     const index_t& zero,
107     const index_t& dim) {
108   const bool invariant = zero <= idx && idx < dim;
109   if (cdim_name == CDimName::CRow) {
110     _assert(invariant, "`0 <= col_indices < ncols` is not satisfied.");
111   } else {
112     _assert(invariant, "`0 <= row_indices < nrows` is not satisfied.");
113   }
114 }
115 
116 // Invariant 5.6
117 // plain_indices[..., compressed_indices[..., i - 1]:compressed_indices[..., i]]
118 // for all i = 1, ..., compressed_dim
119 // are sorted and distinct along the last dimension values.
120 template <CDimName cdim_name, typename index_t>
_check_idx_sorted_distinct_vals_slices_with_cidx(const index_t * RESTRICT ptr_idx_batch,const index_t cidx,const index_t cidx_next)121 INVARIANT_CHECK_FUNC_API _check_idx_sorted_distinct_vals_slices_with_cidx(
122     const index_t* RESTRICT ptr_idx_batch,
123     const index_t cidx,
124     const index_t cidx_next) {
125   // Note that ptr_idx_batch = &idx[batch_idx] and is contiguous.
126   const auto* RESTRICT slice_begin = ptr_idx_batch + cidx;
127   const auto* RESTRICT slice_end = ptr_idx_batch + cidx_next;
128   for (auto* RESTRICT curr = slice_begin; (slice_begin < slice_end) && (curr + 1 < slice_end); ++curr) {
129     const auto invariant = *curr < *(curr + 1);
130     if (cdim_name == CDimName::CRow) {
131       _assert(
132           invariant,
133           "`col_indices[..., crow_indices[..., i - 1]:crow_indices[..., i]] "
134           "for all i = 1, ..., nrows "
135           "are sorted and distinct along the last dimension values` "
136           "is not satisfied.");
137     } else {
138       _assert(
139           invariant,
140           "`row_indices[..., ccol_indices[..., i - 1]:ccol_indices[..., i]] "
141           "for all i = 1, ..., ncols "
142           "are sorted and distinct along the last dimension values` "
143           "is not satisfied.");
144     }
145   }
146 }
147 
indexCount(IntArrayRef sizes)148 static inline int64_t indexCount(IntArrayRef sizes) {
149   int64_t res = 1;
150   for (const auto& s : sizes) {
151     res *= s;
152   }
153   return res;
154 }
155 
156 template <typename func_t, typename vec_func_t>
157 struct EmptyVecKernel {
launchEmptyVecKernel158   static void launch(
159       TensorIteratorBase& iter,
160       const func_t& f,
161       const vec_func_t& vec_f) {}
162 };
163 
164 template <typename scalar_t>
165 using DummyVec = scalar_t;
166 
167 template <
168     template <typename func_t>
169     class kernel_t,
170     template <typename func_t, typename vec_func_t>
171     class vec_kernel_t>
172 struct KernelLauncher {
173   template <typename func_t, typename vec_func_t>
launchKernelLauncher174   static void launch(
175       TensorIteratorBase& iter,
176       const func_t& f,
177       const vec_func_t& vec_f) {
178     vec_kernel_t<func_t, vec_func_t>::launch(iter, f, vec_f);
179   }
180 
181   template <typename func_t>
launchKernelLauncher182   static void launch(TensorIteratorBase& iter, const func_t& f) {
183     kernel_t<func_t>::launch(iter, f);
184   }
185 };
186 
187 template <
188     CDimName cdim_name,
189     template <typename func_t>
190     class kernel_t,
191     template <typename func_t, typename vec_func_t>
192     class vec_kernel_t = EmptyVecKernel,
193     template <typename scalar_t> class Vec = DummyVec,
194     size_t static_shape_max_len = 0>
_validate_compressed_sparse_indices_kernel(const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)195 void _validate_compressed_sparse_indices_kernel(
196     const Tensor& cidx,
197     const Tensor& idx,
198     const int64_t cdim,
199     const int64_t dim,
200     const int64_t nnz) {
201   if (cdim_name == CDimName::CRow) {
202     TORCH_CHECK(
203         cidx.size(-1) == cdim + 1,
204         "crow_indices have wrong shape: ",
205         "crow_indices.shape[-1] = ",
206         cidx.size(-1),
207         " is not equal to ",
208         "nrows + 1 = ",
209         cdim + 1);
210     TORCH_CHECK(
211         idx.size(-1) == nnz,
212         "col_indices have wrong shape: ",
213         "col_indices.shape[-1] = ",
214         idx.size(-1),
215         " is not equal to ",
216         "nnz = ",
217         nnz);
218   } else {
219     TORCH_CHECK(
220         cidx.size(-1) == cdim + 1,
221         "ccol_indices have wrong shape: ",
222         "ccol_indices.shape[-1] = ",
223         cidx.size(-1),
224         " is not equal to ",
225         "ncols + 1 = ",
226         cdim + 1);
227     TORCH_CHECK(
228         idx.size(-1) == nnz,
229         "row_indices have wrong shape: ",
230         "row_indices.shape[-1] = ",
231         idx.size(-1),
232         " is not equal to ",
233         "nnz = ",
234         nnz);
235   }
236 
237   using KernelLauncher = KernelLauncher<kernel_t, vec_kernel_t>;
238 
239   // For TensorIterator's output: no void lambdas.
240   const auto dummy = at::empty({1}, cidx.options());
241 
242   // Catch integer overflow from large dimensions. Otherwise, the
243   // invariant checks may fail with bogus exceptions or succeed with
244   // false-positive results when int64_t typed dimensions are cast to
245   // index dtype that corresponds to smaller interger type such as
246   // int32_t.
247   {
248     AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [cdim, dim, nnz]() {
249       if (cdim_name == CDimName::CRow) {
250         TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(dim)) == dim,
251                     sizeof(index_t) * 8, "-bit integer overflow in column dimension = ", dim);
252         TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(cdim)) == cdim,
253                     sizeof(index_t) * 8, "-bit integer overflow in row dimension = ", cdim);
254       } else {
255         TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(dim)) == dim,
256                     sizeof(index_t) * 8, "-bit integer overflow in row dimension = ", dim);
257         TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(cdim)) == cdim,
258                     sizeof(index_t) * 8, "-bit integer overflow in column dimension = ", cdim);
259       }
260       TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(nnz)) == nnz,
261                   sizeof(index_t) * 8, "-bit integer overflow in nnz = ", nnz);
262     });
263   }
264 
265   // Invariants 5.4 and 5.5
266   {
267     auto iter = TensorIteratorConfig()
268                     .set_check_mem_overlap(false)
269                     .add_owned_output(dummy.expand_as(idx))
270                     .add_input(idx)
271                     .build();
272 
273     AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, dim]() {
274       const auto zero = index_t{0};
275       KernelLauncher::launch(iter, [zero, dim] FUNCAPI(index_t idx) -> index_t {
276         _check_idx_bounds<cdim_name, index_t>(idx, zero, dim);
277         return 0;
278       });
279     });
280   }
281 
282   // Invariants 5.1, 5.2, 5.3, 5.6
283   {
284     const auto cidx_first = cidx.slice(-1, 0, 1);
285     const auto cidx_last = cidx.slice(-1, cdim, cdim + 1);
286 
287     const auto cidx_curr = cidx.slice(-1, 0, cdim);
288     const auto cidx_next = cidx.slice(-1, 1, cdim + 1);
289 
290     const auto batch_dims = cidx.sizes().slice(0, cidx.dim() - 1);
291     const auto batch_count = indexCount(batch_dims);
292     const auto batch_idx =
293         at::arange(batch_count, cidx.options()).view(batch_dims).unsqueeze_(-1);
294 
295     const auto idx_ndims = idx.dim();
296 
297     const auto idx_geometry_holder = at::sparse::TensorGeometryHolder<static_shape_max_len>(idx);
298     const auto idx_sizes = std::get<0>(*idx_geometry_holder);
299     const auto idx_strides = std::get<1>(*idx_geometry_holder);
300 
301     auto iter = TensorIteratorConfig()
302                     .set_check_mem_overlap(false)
303                     .add_owned_output(dummy.expand_as(cidx_curr))
304                     .add_input(cidx_first)
305                     .add_input(cidx_last)
306                     .add_input(cidx_curr)
307                     .add_input(cidx_next)
308                     .add_input(batch_idx)
309                     .build();
310 
311     AT_DISPATCH_INDEX_TYPES(
312         idx.scalar_type(),
313         NAME,
314         [&iter, &idx, dim, nnz, idx_ndims, &idx_sizes, &idx_strides]() {
315           const auto* RESTRICT ptr_idx = idx.const_data_ptr<index_t>();
316           const auto zero = index_t{0};
317           KernelLauncher::launch(
318               iter,
319               [zero, dim, nnz, idx_ndims, idx_sizes, idx_strides, ptr_idx] FUNCAPI(
320                   index_t cidx_first,
321                   index_t cidx_last,
322                   index_t cidx_curr,
323                   index_t cidx_next,
324                   index_t batch_idx) -> index_t {
325                 // Invariant 5.1
326                 _check_first_cidx_is_zero<cdim_name, index_t>(cidx_first, zero);
327                 // Invariant 5.2
328                 _check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
329                 // Invariant 5.3
330                 _check_cidx_nondecreasing_locally_bounded_sequence<
331                     cdim_name,
332                     index_t>(cidx_curr, cidx_next, zero, dim);
333                 // Invariant 5.6
334                 // NOTE: the implementation below is sync-less, but,
335                 // unfortunately, work is not guaranteed to be well-balanced
336                 // between different threads.
337                 // Note: 5.6 should not be tested when
338                 // nnz==0. Fortunately, the code below is no-op when
339                 // nnz==0.
340                 int64_t idx_offset = 0;
341                 // assuming idx contiguity per batch:
342                 int64_t tmp = batch_idx * nnz;
343                 // `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)`
344                 for (int i = idx_ndims - 1;
345                      i >= 0 && nnz > 0;  // break early when nnz==0
346                      i--) {
347                   int64_t div = tmp / idx_sizes[i];
348                   idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i];
349                   tmp = div;
350                 }
351                 const auto* RESTRICT ptr_idx_batch = ptr_idx + idx_offset;
352                 _check_idx_sorted_distinct_vals_slices_with_cidx<
353                     cdim_name,
354                     index_t>(ptr_idx_batch, cidx_curr, cidx_next);
355                 return 0;
356               });
357         });
358   }
359 }
360 
361 template <
362     template <typename func_t>
363     class kernel_t,
364     template <typename func_t, typename vec_func_t>
365     class vec_kernel_t = EmptyVecKernel,
366     template <typename scalar_t> class Vec = DummyVec>
validate_compressed_sparse_indices_kernel(const bool is_crow,const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)367 void validate_compressed_sparse_indices_kernel(
368     const bool is_crow,
369     const Tensor& cidx,
370     const Tensor& idx,
371     const int64_t cdim,
372     const int64_t dim,
373     const int64_t nnz) {
374   constexpr size_t idx_max_ndims = 8; // up to 7-dim batch.
375   const size_t idx_ndims = static_cast<size_t>(idx.dim());
376 
377   if (is_crow) {
378     if (idx_ndims <= idx_max_ndims) {
379       _validate_compressed_sparse_indices_kernel<
380           CDimName::CRow,
381           kernel_t,
382           vec_kernel_t,
383           Vec,
384           idx_max_ndims>(cidx, idx, cdim, dim, nnz);
385     }
386     else {
387       _validate_compressed_sparse_indices_kernel<
388           CDimName::CRow,
389           kernel_t,
390           vec_kernel_t,
391           Vec>(cidx, idx, cdim, dim, nnz);
392     }
393   } else {
394     if (idx_ndims <= idx_max_ndims) {
395       _validate_compressed_sparse_indices_kernel<
396           CDimName::CCol,
397           kernel_t,
398           vec_kernel_t,
399           Vec,
400           idx_max_ndims>(cidx, idx, cdim, dim, nnz);
401     }
402     else {
403       _validate_compressed_sparse_indices_kernel<
404           CDimName::CCol,
405           kernel_t,
406           vec_kernel_t,
407           Vec>(cidx, idx, cdim, dim, nnz);
408     }
409   }
410 }
411 
412 } // namespace
413 
414 } // namespace at::native
415