1 #pragma once
2 #include <ATen/Dispatch.h>
3 #include <ATen/Tensor.h>
4 #include <ATen/Utils.h>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/sparse/Macros.h>
7 #include <ATen/native/SparseTensorUtils.h>
8
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/tensor.h>
17 #endif
18
19 #ifdef GPUCC
20 #define NAME "compressed_index_invariance_checks_cuda"
21 #else
22 #define NAME "compressed_index_invariance_checks_cpu"
23 #endif
24
25 #define INVARIANT_CHECK_FUNC_API static INLINE FUNCAPI void
26
27 namespace at::native {
28
29 namespace {
30
31 // NOTE: all the checks but the very last one are designed
32 // to work with vectors.
33 // To enable vectorization one would need to write a conversion
34 // Vec -> bool and make kernel launchers call into vectorized
35 // execution paths.
36
37 // All the invariants are described in
38 // https://pearu.github.io/bsr_tensor_invariants.html NOTE: in the code we also
39 // use `cidx/idx` to refer to `compressed_indices/plain_indices` respectively.
40
41 INVARIANT_CHECK_FUNC_API
_assert(const bool cond,const char * const message)42 _assert(const bool cond, const char* const message) {
43 #ifdef GPUCC
44 CUDA_KERNEL_ASSERT(cond && message);
45 #else
46 TORCH_CHECK(cond, message);
47 #endif
48 }
49
50 enum class CDimName : bool { CRow, CCol };
51
52 // Invariant 5.1
53 // compressed_index[..., 0] == 0.
54 template <CDimName cdim_name, typename index_t>
_check_first_cidx_is_zero(const index_t & cidx,const index_t & zero)55 INVARIANT_CHECK_FUNC_API _check_first_cidx_is_zero(
56 const index_t& cidx,
57 const index_t& zero) {
58 const bool invariant = cidx == zero;
59 if (cdim_name == CDimName::CRow) {
60 _assert(invariant, "`crow_indices[..., 0] == 0` is not satisfied.");
61 } else {
62 _assert(invariant, "`ccol_indices[..., 0] == 0` is not satisfied.");
63 }
64 }
65
66 // Invariant 5.2
67 // compressed_index[..., -1] == nnz.
68 template <CDimName cdim_name, typename index_t>
_check_last_cidx_is_nnz(const index_t & cidx,const index_t & nnz)69 INVARIANT_CHECK_FUNC_API _check_last_cidx_is_nnz(
70 const index_t& cidx,
71 const index_t& nnz) {
72 const bool invariant = cidx == nnz;
73 if (cdim_name == CDimName::CRow) {
74 _assert(invariant, "`crow_indices[..., -1] == nnz` is not satisfied.");
75 } else {
76 _assert(invariant, "`ccol_indices[..., -1] == nnz` is not satisfied.");
77 }
78 }
79
80 // Invariant 5.3
81 // 0 <= compressed_indices[..., 1:] - compressed_indices[..., :-1] <= plain_dim.
82 template <CDimName cdim_name, typename index_t>
_check_cidx_nondecreasing_locally_bounded_sequence(const index_t & cidx,const index_t & cidx_next,const index_t & zero,const index_t & dim)83 INVARIANT_CHECK_FUNC_API _check_cidx_nondecreasing_locally_bounded_sequence(
84 const index_t& cidx,
85 const index_t& cidx_next,
86 const index_t& zero,
87 const index_t& dim) {
88 const auto s_cidx = cidx_next - cidx;
89 const bool invariant = zero <= s_cidx && s_cidx <= dim;
90 if (cdim_name == CDimName::CRow) {
91 _assert(
92 invariant,
93 "`0 <= crow_indices[..., 1:] - crow_indices[..., :-1] <= ncols` is not satisfied.");
94 } else {
95 _assert(
96 invariant,
97 "`0 <= ccol_indices[..., 1:] - ccol_indices[..., :-1] <= nrows` is not satisfied.");
98 }
99 }
100
101 // Invariants 5.4 and 5.5
102 // 0 <= plain_index < plain_dim.
103 template <CDimName cdim_name, typename index_t>
_check_idx_bounds(const index_t & idx,const index_t & zero,const index_t & dim)104 INVARIANT_CHECK_FUNC_API _check_idx_bounds(
105 const index_t& idx,
106 const index_t& zero,
107 const index_t& dim) {
108 const bool invariant = zero <= idx && idx < dim;
109 if (cdim_name == CDimName::CRow) {
110 _assert(invariant, "`0 <= col_indices < ncols` is not satisfied.");
111 } else {
112 _assert(invariant, "`0 <= row_indices < nrows` is not satisfied.");
113 }
114 }
115
116 // Invariant 5.6
117 // plain_indices[..., compressed_indices[..., i - 1]:compressed_indices[..., i]]
118 // for all i = 1, ..., compressed_dim
119 // are sorted and distinct along the last dimension values.
120 template <CDimName cdim_name, typename index_t>
_check_idx_sorted_distinct_vals_slices_with_cidx(const index_t * RESTRICT ptr_idx_batch,const index_t cidx,const index_t cidx_next)121 INVARIANT_CHECK_FUNC_API _check_idx_sorted_distinct_vals_slices_with_cidx(
122 const index_t* RESTRICT ptr_idx_batch,
123 const index_t cidx,
124 const index_t cidx_next) {
125 // Note that ptr_idx_batch = &idx[batch_idx] and is contiguous.
126 const auto* RESTRICT slice_begin = ptr_idx_batch + cidx;
127 const auto* RESTRICT slice_end = ptr_idx_batch + cidx_next;
128 for (auto* RESTRICT curr = slice_begin; (slice_begin < slice_end) && (curr + 1 < slice_end); ++curr) {
129 const auto invariant = *curr < *(curr + 1);
130 if (cdim_name == CDimName::CRow) {
131 _assert(
132 invariant,
133 "`col_indices[..., crow_indices[..., i - 1]:crow_indices[..., i]] "
134 "for all i = 1, ..., nrows "
135 "are sorted and distinct along the last dimension values` "
136 "is not satisfied.");
137 } else {
138 _assert(
139 invariant,
140 "`row_indices[..., ccol_indices[..., i - 1]:ccol_indices[..., i]] "
141 "for all i = 1, ..., ncols "
142 "are sorted and distinct along the last dimension values` "
143 "is not satisfied.");
144 }
145 }
146 }
147
indexCount(IntArrayRef sizes)148 static inline int64_t indexCount(IntArrayRef sizes) {
149 int64_t res = 1;
150 for (const auto& s : sizes) {
151 res *= s;
152 }
153 return res;
154 }
155
156 template <typename func_t, typename vec_func_t>
157 struct EmptyVecKernel {
launchEmptyVecKernel158 static void launch(
159 TensorIteratorBase& iter,
160 const func_t& f,
161 const vec_func_t& vec_f) {}
162 };
163
164 template <typename scalar_t>
165 using DummyVec = scalar_t;
166
167 template <
168 template <typename func_t>
169 class kernel_t,
170 template <typename func_t, typename vec_func_t>
171 class vec_kernel_t>
172 struct KernelLauncher {
173 template <typename func_t, typename vec_func_t>
launchKernelLauncher174 static void launch(
175 TensorIteratorBase& iter,
176 const func_t& f,
177 const vec_func_t& vec_f) {
178 vec_kernel_t<func_t, vec_func_t>::launch(iter, f, vec_f);
179 }
180
181 template <typename func_t>
launchKernelLauncher182 static void launch(TensorIteratorBase& iter, const func_t& f) {
183 kernel_t<func_t>::launch(iter, f);
184 }
185 };
186
187 template <
188 CDimName cdim_name,
189 template <typename func_t>
190 class kernel_t,
191 template <typename func_t, typename vec_func_t>
192 class vec_kernel_t = EmptyVecKernel,
193 template <typename scalar_t> class Vec = DummyVec,
194 size_t static_shape_max_len = 0>
_validate_compressed_sparse_indices_kernel(const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)195 void _validate_compressed_sparse_indices_kernel(
196 const Tensor& cidx,
197 const Tensor& idx,
198 const int64_t cdim,
199 const int64_t dim,
200 const int64_t nnz) {
201 if (cdim_name == CDimName::CRow) {
202 TORCH_CHECK(
203 cidx.size(-1) == cdim + 1,
204 "crow_indices have wrong shape: ",
205 "crow_indices.shape[-1] = ",
206 cidx.size(-1),
207 " is not equal to ",
208 "nrows + 1 = ",
209 cdim + 1);
210 TORCH_CHECK(
211 idx.size(-1) == nnz,
212 "col_indices have wrong shape: ",
213 "col_indices.shape[-1] = ",
214 idx.size(-1),
215 " is not equal to ",
216 "nnz = ",
217 nnz);
218 } else {
219 TORCH_CHECK(
220 cidx.size(-1) == cdim + 1,
221 "ccol_indices have wrong shape: ",
222 "ccol_indices.shape[-1] = ",
223 cidx.size(-1),
224 " is not equal to ",
225 "ncols + 1 = ",
226 cdim + 1);
227 TORCH_CHECK(
228 idx.size(-1) == nnz,
229 "row_indices have wrong shape: ",
230 "row_indices.shape[-1] = ",
231 idx.size(-1),
232 " is not equal to ",
233 "nnz = ",
234 nnz);
235 }
236
237 using KernelLauncher = KernelLauncher<kernel_t, vec_kernel_t>;
238
239 // For TensorIterator's output: no void lambdas.
240 const auto dummy = at::empty({1}, cidx.options());
241
242 // Catch integer overflow from large dimensions. Otherwise, the
243 // invariant checks may fail with bogus exceptions or succeed with
244 // false-positive results when int64_t typed dimensions are cast to
245 // index dtype that corresponds to smaller interger type such as
246 // int32_t.
247 {
248 AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [cdim, dim, nnz]() {
249 if (cdim_name == CDimName::CRow) {
250 TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(dim)) == dim,
251 sizeof(index_t) * 8, "-bit integer overflow in column dimension = ", dim);
252 TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(cdim)) == cdim,
253 sizeof(index_t) * 8, "-bit integer overflow in row dimension = ", cdim);
254 } else {
255 TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(dim)) == dim,
256 sizeof(index_t) * 8, "-bit integer overflow in row dimension = ", dim);
257 TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(cdim)) == cdim,
258 sizeof(index_t) * 8, "-bit integer overflow in column dimension = ", cdim);
259 }
260 TORCH_CHECK(static_cast<int64_t>(static_cast<index_t>(nnz)) == nnz,
261 sizeof(index_t) * 8, "-bit integer overflow in nnz = ", nnz);
262 });
263 }
264
265 // Invariants 5.4 and 5.5
266 {
267 auto iter = TensorIteratorConfig()
268 .set_check_mem_overlap(false)
269 .add_owned_output(dummy.expand_as(idx))
270 .add_input(idx)
271 .build();
272
273 AT_DISPATCH_INDEX_TYPES(idx.scalar_type(), NAME, [&iter, dim]() {
274 const auto zero = index_t{0};
275 KernelLauncher::launch(iter, [zero, dim] FUNCAPI(index_t idx) -> index_t {
276 _check_idx_bounds<cdim_name, index_t>(idx, zero, dim);
277 return 0;
278 });
279 });
280 }
281
282 // Invariants 5.1, 5.2, 5.3, 5.6
283 {
284 const auto cidx_first = cidx.slice(-1, 0, 1);
285 const auto cidx_last = cidx.slice(-1, cdim, cdim + 1);
286
287 const auto cidx_curr = cidx.slice(-1, 0, cdim);
288 const auto cidx_next = cidx.slice(-1, 1, cdim + 1);
289
290 const auto batch_dims = cidx.sizes().slice(0, cidx.dim() - 1);
291 const auto batch_count = indexCount(batch_dims);
292 const auto batch_idx =
293 at::arange(batch_count, cidx.options()).view(batch_dims).unsqueeze_(-1);
294
295 const auto idx_ndims = idx.dim();
296
297 const auto idx_geometry_holder = at::sparse::TensorGeometryHolder<static_shape_max_len>(idx);
298 const auto idx_sizes = std::get<0>(*idx_geometry_holder);
299 const auto idx_strides = std::get<1>(*idx_geometry_holder);
300
301 auto iter = TensorIteratorConfig()
302 .set_check_mem_overlap(false)
303 .add_owned_output(dummy.expand_as(cidx_curr))
304 .add_input(cidx_first)
305 .add_input(cidx_last)
306 .add_input(cidx_curr)
307 .add_input(cidx_next)
308 .add_input(batch_idx)
309 .build();
310
311 AT_DISPATCH_INDEX_TYPES(
312 idx.scalar_type(),
313 NAME,
314 [&iter, &idx, dim, nnz, idx_ndims, &idx_sizes, &idx_strides]() {
315 const auto* RESTRICT ptr_idx = idx.const_data_ptr<index_t>();
316 const auto zero = index_t{0};
317 KernelLauncher::launch(
318 iter,
319 [zero, dim, nnz, idx_ndims, idx_sizes, idx_strides, ptr_idx] FUNCAPI(
320 index_t cidx_first,
321 index_t cidx_last,
322 index_t cidx_curr,
323 index_t cidx_next,
324 index_t batch_idx) -> index_t {
325 // Invariant 5.1
326 _check_first_cidx_is_zero<cdim_name, index_t>(cidx_first, zero);
327 // Invariant 5.2
328 _check_last_cidx_is_nnz<cdim_name, index_t>(cidx_last, nnz);
329 // Invariant 5.3
330 _check_cidx_nondecreasing_locally_bounded_sequence<
331 cdim_name,
332 index_t>(cidx_curr, cidx_next, zero, dim);
333 // Invariant 5.6
334 // NOTE: the implementation below is sync-less, but,
335 // unfortunately, work is not guaranteed to be well-balanced
336 // between different threads.
337 // Note: 5.6 should not be tested when
338 // nnz==0. Fortunately, the code below is no-op when
339 // nnz==0.
340 int64_t idx_offset = 0;
341 // assuming idx contiguity per batch:
342 int64_t tmp = batch_idx * nnz;
343 // `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)`
344 for (int i = idx_ndims - 1;
345 i >= 0 && nnz > 0; // break early when nnz==0
346 i--) {
347 int64_t div = tmp / idx_sizes[i];
348 idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i];
349 tmp = div;
350 }
351 const auto* RESTRICT ptr_idx_batch = ptr_idx + idx_offset;
352 _check_idx_sorted_distinct_vals_slices_with_cidx<
353 cdim_name,
354 index_t>(ptr_idx_batch, cidx_curr, cidx_next);
355 return 0;
356 });
357 });
358 }
359 }
360
361 template <
362 template <typename func_t>
363 class kernel_t,
364 template <typename func_t, typename vec_func_t>
365 class vec_kernel_t = EmptyVecKernel,
366 template <typename scalar_t> class Vec = DummyVec>
validate_compressed_sparse_indices_kernel(const bool is_crow,const Tensor & cidx,const Tensor & idx,const int64_t cdim,const int64_t dim,const int64_t nnz)367 void validate_compressed_sparse_indices_kernel(
368 const bool is_crow,
369 const Tensor& cidx,
370 const Tensor& idx,
371 const int64_t cdim,
372 const int64_t dim,
373 const int64_t nnz) {
374 constexpr size_t idx_max_ndims = 8; // up to 7-dim batch.
375 const size_t idx_ndims = static_cast<size_t>(idx.dim());
376
377 if (is_crow) {
378 if (idx_ndims <= idx_max_ndims) {
379 _validate_compressed_sparse_indices_kernel<
380 CDimName::CRow,
381 kernel_t,
382 vec_kernel_t,
383 Vec,
384 idx_max_ndims>(cidx, idx, cdim, dim, nnz);
385 }
386 else {
387 _validate_compressed_sparse_indices_kernel<
388 CDimName::CRow,
389 kernel_t,
390 vec_kernel_t,
391 Vec>(cidx, idx, cdim, dim, nnz);
392 }
393 } else {
394 if (idx_ndims <= idx_max_ndims) {
395 _validate_compressed_sparse_indices_kernel<
396 CDimName::CCol,
397 kernel_t,
398 vec_kernel_t,
399 Vec,
400 idx_max_ndims>(cidx, idx, cdim, dim, nnz);
401 }
402 else {
403 _validate_compressed_sparse_indices_kernel<
404 CDimName::CCol,
405 kernel_t,
406 vec_kernel_t,
407 Vec>(cidx, idx, cdim, dim, nnz);
408 }
409 }
410 }
411
412 } // namespace
413
414 } // namespace at::native
415