xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Tensor.h>
4 #include <ATen/native/TensorIterator.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/sparse/Macros.h>
7 #include <ATen/ExpandUtils.h>
8 #include <ATen/native/SparseTensorUtils.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/arange.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
17 #include <ATen/ops/result_type.h>
18 #endif
19 
20 #ifdef GPUCC
21 #define NAME "sparse_binary_op_intersection_cuda"
22 #else
23 #define NAME "sparse_binary_op_intersection_cpu"
24 #endif
25 
26 namespace at::native {
27 
28 namespace {
29 
30 using at::sparse::get_sparse_impl;
31 
32 // ForwardIt: only legacy random access iterator is supported.
33 template<class ForwardIt, class T, bool is_lower = true>
34 static FUNCAPI INLINE
find_bound(ForwardIt first,ForwardIt last,const T & value)35 ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) {
36     ForwardIt RESTRICT it;
37     typename std::iterator_traits<ForwardIt>::difference_type count, step;
38     // NOTE: std::distance(first, last) compiles but produces wrong results on CUDA,
39     // so only legacy random access iterators are safe in this code.
40     count = last - first;
41 
42     while (count > 0) {
43       it = first;
44       step = count / 2;
45       // avoiding std::advance(it, step),
46       // although it does work unlike std::distance on CUDA.
47       it += step;
48       // The decision which separates finding a lower bound vs an upper bound.
49       // Note that a lower bound is a value at *it with the smallest index
50       // such that *it >= value if such value exists, or last if does not.
51       // Similarly, an upper bound is a value at *it with the smallest index
52       // such that *it > value if such value exists, or last if does not.
53       // Let is_lower = true and *it < value, then we know that *it and values
54       // preceeding *it cannot contain a lower bound, so we adjust initial iterator range
55       // from [first, first + count] to [first + step + 1, first + count - (step + 1)],
56       // where +1 skips the element at which we have just evaluated *it < value.
57       // Samilar logic holds when is_lower = false.
58       if (is_lower ? *it < value : value >= *it) {
59         first = ++it;
60         count -= step + 1;
61       }
62       else {
63         count = step;
64       }
65     }
66     return first;
67 }
68 
69 template <template <typename func_t> class kernel_t>
70 struct KernelLauncher {
71   template <typename func_t>
launchKernelLauncher72   static void launch(TensorIteratorBase& iter, const func_t& f) {
73     kernel_t<func_t>::launch(iter, f);
74   }
75 };
76 
make_value_selection_intersection_iter(const Tensor & lhs_values,const Tensor & lhs_select_idx,const Tensor & rhs_values,const Tensor & rhs_select_idx,const Tensor & intersection_counts)77 TensorIterator make_value_selection_intersection_iter(
78     const Tensor& lhs_values,
79     const Tensor& lhs_select_idx,
80     const Tensor& rhs_values,
81     const Tensor& rhs_select_idx,
82     const Tensor& intersection_counts) {
83   const auto res_values_sizes = [&]() -> std::vector<int64_t> {
84     auto sizes = infer_size(
85         // keep nnz dim
86         lhs_values.sizes(),
87         // remove nnz dim for smooth broadcasting
88         rhs_values.sizes().slice(1));
89     // update nnz dim to be the length of an index
90     sizes[0] = lhs_select_idx.numel();
91     return sizes;
92   }();
93   auto res_values = at::empty(res_values_sizes, lhs_values.options());
94 
95   const auto restride_idx = [&res_values](const Tensor& idx) -> Tensor {
96     auto idx_sizes = std::vector<int64_t>(res_values.dim(), 1);
97     auto idx_strides = std::vector<int64_t>(res_values.dim(), 0);
98     idx_sizes[0] = idx.numel();
99     idx_strides[0] = 1;
100     return idx.as_strided(idx_sizes, idx_strides);
101   };
102 
103   const auto restride_values = [&lhs_select_idx](const Tensor& values) -> Tensor {
104     auto values_sizes = at::DimVector(values.sizes());
105     auto values_strides = at::DimVector(values.strides());
106     values_sizes[0] = lhs_select_idx.numel();
107     values_strides[0] = 0;
108     return values.as_strided(values_sizes, values_strides);
109   };
110 
111   auto iter = TensorIteratorConfig()
112     .set_check_mem_overlap(false)
113     .check_all_same_dtype(false)
114     .resize_outputs(false)
115     .add_owned_output(res_values)
116     .add_owned_input(restride_values(lhs_values))
117     .add_owned_input(restride_idx(lhs_select_idx))
118     .add_owned_input(restride_values(rhs_values))
119     .add_owned_input(restride_idx(rhs_select_idx))
120     .add_owned_input(restride_idx(intersection_counts))
121     .build();
122 
123   return iter;
124 }
125 
126 template <
127   template <typename func_t> class kernel_t,
128   typename value_selection_intersection_kernel_t,
129   typename index_t = int64_t,
130   int64_t max_static_len = 0>
131 void _sparse_binary_op_intersection_kernel_impl(
132     Tensor& res,
133     const Tensor& x_,
134     const Tensor& y_,
135     const std::vector<int64_t>& broadcasted_shape,
136     const std::optional<Tensor>& x_hash_opt_ = std::nullopt,
137     const std::optional<Tensor>& y_hash_opt_ = std::nullopt,
138     const bool accumulate_matches = true,
139     const bool distributive_with_sum = true
140 ) {
141   // The common dtype check is relevant when op is done in-place.
142   // This is because binary_of_t produces new values and it could be that
143   // new_values.dtype != res.dtype. In such a case we should error out
144   // as soon as possible to avoid redundant kernel runs.
145   const auto common_dtype = at::result_type(x_, y_);
146   TORCH_CHECK(canCast(common_dtype, res.scalar_type()),
147       "Can't convert result type ", common_dtype,
148       " to output ", res.scalar_type());
149 
150   using KernelLauncher = KernelLauncher<kernel_t>;
151   using OptTensor = std::optional<Tensor>;
152 
153   // If the op and sum are not distributive, coalesce is required.
154   const auto coalesce_if_not_distributive = [distributive_with_sum](const Tensor& t, const OptTensor& t_hash_opt) -> auto {
155     // No need to coalesce in such a case.
156     if (distributive_with_sum) {
157       return std::make_tuple(t, t_hash_opt);
158     } else {
159       // Otherwise coalesce and force hash recompute.
160       return std::make_tuple(t.coalesce(), static_cast<OptTensor>(std::nullopt));
161     }
162   };
163 
164   Tensor x, y;
165   OptTensor x_hash_opt, y_hash_opt;
166   std::tie(x, x_hash_opt) = coalesce_if_not_distributive(x_, x_hash_opt_);
167   std::tie(y, y_hash_opt) = coalesce_if_not_distributive(y_, y_hash_opt_);
168 
169   // Given sparse tensors x and y we decide which one is source, and which one
170   // is probably_coalesced. The indices of both source and probably_coalesced are
171   // hashed and then the hash values of the source's indices are binary-searched
172   // into the hash values of the probably_coalesced's indices.
173   // If probably_coalesce is coalesced, by the property of the hashing method
174   // (see below), the hash values are already sorted and we can avoid any
175   // explicit sorting routines.
176   Tensor probably_coalesced, source;
177   OptTensor probably_coalesced_indices_hash_opt, source_indices_hash_opt;
178   std::tie(probably_coalesced, probably_coalesced_indices_hash_opt, source, source_indices_hash_opt) = [&]() -> auto {
179     // Case 1: either x or y is coalesced.
180     if ((x.is_coalesced() ^ y.is_coalesced())) {
181       return x.is_coalesced()
182         ? std::make_tuple(x, x_hash_opt, y, y_hash_opt)
183         : std::make_tuple(y, y_hash_opt, x, x_hash_opt);
184     }
185     // Case 2: Both x and y are either coalesced or non-coalesced.
186     // If both are coalesced, search into the larger tensor is faster.
187     // Same holds when both are non-coalesced.
188     else {
189       Tensor larger, smaller;
190       OptTensor larger_hash_opt, smaller_hash_opt;
191       std::tie(larger, larger_hash_opt, smaller, smaller_hash_opt) = [&]() -> auto {
192         return x._nnz() >= y._nnz()
193           ? std::make_tuple(x, x_hash_opt, y, y_hash_opt)
194           : std::make_tuple(y, y_hash_opt, x, x_hash_opt);
195       }();
196 
197       // If under a uniform distribution it is likely to hit many elements in larger,
198       // it is best to coalesce it for better performance.
199       const auto larger_sizes = larger.sizes();
200       const auto sparse_dim_numel = std::accumulate(
201           larger_sizes.begin(),
202           larger_sizes.begin() + larger.sparse_dim(),
203           1,
204           std::multiplies<int64_t>());
205       // If nnz > prod(larger.shape[:sparse_dim]), by the pidgeonhole principle,
206       // there is at least one bucket with nnz / prod(larger.shape[:sparse_dim]) elements.
207       // It provides a lower bound for the max count in the intersection.
208       // This condition is very conservative as we do not check whether such an event
209       // actually occurred, although it is very likely under a uniform distribution,
210       // the distribution with the highest uncertainty (maximizes entropy).
211       const auto max_count_lower_bound = larger._nnz() / sparse_dim_numel;
212       constexpr int64_t MAX_COPIES_PER_THREAD = 50;
213       return max_count_lower_bound > MAX_COPIES_PER_THREAD
214         // coalesce invalidates hash values, so force-recompute
215         ? std::make_tuple(larger.coalesce(), static_cast<OptTensor>(std::nullopt), smaller, smaller_hash_opt)
216         : std::make_tuple(larger, larger_hash_opt, smaller, smaller_hash_opt);
217     }
218   }();
219 
220   // The employed hash function maps a d-dim index to a linear offset
221   // into a contiguous memory that is sufficient to fit a dense tensor
222   // of shape broadcasted_shape(x.shape, y.shape), i.e.
223   // idx -> \sum_{i = 0}^d idx[i] * hash_coeffs[i], where
224   // hash_coeffs are the strides of a contiguous tensor of shape
225   // broadcasted_shape(x.shape, y.shape).
226   // Assuming the following order on the dimensions, i.e. the right-most dim is the
227   // fastest-changing dim, and the left-most is the slowest-changing dim,
228   // which is implicit in the definition of hash_coeffs,
229   // it could be shown that the hash function is actually bijective and, hence,
230   // is a perfect hash function (no collisions ever).
231 
232   // Need owning storage in case of the Tensor class.
233   const auto hash_coeffs_storage = [&]() -> auto {
234     const auto broadcasted_sparse_dim_shape = std::vector<int64_t>(
235       broadcasted_shape.begin(),
236       broadcasted_shape.begin() + probably_coalesced.sparse_dim()
237     );
238     auto strides = c10::contiguous_strides(broadcasted_sparse_dim_shape);
239     return at::sparse::TensorGeometryHolder<max_static_len>(strides, strides, probably_coalesced.options());
240   }();
241 
242   const auto hash_coeffs = std::get<0>(*hash_coeffs_storage);
243 
244   const auto nnz_arange = at::arange(
245       std::max(probably_coalesced._nnz(), source._nnz()),
246       source._indices().options());
247   const auto probably_coalesced_nnz_arange = nnz_arange.narrow(-1, 0, probably_coalesced._nnz());
248 
249   // non-const because of gcc-5/clang-5 issues
250   auto sparse_dim = probably_coalesced.sparse_dim();
251 
252   // Apply the hash function to probably_coalesced.indices
253   const auto probably_coalesced_indices_hash = [&]() -> Tensor {
254     // probably_coalesced is coalesced and hash provided? Reuse it!
255     if (probably_coalesced_indices_hash_opt.has_value()) {
256       return (*probably_coalesced_indices_hash_opt).contiguous();
257     }
258 
259     const auto indices = probably_coalesced._indices();
260     // non-const because of gcc-5/clang-5 issues
261     auto indices_dim_stride = indices.stride(0);
262     auto indices_nnz_stride = indices.stride(1);
263 
264     auto hash = at::empty({probably_coalesced._nnz()}, indices.options().dtype(kLong));
265 
266     auto iter = TensorIteratorConfig()
267       .check_all_same_dtype(false)
268       .add_output(hash)
269       .add_input(probably_coalesced_nnz_arange)
270       .build();
271 
272     {
273       const auto* RESTRICT ptr_indices = indices.const_data_ptr<index_t>();
274 
275       KernelLauncher::launch(iter,
276           // NOTE: capture by value required by CUDA
277           [=] FUNCAPI (index_t nnz_idx) -> int64_t {
278           int64_t hash = 0;
279           if (!ptr_indices) {
280             return hash;
281           }
282           const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
283           for (int64_t dim = 0; dim < sparse_dim; ++dim) {
284             const auto dim_hash_coeff = hash_coeffs[dim];
285             const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
286             hash += dim_index * dim_hash_coeff;
287           }
288           return hash;
289       });
290     }
291 
292     return hash;
293   }();
294 
295   // Now that we have hash values of probably_coalesced.indices,
296   // we need to decide whether they need to get sorted.
297   // The sort is not requires if probably_coalesced is coalesced.
298   Tensor sorted_hash, argsort_hash;
299   std::tie(sorted_hash, argsort_hash) = [&]() -> std::tuple<Tensor, Tensor> {
300     if (probably_coalesced.is_coalesced()) {
301       // NOTE: argsort.dtype == nnz_arange.dtype
302       const auto argsort = nnz_arange.narrow(-1, 0, probably_coalesced._nnz());
303       return std::make_tuple(probably_coalesced_indices_hash, argsort);
304     } else {
305       // NOTE: we want argsort.dtype == nnz_arange.dtype,
306       // but sort() produces indices of type int64_t,
307       // so we convert to nnz_arange.dtype to avoid issues
308       // with pointer types in the kernels below.
309       Tensor sorted, argsort;
310       std::tie(sorted, argsort) = probably_coalesced_indices_hash.sort();
311       return std::make_tuple(sorted, argsort.to(nnz_arange.scalar_type()));
312     }
313   }();
314 
315   // Perform hash intersection.
316   // Let  s_hash = hash(source.indices),
317   //     pc_hash = hash(probably_coalesced.indices), then
318   // for i = 0, ..., len(s_hash) - 1:
319   //     lb = <index of a value in pc_hash[argsort_hash] which is a lower bound for s_hash[i]>,
320   //     up = <index of a value in pc_hash[argsort_hash] which is an upper bound for s_hash[i]>,
321   //     intersection_count[i] = up - lb
322   //     intersection_first_idx[i] = lb.
323   //
324   // intersection_count and intersection_first_idx are used to form indices at which
325   // intersection values are selected.
326   auto [intersection_count, intersection_first_idx] = [&]() -> std::tuple<Tensor, Tensor> {
327     const auto source_nnz = source._nnz();
328     auto intersection_buffer = at::empty({2, source_nnz}, sorted_hash.options());
329     auto intersection_count = intersection_buffer.select(0, 0);
330     auto intersection_first_idx = intersection_buffer.select(0, 1);
331 
332     const auto source_indices = source._indices();
333     const auto source_arange = nnz_arange.narrow(-1, 0, source_nnz);
334     // non-const because of gcc-5/clang-5 issues
335     auto indices_dim_stride = source_indices.stride(0);
336     auto indices_nnz_stride = source_indices.stride(1);
337     auto dummy = at::empty({1}, source_arange.options());
338 
339     auto hash = source_indices_hash_opt.has_value()
340       ? (*source_indices_hash_opt).contiguous()
341       : at::empty({0}, probably_coalesced._indices().options().dtype(kLong));
342     const auto* RESTRICT hash_ptr = source_indices_hash_opt.has_value()
343       ? hash.data_ptr<int64_t>()
344       : nullptr;
345 
346     auto iter = TensorIteratorConfig()
347       .set_check_mem_overlap(false)
348       .add_owned_output(dummy.expand_as(source_arange))
349       .add_input(source_arange)
350       .build();
351 
352     {
353       const auto* RESTRICT ptr_indices = source_indices.const_data_ptr<index_t>();
354       const auto* RESTRICT ptr_sorted_hash = sorted_hash.const_data_ptr<int64_t>();
355       const auto sorted_hash_len = sorted_hash.numel();
356       auto* RESTRICT ptr_intersection_count = intersection_count.data_ptr<int64_t>();
357       auto* RESTRICT ptr_intersection_first_idx = intersection_first_idx.data_ptr<int64_t>();
358 
359       // Fusing hash computation with hash intersection.
360       KernelLauncher::launch(iter,
361           // NOTE: capture by value required by CUDA
362           [=] FUNCAPI (index_t nnz_idx) -> index_t {
363           int64_t hash = 0;
364           if (hash_ptr) {
365             hash = hash_ptr[nnz_idx];
366           } else if (sparse_dim) {
367             // Compute hash value
368             const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
369             for (int64_t dim = 0; dim < sparse_dim; ++dim) {
370               const auto dim_hash_coeff = hash_coeffs[dim];
371               const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
372               hash += dim_index * dim_hash_coeff;
373             }
374           }
375 
376           // Perform hash values intersection
377           const auto* RESTRICT lb = find_bound<const int64_t*, int64_t, /*is_lower=*/true>(
378               ptr_sorted_hash,
379               ptr_sorted_hash + sorted_hash_len,
380               hash
381           );
382 
383           const auto* RESTRICT ub = find_bound<const int64_t*, int64_t, /*is_lower=*/false>(
384               ptr_sorted_hash,
385               ptr_sorted_hash + sorted_hash_len,
386               hash
387           );
388 
389           ptr_intersection_count[nnz_idx] = ub - lb;
390           ptr_intersection_first_idx[nnz_idx] = lb - ptr_sorted_hash;
391 
392           return 0;
393       });
394     }
395 
396     return std::make_tuple(intersection_count, intersection_first_idx);
397   }();
398 
399   const auto res_indices = source._indices().clone();
400   const auto binary_op_res_dtype = at::result_type(source._values(), probably_coalesced._values());
401   const auto res_values = value_selection_intersection_kernel_t::apply(
402       source._values().to(binary_op_res_dtype),
403       nnz_arange.narrow(-1, 0, source._nnz()),
404       probably_coalesced._values().to(binary_op_res_dtype),
405       intersection_first_idx.to(nnz_arange.scalar_type()),
406       intersection_count,
407       argsort_hash,
408       accumulate_matches).to(res.scalar_type());
409   const auto res_sparse_dim = source.sparse_dim();
410   const auto res_dense_dim = source.dense_dim();
411   const auto& res_shape = broadcasted_shape;
412   const auto res_nnz = source._nnz();
413 
414   auto* res_sparse_impl = get_sparse_impl(res);
415   res_sparse_impl->raw_resize_(res_sparse_dim, res_dense_dim, res_shape);
416   res_sparse_impl->set_indices_and_values_unsafe(res_indices, res_values);
417   res_sparse_impl->set_nnz_and_narrow(res_nnz);
418   res._coalesced_(source.is_coalesced());
419 }
420 
421 template <
422   template <typename func_t> class kernel_t,
423   typename value_selection_intersection_kernel_t>
424 void _sparse_binary_op_intersection_kernel_out(
425     Tensor& res,
426     const Tensor& x,
427     const Tensor& y,
428     const std::optional<Tensor>& x_hash_opt = std::nullopt,
429     const std::optional<Tensor>& y_hash_opt = std::nullopt,
430     // If op distributes with the sum, the arguments are processed as is,
431     // without the calls to coalesce().
432     const bool distributive_with_sum = true
433 ) {
434   TORCH_CHECK(
435       (x.is_sparse() && y.is_sparse())
436       && (x.dim() == y.dim()) && (x.sparse_dim() == y.sparse_dim())
437       && (x.sizes().slice(0, x.sparse_dim()) == y.sizes().slice(0, y.sparse_dim())),
438       NAME, "(): expects sparse inputs with equal dimensionality, ",
439       "number of sparse dimensions, and shape of sparse dimensions");
440   TORCH_CHECK(
441       x._indices().scalar_type() == y._indices().scalar_type(),
442       NAME, "(): expects inputs' indices to be of the same dtype (i.e. long or int)");
443 
444   const auto check_hash_validity = [](const Tensor& t, const std::optional<Tensor>& t_hash_opt) {
445     if (!t_hash_opt.has_value()) {
446       return;
447     }
448 
449     const auto &t_hash = *t_hash_opt;
450     TORCH_INTERNAL_ASSERT(
451         t_hash.dim() == 1 && t_hash.scalar_type() == kLong && t_hash.size(-1) == t._indices().size(-1),
452         NAME, "(): explicit hash values need to be a 1-dim Long tensor with the ",
453         "NSE matching that of the corresponding sparse tensor.");
454   };
455 
456   check_hash_validity(x, x_hash_opt);
457   check_hash_validity(y, y_hash_opt);
458 
459   const auto broadcasted_shape = infer_size(x.sizes(), y.sizes());
460 
461   // 8 sparse dims should be more than enough?
462   constexpr int64_t max_sparse_dims = 8;
463 
464   // COO indices are only 64-bit integers for now.
465   using index_t = int64_t;
466 
467   if (max_sparse_dims > x.sparse_dim()) {
468     _sparse_binary_op_intersection_kernel_impl<
469       // For some reason MSVC complaints about passing constexpr max_sparse_dims
470       // as a template parameter claiming as if it is not know at compile time.
471       kernel_t, value_selection_intersection_kernel_t, index_t, 8>(
472         res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
473   } else {
474     _sparse_binary_op_intersection_kernel_impl<
475       kernel_t, value_selection_intersection_kernel_t, index_t>(
476         res, x, y, broadcasted_shape, x_hash_opt, y_hash_opt, distributive_with_sum);
477   }
478 }
479 
480 } // anonymous namespace
481 
482 } // at::native
483