xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/SparseCsrTensorUtils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/mkl/Sparse.h>
9 #include <ATen/native/BinaryOps.h>
10 #include <ATen/native/CPUBlas.h>
11 #include <ATen/native/Resize.h>
12 #include <ATen/native/SparseTensorUtils.h>
13 #include <ATen/native/TensorConversions.h>
14 #include <ATen/native/mkl/SparseBlasImpl.h>
15 #include <ATen/native/sparse/SparseBlasImpl.h>
16 #include <ATen/native/sparse/SparseCsrTensorMath.h>
17 #include <c10/macros/Macros.h>
18 #include <c10/util/irange.h>
19 #include <ATen/AccumulateType.h>
20 
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #include <ATen/NativeFunctions.h>
24 #include <ATen/Operators.h>
25 #else
26 #include <ATen/ops/_conj_physical_native.h>
27 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
28 #include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
29 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
30 #include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
31 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
32 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
33 #include <ATen/ops/_sparse_csr_prod_native.h>
34 #include <ATen/ops/_sparse_csr_sum_native.h>
35 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
36 #include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
37 #include <ATen/ops/_sparse_mm_reduce_impl_native.h>
38 #include <ATen/ops/_unique.h>
39 #include <ATen/ops/abs.h>
40 #include <ATen/ops/abs_native.h>
41 #include <ATen/ops/add.h>
42 #include <ATen/ops/add_native.h>
43 #include <ATen/ops/addmm.h>
44 #include <ATen/ops/addmm_native.h>
45 #include <ATen/ops/angle.h>
46 #include <ATen/ops/angle_native.h>
47 #include <ATen/ops/asin.h>
48 #include <ATen/ops/asin_native.h>
49 #include <ATen/ops/asinh.h>
50 #include <ATen/ops/asinh_native.h>
51 #include <ATen/ops/atan.h>
52 #include <ATen/ops/atan_native.h>
53 #include <ATen/ops/atanh.h>
54 #include <ATen/ops/atanh_native.h>
55 #include <ATen/ops/ceil.h>
56 #include <ATen/ops/ceil_native.h>
57 #include <ATen/ops/conj_physical.h>
58 #include <ATen/ops/conj_physical_native.h>
59 #include <ATen/ops/copy_native.h>
60 #include <ATen/ops/deg2rad.h>
61 #include <ATen/ops/deg2rad_native.h>
62 #include <ATen/ops/empty.h>
63 #include <ATen/ops/empty_like.h>
64 #include <ATen/ops/erf.h>
65 #include <ATen/ops/erf_native.h>
66 #include <ATen/ops/erfinv.h>
67 #include <ATen/ops/erfinv_native.h>
68 #include <ATen/ops/expm1.h>
69 #include <ATen/ops/expm1_native.h>
70 #include <ATen/ops/fill_native.h>
71 #include <ATen/ops/floor.h>
72 #include <ATen/ops/floor_native.h>
73 #include <ATen/ops/frac.h>
74 #include <ATen/ops/frac_native.h>
75 #include <ATen/ops/isinf.h>
76 #include <ATen/ops/isinf_native.h>
77 #include <ATen/ops/isnan.h>
78 #include <ATen/ops/isnan_native.h>
79 #include <ATen/ops/isneginf.h>
80 #include <ATen/ops/isneginf_native.h>
81 #include <ATen/ops/isposinf.h>
82 #include <ATen/ops/isposinf_native.h>
83 #include <ATen/ops/log1p.h>
84 #include <ATen/ops/log1p_native.h>
85 #include <ATen/ops/mm_native.h>
86 #include <ATen/ops/mul.h>
87 #include <ATen/ops/mul_native.h>
88 #include <ATen/ops/neg.h>
89 #include <ATen/ops/neg_native.h>
90 #include <ATen/ops/normal_native.h>
91 #include <ATen/ops/ones.h>
92 #include <ATen/ops/ones_like.h>
93 #include <ATen/ops/rad2deg.h>
94 #include <ATen/ops/rad2deg_native.h>
95 #include <ATen/ops/relu.h>
96 #include <ATen/ops/relu_native.h>
97 #include <ATen/ops/resize_as_sparse_native.h>
98 #include <ATen/ops/result_type.h>
99 #include <ATen/ops/round.h>
100 #include <ATen/ops/round_native.h>
101 #include <ATen/ops/round_ops.h>
102 #include <ATen/ops/sgn.h>
103 #include <ATen/ops/sgn_native.h>
104 #include <ATen/ops/sign.h>
105 #include <ATen/ops/sign_native.h>
106 #include <ATen/ops/signbit.h>
107 #include <ATen/ops/signbit_native.h>
108 #include <ATen/ops/sin.h>
109 #include <ATen/ops/sin_native.h>
110 #include <ATen/ops/sinh.h>
111 #include <ATen/ops/sinh_native.h>
112 #include <ATen/ops/sparse_mask.h>
113 #include <ATen/ops/sparse_mask_native.h>
114 #include <ATen/ops/sqrt.h>
115 #include <ATen/ops/sqrt_native.h>
116 #include <ATen/ops/tan.h>
117 #include <ATen/ops/tan_native.h>
118 #include <ATen/ops/tanh.h>
119 #include <ATen/ops/tanh_native.h>
120 #include <ATen/ops/tensor.h>
121 #include <ATen/ops/threshold_backward.h>
122 #include <ATen/ops/threshold_backward_native.h>
123 #include <ATen/ops/trunc.h>
124 #include <ATen/ops/trunc_native.h>
125 #include <ATen/ops/zero_native.h>
126 #include <ATen/ops/zeros.h>
127 #include <ATen/ops/zeros_like.h>
128 #endif
129 
130 #include <algorithm>
131 
132 namespace at {
133 namespace meta {
134 
TORCH_META_FUNC(_convert_indices_from_coo_to_csr)135 TORCH_META_FUNC(_convert_indices_from_coo_to_csr)
136 (const Tensor& self, const int64_t size, const bool out_int32) {
137   TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector, but got ",
138               self.dim(), " dimensional tensor.");
139   ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
140   c10::TensorOptions options =
141       TensorOptions().device(self.options().device()).dtype(scalar_type);
142   set_output_raw_strided(0, size + 1, {}, options);
143 }
144 
TORCH_META_FUNC(_convert_indices_from_csr_to_coo)145 TORCH_META_FUNC(_convert_indices_from_csr_to_coo)
146 (const Tensor& crow_indices,
147  const Tensor& col_indices,
148  const bool out_int32,
149  const bool transpose) {
150   TORCH_CHECK(
151     crow_indices.dim() == col_indices.dim(), "crow_indices and col_indices are supposed to have"
152     " the same dimensionality, but got ", crow_indices.dim(), " and ",
153     crow_indices.dim(), " dimensional tensors, respectively.");
154   ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
155   c10::TensorOptions options = crow_indices.options().dtype(scalar_type);
156   set_output_raw_strided(0, {col_indices.dim() + 1, col_indices.numel()}, {}, options, {});
157 }
158 
159 } // namespace meta
160 
161 namespace {
162 
163 template <typename F>
unary_op_out(F op_out,const Tensor & self,Tensor & result)164 Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result) {
165   TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
166   TORCH_INTERNAL_ASSERT(result.is_sparse_csr());
167 
168   if (!result.is_same(self)) {
169     // For the case of (0x0) result tensor, manually resize `result` tensor
170     // to the size of `self` tensor
171     if (result.numel() == 0) {
172       at::native::resize_as_sparse_compressed_(result, self);
173     }
174     // copy_sparse_compressed_ internally checks the sizes of result and self tensors
175     // Hence no external size check required
176     at::native::copy_sparse_compressed_(result, self);
177   }
178 
179   auto self_values = self.values();
180   auto result_values = result.values();
181 
182   op_out(self_values, result_values);
183   return result;
184 }
185 
186 template <typename F, typename... Args>
unary_op_inplace(Tensor & self,const F & op_inplace,Args &&...args)187 Tensor& unary_op_inplace(Tensor& self, const F& op_inplace, Args&&... args) {
188   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "unary_op_inplace", [](){});
189 
190   auto self_values = self.values();
191   (self_values.*op_inplace)(std::forward<Args>(args)...);
192   return self;
193 }
194 
195 } // end anonymous namespace
196 
197 namespace native {
198 
199 using namespace at::sparse_csr;
200 // certain utility functions are usable from sparse COO.
201 using namespace at::sparse;
202 
mul_out_sparse_csr(const Tensor & t_,const Tensor & src_,Tensor & r)203 Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
204   // // TODO: Use a specialized CSR kernel for performance if needed
205   if (t_.is_sparse_csr() && src_.layout() == kStrided) {
206     return mul_out_sparse_csr(t_, src_.sparse_mask(t_), r);
207   }
208   if (t_.layout() == kStrided && src_.is_sparse_csr()) {
209     return mul_out_sparse_csr(t_.sparse_mask(src_), src_, r);
210   }
211   TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR");
212   Tensor t = t_.to_sparse();
213   Tensor src = src_.to_sparse();
214   Tensor tmp_result = t.mul(src);
215   auto r_sparse_csr = tmp_result.to_sparse_csr();
216   r.resize_as_sparse_(r_sparse_csr);
217   r.copy_(r_sparse_csr);
218   return r;
219 }
220 
221 template <typename op_t>
intersection_binary_op_with_wrapped_scalar(const Tensor & sparse,const Tensor & scalar,const op_t & op)222 Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, const Tensor& scalar, const op_t& op) {
223   // NOTE: intersection_binary_op_with_wrapped_scalar assumes scalar.numel() == 1.
224   const auto result_values = op(sparse.values(), scalar.squeeze()).to(at::result_type(sparse, scalar));
225   const auto result_sizes = infer_size(sparse.sizes(), scalar.sizes());
226   auto [compressed_indices, plain_indices] = getCompressedPlainIndices(sparse);
227   return at::_sparse_compressed_tensor_unsafe(
228       compressed_indices.clone(),
229       plain_indices.clone(),
230       result_values,
231       result_sizes,
232       sparse.options().dtype(result_values.scalar_type()));
233 }
234 
235 template <typename op_t>
intersection_binary_op_with_wrapped_scalar_(Tensor & sparse,const Tensor & scalar,const string & op_name,const op_t & op)236 Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const string& op_name, const op_t& op) {
237   // NOTE: intersection_binary_op_with_wrapped_scalar_ assumes scalar.numel() == 1.
238   const auto broadcasted_shape = infer_size(sparse.sizes(), scalar.sizes());
239   if (sparse.sizes() != broadcasted_shape) {
240     TORCH_CHECK(false, op_name, "(): output with shape ", sparse.sizes(), " does not match ",
241         "the broadcast shape ", broadcasted_shape);
242   }
243   auto values = sparse.values();
244   // Safe to use squeeze here, we already know that scalar safely broadcasts.
245   op(values, scalar.squeeze());
246   return sparse;
247 }
248 
mul_sparse_csr(const Tensor & self,const Tensor & other)249 Tensor mul_sparse_csr(const Tensor& self, const Tensor& other) {
250   // Check if either of the arguments is a wrapped Scalar
251   if (self.layout() == kStrided && self.dim() == 0) {
252     return intersection_binary_op_with_wrapped_scalar(other, self, [](const Tensor& a, const Tensor& b) -> Tensor {
253         return a.mul(b);
254     });
255   }
256   if (other.layout() == kStrided && other.dim() == 0) {
257     return intersection_binary_op_with_wrapped_scalar(self, other, [](const Tensor& a, const Tensor& b) -> Tensor {
258         return a.mul(b);
259     });
260   }
261 
262   if (self.is_sparse_csr() && other.layout() == kStrided) {
263     return mul_sparse_csr(self, other.sparse_mask(self));
264   }
265   if (self.layout() == kStrided && other.is_sparse_csr()) {
266     return mul_sparse_csr(self.sparse_mask(other), other);
267   }
268 
269   auto commonDtype = at::result_type(self, other);
270   auto result_options = self.options().dtype(commonDtype);
271   // CSR is 2d!
272   Tensor result = at::empty({0, 0}, result_options);
273   return at::mul_out(result, self, other); // redispatch!
274 }
275 
mul_sparse_csr_(Tensor & self,const Tensor & other)276 Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) {
277   if (other.layout() == kStrided && other.dim() == 0) {
278     return intersection_binary_op_with_wrapped_scalar_(self, other, "mul_", [](Tensor& a, const Tensor& b) -> Tensor& {
279         return a.mul_(b);
280     });
281   }
282   return at::mul_out(self, self, other); // redispatch!
283 }
284 
285 
286 namespace {
287 
288 template <typename F>
get_result_tensor_for_unary_op(F op,const Tensor & input)289 inline Tensor get_result_tensor_for_unary_op(F op, const Tensor& input) {
290   auto values = input.values();
291 
292   // To handle type promotion for inputs to unary ops,
293   // we first get the result from the underlined op, and use the result
294   // to create a sparse compressed tensor, which is used as the input to the out=
295   // variant
296   auto result_values = op(values);
297 
298   auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(),
299                                                                       "get_result_tensor_for_unary_op",
300                                                                       [&]{ return input.crow_indices(); },
301                                                                       [&]{ return input.ccol_indices(); });
302   auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(),
303                                                                  "get_result_tensor_for_unary_op",
304                                                                  [&]{ return input.col_indices(); },
305                                                                  [&]{ return input.row_indices(); });
306 
307   auto result = at::_sparse_compressed_tensor_unsafe(
308       compressed_indices.clone(),
309       plain_indices.clone(),
310       result_values,
311       input.sizes(),
312       input.options().dtype(result_values.scalar_type()));
313 
314   return result;
315 }
316 } // namespace
317 
normal_sparse_csr_(Tensor & self,double mean,double std,std::optional<Generator> gen)318 Tensor& normal_sparse_csr_(
319     Tensor& self,
320     double mean,
321     double std,
322     std::optional<Generator> gen) {
323   return unary_op_inplace(self, &Tensor::normal_, mean, std, gen);
324 }
325 
fill_sparse_csr_(Tensor & self,const Scalar & value)326 Tensor& fill_sparse_csr_(Tensor& self, const Scalar& value) {
327   return unary_op_inplace(self, &TensorBase::fill_, value);
328 }
329 
sparse_mask_sparse_compressed(const Tensor & self,const Tensor & mask)330 Tensor sparse_mask_sparse_compressed(
331     const Tensor& self,
332     const Tensor& mask) {
333   TORCH_CHECK(at::sparse_csr::is_sparse_compressed(mask),
334               "sparse_mask_sparse_compressed expects mask to have sparse compressed layout, got ", mask.layout());
335   TORCH_CHECK(
336       mask.sizes().equals(self.sizes()),
337       "sparse_mask(): operands have incompatible sizes; self has size ",
338       self.sizes(),
339       " but mask has size ",
340       mask.sizes());
341 
342   if (self.is_same(mask)) {
343     return self;
344   }
345 
346   if (!mask.numel() || !mask._nnz()) {
347     return mask.clone().to(self.device(), self.scalar_type());
348   }
349 
350   if (self.layout() == kStrided) {
351     auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(mask);
352     auto mask_values = mask.values();
353     auto dense_mask = at::_sparse_compressed_tensor_unsafe(
354         compressed_indices,
355         plain_indices,
356         at::ones({1}, self.options().dtype(kBool)).expand_as(mask_values),
357         self.sizes(),
358         self.options().dtype(kBool).layout(mask.layout())).to_dense();
359     return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
360         mask.layout(), "sparse_mask_sparse_compressed",
361         [&] {
362           return at::native::dense_to_sparse_with_mask(self, dense_mask, mask.layout(), {}, mask.dense_dim());
363         },
364         [&] {
365           auto blocksize = at::sparse_csr::getBlockSize(mask);
366           return at::native::dense_to_sparse_with_mask(self, dense_mask, mask.layout(), blocksize, mask.dense_dim());
367         });
368   } else if (self.layout() == mask.layout()) {
369     // TODO: keeping this for BC but the method used here may lead to
370     // incorrect indices.
371     return self.mul(at::ones_like(mask)).to(self.scalar_type());
372   } else {
373     // TODO: keeping this for BC but the method used here cannot
374     // support batch dimensions because sparse COO tensors are batch
375     // dimension ignorant.
376     return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
377         mask.layout(), "sparse_mask_sparse_compressed",
378         [&] {
379           return self.sparse_mask(mask.to_sparse()).to_sparse(mask.layout());
380         },
381         [&] {
382           auto blocksize = at::sparse_csr::getBlockSize(mask);
383           return self.sparse_mask(mask.to_sparse()).to_sparse(mask.layout(), blocksize);
384         });
385   }
386 }
387 
mul_scalar_sparse_csr(const Tensor & self,const Scalar & other)388 Tensor mul_scalar_sparse_csr(const Tensor& self, const Scalar& other) {
389   auto result_values = self.values().mul(other);
390   return at::native::_sparse_csr_tensor_unsafe(
391       self.crow_indices().clone(),
392       self.col_indices().clone(),
393       result_values,
394       self.sizes(),
395       result_values.scalar_type(),
396       self.layout(),
397       result_values.device());
398 }
399 
zero_sparse_csr_(Tensor & self)400 Tensor& zero_sparse_csr_(Tensor& self) {
401   /*
402     csr.zero_() resets nnz to 0.
403 
404     If the original sparsity pattern needs to be preserved, use
405     `csr.values().zero_()` instead.
406 
407     The above behavior also implies that torch.zeros_like(csr) returns
408     a new tensor with nnz == 0. If one needs a zeros_like semantics
409     where the result has the same sparsity pattern as input, then use
410     `result = csr.clone(); result.values.zero_();`
411   */
412   AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
413   get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes());
414   return self;
415 }
416 
417 /* Implementation of Unary Ufuncs, those supported for Sparse CSR Layout
418  * Only simple funcs, with 0->0 correspondence are currently supported. */
419 
420 #define CREATE_UNARY_UFUNC_OUT(op_name)                                  \
421   Tensor& op_name##_sparse_csr_out(const Tensor& self, Tensor& result) { \
422     return unary_op_out(&at::op_name##_outf, self, result);              \
423   }
424 
425 #define CREATE_UNARY_UFUNC_FUNCTIONAL(op_name)                 \
426   Tensor op_name##_sparse_csr(const Tensor& self) {            \
427     return get_result_tensor_for_unary_op(&at::op_name, self); \
428   }
429 
430 #define CREATE_UNARY_UFUNC_INPLACE(op_name)             \
431   Tensor& op_name##_sparse_csr_(Tensor& self) {         \
432     return unary_op_inplace(self, &Tensor::op_name##_); \
433   }
434 
435 #define CREATE_UNARY_UFUNC(op_name)       \
436   CREATE_UNARY_UFUNC_OUT(op_name);        \
437   CREATE_UNARY_UFUNC_FUNCTIONAL(op_name); \
438   CREATE_UNARY_UFUNC_INPLACE(op_name);
439 
440 #define CREATE_UNARY_UFUNC_NO_INPLACE(op_name) \
441   CREATE_UNARY_UFUNC_OUT(op_name);             \
442   CREATE_UNARY_UFUNC_FUNCTIONAL(op_name);
443 
444 // Exhaustive list of the unary ufuncs supported by sparse compressed
445 CREATE_UNARY_UFUNC(abs);
446 CREATE_UNARY_UFUNC(asin);
447 CREATE_UNARY_UFUNC(asinh);
448 CREATE_UNARY_UFUNC(atan);
449 CREATE_UNARY_UFUNC(atanh);
450 CREATE_UNARY_UFUNC(ceil);
451 CREATE_UNARY_UFUNC(deg2rad);
452 CREATE_UNARY_UFUNC(erf);
453 CREATE_UNARY_UFUNC(erfinv);
454 CREATE_UNARY_UFUNC(expm1);
455 CREATE_UNARY_UFUNC(floor);
456 CREATE_UNARY_UFUNC(frac);
457 CREATE_UNARY_UFUNC(log1p);
458 CREATE_UNARY_UFUNC(neg);
459 CREATE_UNARY_UFUNC(rad2deg);
460 CREATE_UNARY_UFUNC(sign);
461 CREATE_UNARY_UFUNC(sin);
462 CREATE_UNARY_UFUNC(sinh);
463 CREATE_UNARY_UFUNC(sgn);
464 CREATE_UNARY_UFUNC(sqrt);
465 CREATE_UNARY_UFUNC(tan);
466 CREATE_UNARY_UFUNC(tanh);
467 CREATE_UNARY_UFUNC(trunc);
468 CREATE_UNARY_UFUNC(conj_physical);
469 
470 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
471 static CREATE_UNARY_UFUNC(relu);
C10_DIAGNOSTIC_POP()472 C10_DIAGNOSTIC_POP()
473 
474 // With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
475 // to unresolved overload.
476 Tensor& round_sparse_csr_out(const Tensor& self, Tensor& result) {
477   return unary_op_out(&at::_ops::round_out::call, self, result);
478 }
479 
round_sparse_csr(const Tensor & self)480 Tensor round_sparse_csr(const Tensor& self) {
481   return get_result_tensor_for_unary_op(&at::_ops::round::call, self);
482 }
483 
round_sparse_csr_(Tensor & self)484 Tensor& round_sparse_csr_(Tensor& self) {
485   TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
486   self.values().round_();
487   return self;
488 }
489 
threshold_backward_sparse_compressed(const Tensor & grad_output,const Tensor & self,const Scalar & threshold)490 Tensor threshold_backward_sparse_compressed(
491     const Tensor& grad_output,
492     const Tensor& self,
493     const Scalar& threshold) {
494   return get_result_tensor_for_unary_op(
495       [&](const Tensor& t) {
496         return at::threshold_backward(t, self.values(), threshold);
497       },
498       grad_output);
499 }
500 
threshold_backward_sparse_compressed_out(const Tensor & grad_output,const Tensor & self,const Scalar & threshold,Tensor & grad_input)501 Tensor& threshold_backward_sparse_compressed_out(
502     const Tensor& grad_output,
503     const Tensor& self,
504     const Scalar& threshold,
505     Tensor& grad_input) {
506   return unary_op_out(
507       [&](const Tensor& t, Tensor& out) {
508         return at::threshold_backward_outf(t, self.values(), threshold, out);
509       },
510       grad_output,
511       grad_input);
512 }
513 
514 // angle, isneginf, isposinf and signbit currently don't have an inplace variant
515 CREATE_UNARY_UFUNC_NO_INPLACE(angle);
516 CREATE_UNARY_UFUNC_NO_INPLACE(isneginf);
517 CREATE_UNARY_UFUNC_NO_INPLACE(isposinf);
518 CREATE_UNARY_UFUNC_NO_INPLACE(signbit);
519 
520 // isnan and isinf don't have an out variant
521 CREATE_UNARY_UFUNC_FUNCTIONAL(isnan);
522 CREATE_UNARY_UFUNC_FUNCTIONAL(isinf);
523 
524 template <typename scalar_t>
addmm_out_sparse_csr_native_cpu(const Tensor & sparse,const Tensor & dense,const Tensor & r,Scalar alpha,Scalar beta)525 void addmm_out_sparse_csr_native_cpu(
526     const Tensor& sparse,
527     const Tensor& dense,
528     const Tensor& r,
529     Scalar alpha,
530     Scalar beta) {
531   auto dim_i = sparse.size(0);
532   auto dim_k = dense.size(1);
533 
534   auto csr = sparse.crow_indices();
535   auto col_indices = sparse.col_indices();
536   auto values = sparse.values();
537 
538   scalar_t cast_alpha = alpha.to<scalar_t>();
539   r.mul_(beta);
540   AT_DISPATCH_INDEX_TYPES(
541       col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
542         auto csr_accessor = csr.accessor<index_t, 1>();
543         auto col_indices_accessor = col_indices.accessor<index_t, 1>();
544 
545         auto values_accessor = values.accessor<scalar_t, 1>();
546         scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
547         scalar_t* r_ptr = r.data_ptr<scalar_t>();
548 
549         int64_t dense_stride0 = dense.stride(0);
550         int64_t dense_stride1 = dense.stride(1);
551         int64_t r_stride0 = r.stride(0);
552         int64_t r_stride1 = r.stride(1);
553 
554         at::parallel_for(
555             0,
556             dim_i,
557             internal::GRAIN_SIZE,
558             [&](int64_t irow_start, int64_t irow_end) {
559               for (index_t h = irow_start; h < irow_end; ++h) {
560                 index_t i_start = csr_accessor[h];
561                 index_t i_end = csr_accessor[h + 1];
562                 for (index_t i = i_start; i < i_end; i++) {
563                   scalar_t val = values_accessor[i];
564                   index_t col = col_indices_accessor[i];
565                   at::native::cpublas::axpy<scalar_t>(
566                       dim_k,
567                       cast_alpha * val,
568                       dense_ptr + col * dense_stride0,
569                       dense_stride1,
570                       r_ptr + h * r_stride0,
571                       r_stride1);
572                 }
573               }
574             });
575       });
576 }
577 
578 // Functions for matrix multiplication.
579 // result = beta * self + alpha (mat1 @ mat2)
addmm_out_sparse_compressed_cpu(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)580 Tensor& addmm_out_sparse_compressed_cpu(
581     const Tensor& self,
582     const Tensor& mat1,
583     const Tensor& mat2,
584     const Scalar& beta,
585     const Scalar& alpha,
586     Tensor& result) {
587   // All the checks are from addmm_out_cuda_impl (ATen/native/cuda/Blas.cpp) and
588   // TORCH_META_FUNC(addmm) (ATen/native/LinearAlgebra.cpp)
589   // TODO: remove code duplication and unify code
590   sparse::impl::_check_dim(mat1, 2, "mat1");
591   sparse::impl::_check_dim(mat2, 2, "mat2");
592 
593   TORCH_CHECK(
594       mat1.size(1) == mat2.size(0), "mat1 and mat2 shapes cannot be multiplied (",
595       mat1.size(0), "x", mat1.size(1), " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
596 
597   c10::MaybeOwned<at::Tensor> self_;
598   // Don't expand self if this is an in-place operation
599   if (&result == &self) {
600      self_ = c10::MaybeOwned<Tensor>::borrowed(self);
601   } else {
602      self_ = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");
603   }
604 
605 
606   TORCH_CHECK(((self_->dim() == 2) &&
607                (self_->size(0) == mat1.size(0)) &&
608                (self_->size(1) == mat2.size(1))),
609               "The input tensor must be a matrix with size ",
610               mat1.size(0),
611               "x",
612               mat2.size(1),
613               ", but got a ",
614               self_->dim(),
615               "-D tensor with size ",
616               self_->size(0),
617               "x",
618               self_->size(1));
619 
620   if (&result != &self) {
621     if (result.layout() == kStrided) {
622       at::native::resize_output(result, self_->sizes());
623     } else {
624       result.resize_as_sparse_(*self_);
625     }
626     result.copy_(*self_);
627   }
628 
629   if (result.numel() == 0) {
630     // If result gets resized and is sparse compressed,
631     // it's compressed_indices tensor will contain junk values
632     // so the whole tensor is not a valid compressed tensor.
633     // To combat that, result needs to get zeroed out.
634     if (at::sparse_csr::is_sparse_compressed(result)) {
635       result.zero_();
636     }
637     return result;
638   }
639 
640   if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) {
641     // According to docs, when beta==0 values in self should be ignored.
642     // nans and infs should not propagate
643     if (beta.toComplexDouble() == 0.) {
644       result.zero_();
645     } else {
646       result.mul_(beta);
647     }
648     return result;
649   }
650 
651 #if !AT_USE_MKL_SPARSE()
652   // The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @
653   // strided -> strided
654   if (mat1.layout() == kStrided) {
655     if (mat2.layout() == kSparseCsr) {
656       if (result.layout() == kStrided) {
657         AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
658             result.scalar_type(), "addmm_sparse_dense", [&] {
659               addmm_out_sparse_csr_native_cpu<scalar_t>(
660                   mat2.transpose(-2, -1).to_sparse_csr(),
661                   mat1.transpose(-2, -1),
662                   result.transpose(-2, -1),
663                   alpha,
664                   beta);
665             });
666         return result;
667       }
668     }
669     if (mat2.layout() == kSparseCsc) {
670       if (result.layout() == kStrided) {
671         AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
672             result.scalar_type(), "addmm_sparse_dense", [&] {
673               addmm_out_sparse_csr_native_cpu<scalar_t>(
674                   mat2.transpose(-2, -1),
675                   mat1.transpose(-2, -1),
676                   result.transpose(-2, -1),
677                   alpha,
678                   beta);
679             });
680         return result;
681       }
682     }
683   } else if (mat1.layout() == kSparseCsr) {
684     if (mat2.layout() == kStrided) {
685       if (result.layout() == kStrided) {
686         AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
687             result.scalar_type(), "addmm_sparse_dense", [&] {
688               addmm_out_sparse_csr_native_cpu<scalar_t>(
689                   mat1, mat2, result, alpha, beta);
690             });
691         return result;
692       }
693     }
694   } else if (mat1.layout() == kSparseCsc) {
695     if (mat2.layout() == kStrided) {
696       if (result.layout() == kStrided) {
697         AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
698             result.scalar_type(), "addmm_sparse_dense", [&] {
699               addmm_out_sparse_csr_native_cpu<scalar_t>(
700                   mat1.to_sparse_csr(), mat2, result, alpha, beta);
701             });
702         return result;
703       }
704     }
705   }
706   TORCH_CHECK(
707       false,
708       "addmm: computation on CPU is not implemented for ",
709       result.layout(),
710       " + ",
711       mat1.layout(),
712       " @ ",
713       mat2.layout(),
714       " without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors.");
715 #else
716   sparse::impl::mkl::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
717 #endif
718   return result;
719 }
720 
addmm_sparse_compressed_dense(const Tensor & self,const SparseCsrTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)721 Tensor addmm_sparse_compressed_dense(
722     const Tensor& self,
723     const SparseCsrTensor& sparse,
724     const Tensor& dense,
725     const Scalar& beta,
726     const Scalar& alpha) {
727   Tensor r = at::empty({0, 0}, self.options());
728   at::addmm_out(r, self, sparse, dense, beta, alpha);
729   return r;
730 }
731 
_sparse_csr_mm_out(const Tensor & mat1,const Tensor & mat2,Tensor & result)732 Tensor& _sparse_csr_mm_out(
733     const Tensor& mat1,
734     const Tensor& mat2,
735     Tensor& result) {
736   auto zero = at::zeros_like(result);
737   return at::addmm_out(result, zero, mat1, mat2, 0.0, 1.0);
738 }
739 
_sparse_csr_mm(const Tensor & mat1,const Tensor & mat2)740 Tensor _sparse_csr_mm(const Tensor& mat1, const Tensor& mat2) {
741   if (mat1.is_sparse_csr() && mat2.is_sparse_csr()) {
742     // Return sparse
743     return at::addmm(
744         at::zeros({mat1.size(0), mat2.size(1)}, mat2.options()),
745         mat1,
746         mat2,
747         0.0,
748         1.0);
749   }
750   if ((mat1.layout() == kSparseCsc || mat1.layout() == kSparseCsr) &&
751       (mat2.layout() == kSparseCsc || mat2.layout() == kSparseCsr)) {
752     // TODO: Expensive conversion to CSR. Should add native support for CSC.
753     // Covers CSC @ CSR
754     // Covers CSR @ CSC
755     // Covers CSC @ CSC
756     return _sparse_csr_mm(mat1.to_sparse_csr(), mat2.to_sparse_csr());
757   }
758   if (mat1.layout() == kSparseCsc && mat2.layout() == c10::kStrided) {
759     // TODO: This is a costly conversion. We should have
760     // native support for CSC.
761     return _sparse_csr_mm(mat1.to_sparse_csr(), mat2);
762   }
763   // Default to taking options from mat1
764   auto result_options = mat1.options();
765   if (mat2.layout() == kStrided) {
766     // if either  arg is strided we return strided, so update the options if
767     // mat2 is strided.
768     result_options = result_options.layout(kStrided);
769   }
770   return at::addmm(
771       at::zeros({mat1.size(0), mat2.size(1)}, result_options),
772       mat1,
773       mat2,
774       0.0,
775       1.0);
776 }
777 
778 // Functions for element-wise addition.
add_sparse_csr(const Tensor & self,const Tensor & other,const Scalar & alpha)779 Tensor add_sparse_csr(
780     const Tensor& self,
781     const Tensor& other,
782     const Scalar& alpha) {
783   auto commonDtype = at::result_type(self, other);
784   alpha_check(commonDtype, alpha);
785   Tensor result;
786   if (self.layout() != kStrided && other.layout() == kStrided) {
787     // add(sparse, dense) -> dense
788     result = at::empty_like(
789         other,
790         other.options()
791             .dtype(commonDtype)
792             .memory_format(at::MemoryFormat::Contiguous));
793   } else {
794     // add(dense, sparse) -> dense AND add(sparse, sparse) -> sparse
795     result = at::empty_like(
796         self,
797         self.options()
798             .dtype(commonDtype)
799             .memory_format(at::MemoryFormat::Contiguous));
800   }
801   return at::add_out(result, self, other, alpha); // redispatch!
802 }
803 
add_sparse_csr_(Tensor & self,const Tensor & other,const Scalar & alpha)804 Tensor& add_sparse_csr_(
805     Tensor& self,
806     const Tensor& other,
807     const Scalar& alpha) {
808   return at::add_out(self, self, other, alpha); // redispatch!
809 }
810 
add_out_dense_sparse_compressed_cpu(const Tensor & out,const Tensor & dense,const SparseCsrTensor & src,const Scalar & alpha)811 static void add_out_dense_sparse_compressed_cpu(
812     const Tensor& out,
813     const Tensor& dense,
814     const SparseCsrTensor& src,
815     const Scalar& alpha) {
816   TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
817   TORCH_INTERNAL_ASSERT(
818       src.layout() == kSparseCsr || src.layout() == kSparseCsc);
819   TORCH_INTERNAL_ASSERT(dense.device() == kCPU || dense.device() == kMeta);
820 
821   TORCH_CHECK(
822       out.is_contiguous(),
823       "out argument must be contiguous, but got: ",
824       out.suggest_memory_format());
825   TORCH_CHECK(
826       out.device() == dense.device(),
827       "add: expected 'out' to match dense tensor, but got tensor on device: ",
828       out.device());
829   TORCH_CHECK(
830       src.device() == dense.device(),
831       "add: expected 'src' to match dense tensor, but got tensor on device: ",
832       src.device());
833 
834   TORCH_CHECK(
835       dense.sizes().equals(src.sizes()),
836       "add: expected 'self' and 'other' to have same size, but self has size ",
837       dense.sizes(),
838       " while other has size ",
839       src.sizes(),
840       " (FYI: op2-sparse addition does not currently support broadcasting)");
841 
842   auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
843   TORCH_CHECK(
844       canCast(commonDtype, out.scalar_type()),
845       "Can't convert result type ",
846       commonDtype,
847       " to output ",
848       out.scalar_type(),
849       " in add operation");
850 
851   auto src_values = src.values();
852 
853   resize_output(out, dense.sizes());
854 
855   Tensor resultBuffer = out;
856 
857   if (out.scalar_type() != commonDtype) {
858     resultBuffer = dense.to(commonDtype);
859   } else if (!is_same_tensor(out, dense)) {
860     resultBuffer.copy_(dense);
861   }
862 
863   if (src._nnz() == 0) {
864     return;
865   }
866 
867   TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
868 
869   auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)});
870   resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
871   Tensor src_compressed_indices;
872   Tensor src_plain_indices;
873   std::tie(src_compressed_indices, src_plain_indices) =
874       at::sparse_csr::getCompressedPlainIndices(src);
875   src_compressed_indices =
876       src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
877   src_plain_indices =
878       src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
879   auto src_layout = src.layout();
880 
881   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
882       kComplexHalf,
883       kHalf,
884       kBool,
885       kBFloat16,
886       commonDtype,
887       "add_out_op2_sparse_csr",
888       [&valuesBuffer,
889        &resultBuffer,
890        &alpha,
891        &src_compressed_indices,
892        &src_plain_indices,
893        &src_layout]() {
894         AT_DISPATCH_INDEX_TYPES(
895             src_compressed_indices.scalar_type(),
896             "csr_add_out_crow_indices",
897             [&valuesBuffer,
898              &resultBuffer,
899              &alpha,
900              &src_compressed_indices,
901              &src_plain_indices,
902              &src_layout]() {
903               auto batch_count =
904                   resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
905               auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
906               scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
907               scalar_t cast_value = alpha.to<scalar_t>();
908 
909               auto compressed_indices_accessor =
910                   src_compressed_indices.accessor<index_t, 2>();
911               auto plain_indices_accessor =
912                   src_plain_indices.accessor<index_t, 2>();
913               auto out_strides = resultBuffer.strides();
914               auto const out_stride_batch = out_strides[0];
915               auto const out_stride_compressed =
916                   AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
917                       src_layout,
918                       "add_out_dense_sparse_compressed_cpu",
919                       [&out_strides] { return out_strides[1]; },
920                       [&out_strides] { return out_strides[2]; });
921               auto const out_stride_plain =
922                   AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
923                       src_layout,
924                       "add_out_dense_sparse_compressed_cpu",
925                       [&out_strides] { return out_strides[2]; },
926                       [&out_strides] { return out_strides[1]; });
927 
928               for (const auto batch_idx : c10::irange(batch_count)) {
929                 for (const auto i_compressed :
930                      c10::irange(src_compressed_indices.size(-1) - 1)) {
931                   index_t start_index =
932                       compressed_indices_accessor[batch_idx][i_compressed];
933                   index_t end_index =
934                       compressed_indices_accessor[batch_idx][i_compressed + 1];
935                   for (const auto i : c10::irange(start_index, end_index)) {
936                     auto i_plain = plain_indices_accessor[batch_idx][i];
937                     auto index = batch_idx * out_stride_batch +
938                         i_compressed * out_stride_compressed +
939                         i_plain * out_stride_plain;
940                     out_ptr[index] +=
941                         cast_value * values_accessor[batch_idx][i];
942                   }
943                 }
944               }
945             });
946       });
947   if (out.scalar_type() != commonDtype) {
948     out.copy_(resultBuffer);
949   }
950 }
951 
add_out_sparse_compressed_cpu(const Tensor & self,const SparseCsrTensor & other,const Scalar & alpha,SparseCsrTensor & out)952 Tensor& add_out_sparse_compressed_cpu(
953     const Tensor& self,
954     const SparseCsrTensor& other,
955     const Scalar& alpha,
956     SparseCsrTensor& out) {
957   if (self.layout() == kStrided) {
958     add_out_dense_sparse_compressed_cpu(out, self, other, alpha);
959   } else if (other.layout() == kStrided) {
960     add_out_dense_sparse_compressed_cpu(out, other, self, alpha);
961   } else {
962     TORCH_CHECK(
963         self.sizes().equals(other.sizes()),
964         "torch.add: Expected input tensors to have the same shape, but got tensor `self` with shape ",
965         self.sizes(),
966         " and tensor `other` with shape ",
967         other.sizes());
968 
969     if (only_sparse_compressed_add_trivial_cases(self, other, alpha, out)) {
970       return out;
971     }
972 
973     at::native::resize_as_sparse_compressed_(out, self);
974     sparse::impl::cpu::add_out_sparse_csr(self, other, alpha, out);
975   }
976   return out;
977 }
978 
979 /*
980     Reductions on sparse CSR tensors using masked semantics.
981 
982     - A CSR tensor is a 2D tensor that is specified by a 3-tuple
983       (crow_indices, col_indices, values).
984 
985     - To support a reduction operator on a CSR tensor, define:
986 
987 template <typename scalar_t>
988 struct Reduction...Op {
989   inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
990     return a ... b;
991   }
992   inline scalar_t identity() const { return ...; }
993 };
994 
995 Tensor _sparse_csr_..._cpu(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
996   ...
997       result = reduce_sparse_csr_cpu_template<scalar_t>(input_, dims_to_sum, keepdim, Reduction...Op<scalar_t>());
998   ...
999   return result;
1000 }
1001 
1002       and add the following
1003 
1004         - func: _sparse_csr_op.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
1005           dispatch:
1006             SparseCsrCUDA: _sparse_csr_..._cpu
1007 
1008       to native_functions.yaml
1009 
1010       Use ReductionAddOp and _sparse_csr_sum implementation as an example.
1011 
1012     - Since a CSR tensor dimensionality is always 2, only reductions
1013       with keepdim=True can be supported.
1014 
1015 */
1016 
1017 namespace {
1018 
1019 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_dim0_cpu_template(const Tensor & sparse,ReductionOp rop)1020 Tensor reduce_sparse_csr_dim0_cpu_template(const Tensor& sparse, ReductionOp rop) {
1021   /*
1022     Consider the following sparse tensor:
1023 
1024     1 * * * *
1025     * * * 2 *
1026     * * 3 * *
1027     * * * * *
1028     4 * 5 * *
1029 
1030     that has CSR representation
1031 
1032       crow_indices = [0, 1, 2, 3, 3, 5]
1033       col_indices = [0, 3, 2, 0, 2]
1034       values = [1, 2, 3, 4, 5]
1035 
1036     Reduction with dim=0 results:
1037 
1038     rop(1,4) * rop(3,5) 2 *
1039 
1040     that has CSR representation
1041 
1042       new_crow_indices = [0, 3]
1043       new_col_indices = [0, 2, 3]
1044       new_values = [rop(1, 4], rop(3, 5), 2]
1045 
1046     In general, the CSR representation data can be computed as follows:
1047 
1048       new_col_indices, col_map = col_indices.unique(sorted=True, return_inverse=True)
1049       nnz = new_col_indices.numel()
1050       new_crow_indices = [0, nnz]
1051       new_values.resize(nnz); new_values.fill_(identity)
1052       for i in range(col_indices.numel()):
1053           new_values[col_map[i]] = rop(new_values[col_map[i], values[i])
1054    */
1055 
1056   Tensor col_indices = sparse.col_indices();
1057   Tensor values = sparse.values();
1058   auto numel = values.numel();
1059 
1060   /*
1061     Calling at::_unique constitutes the main bottleneck of this
1062     function. However, it is still about 5x faster than using the
1063     invariant:
1064       csr.sum(dim=0) == csr.transpose(0, 1).sum(dim=1)
1065   */
1066   auto [new_col_indices, columns_map] = at::_unique(col_indices, true, true);
1067   auto nnz = new_col_indices.numel();
1068 
1069   Tensor new_crow_indices = at::empty({2}, col_indices.options());
1070   new_crow_indices[0] = 0;
1071   new_crow_indices[1] = nnz;
1072 
1073   // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1074   // of float should be float in current scenario. In CUDA, float is the accumulate type
1075   // of float, while in CPU, double is the accumulate type of float.
1076   using acc_t = at::acc_type<scalar_t, true>;
1077   auto acc_buffer = at::sparse_csr::create_acc_buffer<acc_t, scalar_t>(
1078       values.options(), values.scalar_type(), nnz);
1079   Tensor new_values = std::get<0>(acc_buffer);
1080   Tensor new_values_acc = std::get<1>(acc_buffer);
1081   new_values_acc.fill_(rop.identity());
1082 
1083   int64_t* columns_map_ptr = columns_map.data_ptr<int64_t>();
1084   scalar_t* values_ptr = values.data_ptr<scalar_t>();
1085   acc_t* new_values_acc_ptr =
1086       new_values_acc.data_ptr<acc_t>();
1087 
1088   // There is no point in parallelizing the following for-loop
1089   // because about 99.3% of the computation time is spent in the
1090   // at::_unique call above.
1091   for (const auto i : c10::irange(numel)) {
1092     int64_t col = columns_map_ptr[i];
1093     scalar_t val = values_ptr[i];
1094     new_values_acc_ptr[col] = rop(new_values_acc_ptr[col], static_cast<acc_t>(val));
1095   }
1096   copy_from_acc_buffer(new_values, new_values_acc);
1097 
1098   return at::native::_sparse_csr_tensor_unsafe(new_crow_indices, new_col_indices, new_values,
1099                                               {1, sparse.size(1)},
1100                                               new_values.scalar_type(),
1101                                               sparse.layout(),
1102                                               new_values.device());
1103 }
1104 
1105 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_dim1_cpu_template(const Tensor & sparse,ReductionOp rop)1106 Tensor reduce_sparse_csr_dim1_cpu_template(const Tensor& sparse, ReductionOp rop) {
1107   /*
1108     Consider the following sparse tensor:
1109 
1110     1 * * * *
1111     * * * 2 *
1112     * * 3 * *
1113     * * * * *
1114     4 * 5 * *
1115 
1116     that has CSR representation
1117 
1118       crow_indices = [0, 1, 2, 3, 3, 5]
1119       col_indices = [0, 3, 2, 0, 2]
1120       values = [1, 2, 3, 4, 5]
1121 
1122     Reduction with dim=1 results:
1123 
1124     1
1125     2
1126     3
1127     *
1128     rop(4, 5)
1129 
1130     that has CSR representation
1131 
1132       new_crow_indices = [0, 1, 2, 3, 3, 4]
1133       new_col_indices = [0, 0, 0, 0]
1134       new_values = [1, 2, 3, rop(4, 5)]
1135 
1136     In general, the result CSR data can be computed as follows:
1137 
1138       new_crow_indices = [0]
1139       for i in range(1, nrows+1):
1140           new_crow_indices[i] = new_crow_indices[i-1] + (crow_indices[i] == crow_indices[i-1])
1141       nnz = new_crow_indices[-1]
1142       new_col_indices = zeros(nnz)
1143       new_values.resize(nnz)
1144       j = -1
1145       for i in range(1, nrows+1):
1146           if crow_indices[i] == crow_indices[i-1]:
1147               continue
1148           j += 1
1149           new_values[j] = rop(values[crow_indices[i] : crow_indices[i-1]])
1150   */
1151 
1152   Tensor crow_indices = sparse.crow_indices();
1153   auto ioptions = crow_indices.options();
1154   Tensor values = sparse.values();
1155   auto nrows = sparse.size(0);
1156 
1157   Tensor new_crow_indices = at::empty({crow_indices.numel()}, ioptions);
1158   Tensor new_col_indices = at::empty({}, ioptions);
1159   Tensor row_map = at::empty({nrows}, ioptions);
1160 
1161   // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1162   // of float should be float in current scenario. In CUDA, float is the accumulate type
1163   // of float, while in CPU, double is the accumulate type of float.
1164   using acc_t = at::acc_type<scalar_t, true>;
1165   auto acc_buffer = at::sparse_csr::create_acc_buffer<acc_t, scalar_t>(
1166       values.options(), values.scalar_type());
1167   Tensor new_values = std::get<0>(acc_buffer);
1168   Tensor new_values_acc = std::get<1>(acc_buffer);
1169 
1170   AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "reduce_sparse_csr_dim1_cpu_indices",
1171                           [&]() {
1172     index_t* crow_indices_ptr = crow_indices.data_ptr<index_t>();
1173     index_t* new_crow_indices_ptr = new_crow_indices.data_ptr<index_t>();
1174     index_t* row_map_ptr = row_map.data_ptr<index_t>();
1175     int64_t nnz = 0;
1176     new_crow_indices_ptr[0] = 0;
1177     for(int64_t i=0; i<nrows; i++) {
1178       if (crow_indices_ptr[i] != crow_indices_ptr[i + 1]) {
1179         row_map_ptr[i] = nnz;
1180         nnz++;
1181       }
1182       new_crow_indices_ptr[i + 1] = nnz;
1183     }
1184     new_col_indices.resize_(nnz);
1185     new_col_indices.fill_(index_t(0));
1186     new_values.resize_(nnz);
1187     new_values_acc.resize_(nnz);
1188 
1189     scalar_t* values_ptr = values.data_ptr<scalar_t>();
1190     acc_t* new_values_acc_ptr = new_values_acc.data_ptr<acc_t>();
1191 
1192     at::parallel_for(
1193         0,
1194         nrows,
1195         internal::GRAIN_SIZE,
1196         [&](int64_t irow_start, int64_t irow_end) {
1197             index_t i_end = crow_indices_ptr[irow_start];
1198             for (index_t h = irow_start; h < irow_end; ++h) {
1199               index_t i_start = i_end;
1200               i_end = crow_indices_ptr[h+1];
1201               if (i_start != i_end) {
1202                 acc_t res = static_cast<acc_t>(values_ptr[i_start]);
1203                 for (index_t i = i_start + 1; i < i_end; i++) {
1204                   res = rop(res, static_cast<acc_t>(values_ptr[i]));
1205                 }
1206                 new_values_acc_ptr[row_map_ptr[h]] = res;
1207               }
1208             }
1209         });
1210                           });
1211 
1212   copy_from_acc_buffer(new_values, new_values_acc);
1213 
1214   return at::native::_sparse_csr_tensor_unsafe(new_crow_indices, new_col_indices, new_values,
1215                                                 {sparse.size(0), 1},
1216                                                 new_values.scalar_type(),
1217                                                 sparse.layout(),
1218                                                 new_values.device());
1219 }
1220 
1221 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_dim01_cpu_template(const Tensor & sparse,ReductionOp rop)1222 Tensor reduce_sparse_csr_dim01_cpu_template(const Tensor& sparse, ReductionOp rop) {
1223 
1224   auto ioptions = sparse.col_indices().options();
1225   Tensor values = sparse.values();
1226   auto numel = values.numel();
1227   auto nnz = std::min<int64_t>(1, numel);
1228 
1229   /* TODO: we can likely do about 3x better than parallel_reduce:
1230 
1231 In [2]: t=torch.randn(5000, 5000).to_sparse_csr()
1232 
1233 In [3]: %timeit torch._sparse_csr_sum(t, dim=(0, 1), keepdim=True)
1234 3.39 ms ± 898 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
1235 
1236 In [4]: %timeit torch.sum(t.values())
1237 1.07 ms ± 291 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1238   */
1239 
1240   // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1241   // of float should be float in current scenario. In CUDA, float is the accumulate type
1242   // of float, while in CPU, double is the accumulate type of float.
1243   using acc_t = at::acc_type<scalar_t, true>;
1244   scalar_t* values_ptr = values.data_ptr<scalar_t>();
1245   acc_t value = at::parallel_reduce(
1246                                        0,
1247                                        numel,
1248                                        internal::GRAIN_SIZE,
1249                                        rop.identity(),
1250                                        [&](int64_t i_start, int64_t i_end, scalar_t identity) {
1251                                          acc_t res = acc_t(identity);
1252                                          for (int64_t i=i_start; i<i_end; i++) {
1253                                            acc_t val = acc_t(values_ptr[i]);
1254                                            res = rop(res, val);
1255                                          }
1256                                          return res;
1257                                        }, rop
1258                                        );
1259 
1260   Tensor new_col_indices = at::zeros({nnz}, ioptions);
1261   Tensor new_crow_indices = at::tensor(ArrayRef<int64_t>{0, nnz}, ioptions);
1262   Tensor new_values;
1263   auto result_dtype = at::isIntegralType(values.scalar_type(), /*includeBool=*/true) ? ScalarType::Long : values.scalar_type();
1264   if (numel > 0) {
1265     new_values = at::empty({1}, values.options().dtype(result_dtype));
1266     new_values.fill_(value);
1267   } else {
1268     new_values = at::empty({}, values.options().dtype(result_dtype));
1269   }
1270   return at::native::_sparse_csr_tensor_unsafe(new_crow_indices, new_col_indices, new_values,
1271                                                {1, std::min<int64_t>(1, sparse.size(1))},
1272                                                new_values.scalar_type(),
1273                                                sparse.layout(),
1274                                                new_values.device());
1275 }
1276 
1277 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_cpu_template(const Tensor & sparse,std::vector<int64_t> dims,ReductionOp rop)1278 Tensor reduce_sparse_csr_cpu_template(const Tensor& sparse, std::vector<int64_t> dims, ReductionOp rop) {
1279   if (dims.size() == 1) {
1280     if (dims[0] == 0) {
1281       return reduce_sparse_csr_dim0_cpu_template<scalar_t>(sparse, rop);
1282     } else {
1283       TORCH_INTERNAL_ASSERT(dims[0] == 1);
1284       return reduce_sparse_csr_dim1_cpu_template<scalar_t>(sparse, rop);
1285     }
1286   } else if (dims.size() == 2) {
1287     TORCH_INTERNAL_ASSERT(((dims[0] == 0 && dims[1] == 1) || (dims[0] == 1 && dims[1] == 0)));
1288     return reduce_sparse_csr_dim01_cpu_template<scalar_t>(sparse, rop);
1289   }
1290   TORCH_INTERNAL_ASSERT(dims.empty());
1291   // effective after gh-29137 has been resolved
1292   return sparse.clone();
1293 }
1294 
1295 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_cpu_template(const Tensor & sparse,IntArrayRef dims_to_sum,bool keepdim,ReductionOp rop)1296 Tensor reduce_sparse_csr_cpu_template(const Tensor& sparse, IntArrayRef dims_to_sum, bool keepdim, ReductionOp rop) {
1297   TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr());
1298   TORCH_CHECK(keepdim, "reduction operations on CSR tensors with keepdim=False is unsupported");
1299   TORCH_INTERNAL_ASSERT(sparse.device() == kCPU);
1300 
1301   const int64_t input_dim = sparse.dim();
1302   TORCH_INTERNAL_ASSERT(input_dim == 2);
1303   auto dims = dims_to_sum.vec();
1304   maybe_wrap_dims(dims, input_dim);
1305   if (dims.empty()) {
1306     // after gh-29137 is resolved, delete this if-block
1307     dims.emplace_back(0);
1308     dims.emplace_back(1);
1309   }
1310   return reduce_sparse_csr_cpu_template<scalar_t>(sparse, dims, rop);
1311 }
1312 
1313 template <typename scalar_t>
1314 struct ReductionAddOp {
operator ()at::native::__anon0d9558831e11::ReductionAddOp1315   inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
1316     return a + b;
1317   }
identityat::native::__anon0d9558831e11::ReductionAddOp1318   inline scalar_t identity() const { return 0; }
1319 };
1320 
1321 template <typename scalar_t>
1322 struct ReductionMulOp {
operator ()at::native::__anon0d9558831e11::ReductionMulOp1323   inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
1324     return a * b;
1325   }
identityat::native::__anon0d9558831e11::ReductionMulOp1326   inline scalar_t identity() const { return 1; }
1327 };
1328 
1329 }  // namespace
1330 
_sparse_csr_sum_cpu(const Tensor & input,IntArrayRef dims_to_sum,bool keepdim,std::optional<ScalarType> dtype)1331 Tensor _sparse_csr_sum_cpu(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
1332   ScalarType dtype_ = dtype.value_or(input.scalar_type());
1333   Tensor input_ = at::sparse_csr::to_type(input, dtype_);
1334   Tensor result;
1335   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
1336       kHalf, kBFloat16, input_.scalar_type(), "_sparse_csr_sum_cpu", [&] {
1337         // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1338         // of float should be float in current scenario. In CUDA, float is the accumulate type
1339         // of float, while in CPU, double is the accumulate type of float.
1340         using acc_t = at::acc_type<scalar_t, true>;
1341         result = reduce_sparse_csr_cpu_template<scalar_t>(
1342             input_, dims_to_sum, keepdim, ReductionAddOp<acc_t>());
1343       });
1344   return result;
1345 }
1346 
_sparse_csr_prod_cpu(const Tensor & input,IntArrayRef dims_to_reduce,bool keepdim,std::optional<ScalarType> dtype)1347 Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, bool keepdim, std::optional<ScalarType> dtype) {
1348   ScalarType dtype_ = dtype.value_or(input.scalar_type());
1349   Tensor input_ = input.to(dtype_);
1350   Tensor result;
1351   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
1352     kHalf, kBFloat16, input_.scalar_type(), "_sparse_csr_prod_cpu",
1353     [&] {
1354       result = reduce_sparse_csr_cpu_template<scalar_t>(input_, dims_to_reduce, keepdim, ReductionMulOp<scalar_t>());
1355     });
1356   return result;
1357 }
1358 
_sparse_mm_reduce_impl_sparse_csr_cpu(const Tensor & self,const Tensor & other,const c10::string_view reduce)1359 std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_sparse_csr_cpu(
1360     const Tensor& self,
1361     const Tensor& other,
1362     const c10::string_view reduce) {
1363 
1364   auto layout = self.layout();
1365   TORCH_CHECK(layout == kSparseCsr,
1366       "sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
1367   TORCH_CHECK(self.dense_dim() == 0,
1368       "sparse_mm_reduce: expected non-hybrid self tensor.");
1369   TORCH_CHECK(self.dim() == 2,
1370       "sparse_mm_reduce: expected self to be a 2-D tensor, got ", self.dim(), "-D tensor.");
1371 
1372   sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/false>(
1373       self, Tensor(), other);
1374 
1375   auto op = get_reduction_enum(reduce);
1376   TORCH_CHECK(op != ReductionType::PROD, "sparse_mm_reduce: reduce type of prod has not been enabled.")
1377 
1378   auto crow = self.crow_indices();
1379   auto col = self.col_indices();
1380   auto val = self.values();
1381 
1382   // init output to be all zeros, for `rows` that has no nonzero elements,
1383   // the corresponding rows in the output will be zero.
1384   auto out = at::zeros({self.size(0), other.size(1)}, other.options());
1385   auto arg_out = at::empty({0}, col.options());
1386 
1387   int64_t nnz = self._nnz();
1388   if (nnz == 0) {
1389     return std::make_tuple(out, arg_out);
1390   }
1391 
1392   // only need to calculate the out args
1393   // for reduce type "amax" and "amin" for training
1394   bool need_arg_out = at::GradMode::is_enabled()
1395       && (self.requires_grad() || other.requires_grad())
1396       && (op == ReductionType::MAX || op == ReductionType::MIN);
1397 
1398   if (!need_arg_out) {
1399     spmm_reduce_stub(kCPU, out, crow, col, val, other, op);
1400   } else {
1401     // allocate memory and init with invalid index
1402     arg_out.resize_(out.sizes());
1403     arg_out.fill_(nnz);
1404     spmm_reduce_arg_stub(kCPU, out, arg_out, crow, col, val, other, op);
1405   }
1406 
1407   return std::make_tuple(std::move(out), std::move(arg_out));
1408 }
1409 
_sparse_mm_reduce_impl_backward_sparse_csr_cpu(const Tensor & self,const Tensor & grad_out,const Tensor & other,const c10::string_view reduce,const Tensor & arg_out,std::array<bool,2> output_mask)1410 std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_backward_sparse_csr_cpu(
1411     const Tensor& self,
1412     const Tensor& grad_out,
1413     const Tensor& other,
1414     const c10::string_view reduce,
1415     const Tensor& arg_out,
1416     std::array<bool, 2> output_mask) {
1417 
1418   auto layout = self.layout();
1419   TORCH_CHECK(layout == kSparseCsr,
1420       "sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
1421 
1422   sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/true>(
1423       self, grad_out, other);
1424 
1425   auto op = get_reduction_enum(reduce);
1426 
1427   auto crow = self.crow_indices();
1428   auto col = self.col_indices();
1429   auto val = self.values();
1430 
1431   // `row`: row indices of COO format
1432   // `ccol`: ccol indices of CSC format (with permute)
1433   // `permute`: permute pattern from CSR to CSC
1434   //
1435   // TODO: optimize the following section,
1436   // currently `argsort` is sequential.
1437   Tensor row, ccol, permute;
1438   {
1439     bool out_int32 = crow.scalar_type() == ScalarType::Int;
1440     Tensor coo_indices = at::_convert_indices_from_csr_to_coo(
1441         crow,
1442         col,
1443         out_int32,
1444         /*transpose*/false);
1445     row = coo_indices.select(0, 0);
1446 
1447     // calculate the global index for CSC
1448     // and get the conversion permute pattern
1449     Tensor index = col.mul(self.size(0)).add_(row);
1450     permute = index.argsort();
1451 
1452     ccol = at::_convert_indices_from_coo_to_csr(
1453         /*column indices*/col.index_select(0, permute),
1454         /*column count*/self.size(1),
1455         out_int32);
1456   }
1457 
1458   Tensor grad_self, grad_other;
1459   if (output_mask[0]) {
1460     // grad_input has the same indices and nnz with input
1461     grad_self = at::empty_like(self);
1462     grad_self.values().zero_();
1463     if (op == ReductionType::MAX || op == ReductionType::MIN) {
1464       spmm_reduce_backward_input_arg_stub(kCPU, grad_self, grad_out, col, other, arg_out, op);
1465     } else {
1466       spmm_reduce_backward_input_stub(kCPU, grad_self, grad_out, crow, col, other, row, op);
1467     }
1468   }
1469   if (output_mask[1]) {
1470     grad_other = at::zeros(other.sizes(), other.options());
1471     if (op == ReductionType::MAX || op == ReductionType::MIN) {
1472       spmm_reduce_backward_other_arg_stub(kCPU, grad_other, grad_out, col, val, arg_out, op);
1473     } else {
1474       spmm_reduce_backward_other_stub(kCPU, grad_other, grad_out, crow, val, row, ccol, permute, op);
1475     }
1476   }
1477 
1478   return std::make_tuple(std::move(grad_self), std::move(grad_other));
1479 }
1480 
1481 DEFINE_DISPATCH(spmm_reduce_stub);
1482 DEFINE_DISPATCH(spmm_reduce_arg_stub);
1483 DEFINE_DISPATCH(spmm_reduce_backward_input_stub);
1484 DEFINE_DISPATCH(spmm_reduce_backward_input_arg_stub);
1485 DEFINE_DISPATCH(spmm_reduce_backward_other_stub);
1486 DEFINE_DISPATCH(spmm_reduce_backward_other_arg_stub);
1487 
1488 } // namespace native
1489 } // namespace at
1490