1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/SparseCsrTensorUtils.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/grad_mode.h>
8 #include <ATen/mkl/Sparse.h>
9 #include <ATen/native/BinaryOps.h>
10 #include <ATen/native/CPUBlas.h>
11 #include <ATen/native/Resize.h>
12 #include <ATen/native/SparseTensorUtils.h>
13 #include <ATen/native/TensorConversions.h>
14 #include <ATen/native/mkl/SparseBlasImpl.h>
15 #include <ATen/native/sparse/SparseBlasImpl.h>
16 #include <ATen/native/sparse/SparseCsrTensorMath.h>
17 #include <c10/macros/Macros.h>
18 #include <c10/util/irange.h>
19 #include <ATen/AccumulateType.h>
20
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #include <ATen/NativeFunctions.h>
24 #include <ATen/Operators.h>
25 #else
26 #include <ATen/ops/_conj_physical_native.h>
27 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
28 #include <ATen/ops/_convert_indices_from_coo_to_csr_native.h>
29 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
30 #include <ATen/ops/_convert_indices_from_csr_to_coo_native.h>
31 #include <ATen/ops/_sparse_bsr_tensor_unsafe_native.h>
32 #include <ATen/ops/_sparse_compressed_tensor_unsafe_native.h>
33 #include <ATen/ops/_sparse_csr_prod_native.h>
34 #include <ATen/ops/_sparse_csr_sum_native.h>
35 #include <ATen/ops/_sparse_csr_tensor_unsafe_native.h>
36 #include <ATen/ops/_sparse_mm_reduce_impl_backward_native.h>
37 #include <ATen/ops/_sparse_mm_reduce_impl_native.h>
38 #include <ATen/ops/_unique.h>
39 #include <ATen/ops/abs.h>
40 #include <ATen/ops/abs_native.h>
41 #include <ATen/ops/add.h>
42 #include <ATen/ops/add_native.h>
43 #include <ATen/ops/addmm.h>
44 #include <ATen/ops/addmm_native.h>
45 #include <ATen/ops/angle.h>
46 #include <ATen/ops/angle_native.h>
47 #include <ATen/ops/asin.h>
48 #include <ATen/ops/asin_native.h>
49 #include <ATen/ops/asinh.h>
50 #include <ATen/ops/asinh_native.h>
51 #include <ATen/ops/atan.h>
52 #include <ATen/ops/atan_native.h>
53 #include <ATen/ops/atanh.h>
54 #include <ATen/ops/atanh_native.h>
55 #include <ATen/ops/ceil.h>
56 #include <ATen/ops/ceil_native.h>
57 #include <ATen/ops/conj_physical.h>
58 #include <ATen/ops/conj_physical_native.h>
59 #include <ATen/ops/copy_native.h>
60 #include <ATen/ops/deg2rad.h>
61 #include <ATen/ops/deg2rad_native.h>
62 #include <ATen/ops/empty.h>
63 #include <ATen/ops/empty_like.h>
64 #include <ATen/ops/erf.h>
65 #include <ATen/ops/erf_native.h>
66 #include <ATen/ops/erfinv.h>
67 #include <ATen/ops/erfinv_native.h>
68 #include <ATen/ops/expm1.h>
69 #include <ATen/ops/expm1_native.h>
70 #include <ATen/ops/fill_native.h>
71 #include <ATen/ops/floor.h>
72 #include <ATen/ops/floor_native.h>
73 #include <ATen/ops/frac.h>
74 #include <ATen/ops/frac_native.h>
75 #include <ATen/ops/isinf.h>
76 #include <ATen/ops/isinf_native.h>
77 #include <ATen/ops/isnan.h>
78 #include <ATen/ops/isnan_native.h>
79 #include <ATen/ops/isneginf.h>
80 #include <ATen/ops/isneginf_native.h>
81 #include <ATen/ops/isposinf.h>
82 #include <ATen/ops/isposinf_native.h>
83 #include <ATen/ops/log1p.h>
84 #include <ATen/ops/log1p_native.h>
85 #include <ATen/ops/mm_native.h>
86 #include <ATen/ops/mul.h>
87 #include <ATen/ops/mul_native.h>
88 #include <ATen/ops/neg.h>
89 #include <ATen/ops/neg_native.h>
90 #include <ATen/ops/normal_native.h>
91 #include <ATen/ops/ones.h>
92 #include <ATen/ops/ones_like.h>
93 #include <ATen/ops/rad2deg.h>
94 #include <ATen/ops/rad2deg_native.h>
95 #include <ATen/ops/relu.h>
96 #include <ATen/ops/relu_native.h>
97 #include <ATen/ops/resize_as_sparse_native.h>
98 #include <ATen/ops/result_type.h>
99 #include <ATen/ops/round.h>
100 #include <ATen/ops/round_native.h>
101 #include <ATen/ops/round_ops.h>
102 #include <ATen/ops/sgn.h>
103 #include <ATen/ops/sgn_native.h>
104 #include <ATen/ops/sign.h>
105 #include <ATen/ops/sign_native.h>
106 #include <ATen/ops/signbit.h>
107 #include <ATen/ops/signbit_native.h>
108 #include <ATen/ops/sin.h>
109 #include <ATen/ops/sin_native.h>
110 #include <ATen/ops/sinh.h>
111 #include <ATen/ops/sinh_native.h>
112 #include <ATen/ops/sparse_mask.h>
113 #include <ATen/ops/sparse_mask_native.h>
114 #include <ATen/ops/sqrt.h>
115 #include <ATen/ops/sqrt_native.h>
116 #include <ATen/ops/tan.h>
117 #include <ATen/ops/tan_native.h>
118 #include <ATen/ops/tanh.h>
119 #include <ATen/ops/tanh_native.h>
120 #include <ATen/ops/tensor.h>
121 #include <ATen/ops/threshold_backward.h>
122 #include <ATen/ops/threshold_backward_native.h>
123 #include <ATen/ops/trunc.h>
124 #include <ATen/ops/trunc_native.h>
125 #include <ATen/ops/zero_native.h>
126 #include <ATen/ops/zeros.h>
127 #include <ATen/ops/zeros_like.h>
128 #endif
129
130 #include <algorithm>
131
132 namespace at {
133 namespace meta {
134
TORCH_META_FUNC(_convert_indices_from_coo_to_csr)135 TORCH_META_FUNC(_convert_indices_from_coo_to_csr)
136 (const Tensor& self, const int64_t size, const bool out_int32) {
137 TORCH_CHECK(self.dim() <= 1, "Input is supposed to be a vector, but got ",
138 self.dim(), " dimensional tensor.");
139 ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
140 c10::TensorOptions options =
141 TensorOptions().device(self.options().device()).dtype(scalar_type);
142 set_output_raw_strided(0, size + 1, {}, options);
143 }
144
TORCH_META_FUNC(_convert_indices_from_csr_to_coo)145 TORCH_META_FUNC(_convert_indices_from_csr_to_coo)
146 (const Tensor& crow_indices,
147 const Tensor& col_indices,
148 const bool out_int32,
149 const bool transpose) {
150 TORCH_CHECK(
151 crow_indices.dim() == col_indices.dim(), "crow_indices and col_indices are supposed to have"
152 " the same dimensionality, but got ", crow_indices.dim(), " and ",
153 crow_indices.dim(), " dimensional tensors, respectively.");
154 ScalarType scalar_type = out_int32 ? ScalarType::Int : ScalarType::Long;
155 c10::TensorOptions options = crow_indices.options().dtype(scalar_type);
156 set_output_raw_strided(0, {col_indices.dim() + 1, col_indices.numel()}, {}, options, {});
157 }
158
159 } // namespace meta
160
161 namespace {
162
163 template <typename F>
unary_op_out(F op_out,const Tensor & self,Tensor & result)164 Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result) {
165 TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
166 TORCH_INTERNAL_ASSERT(result.is_sparse_csr());
167
168 if (!result.is_same(self)) {
169 // For the case of (0x0) result tensor, manually resize `result` tensor
170 // to the size of `self` tensor
171 if (result.numel() == 0) {
172 at::native::resize_as_sparse_compressed_(result, self);
173 }
174 // copy_sparse_compressed_ internally checks the sizes of result and self tensors
175 // Hence no external size check required
176 at::native::copy_sparse_compressed_(result, self);
177 }
178
179 auto self_values = self.values();
180 auto result_values = result.values();
181
182 op_out(self_values, result_values);
183 return result;
184 }
185
186 template <typename F, typename... Args>
unary_op_inplace(Tensor & self,const F & op_inplace,Args &&...args)187 Tensor& unary_op_inplace(Tensor& self, const F& op_inplace, Args&&... args) {
188 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "unary_op_inplace", [](){});
189
190 auto self_values = self.values();
191 (self_values.*op_inplace)(std::forward<Args>(args)...);
192 return self;
193 }
194
195 } // end anonymous namespace
196
197 namespace native {
198
199 using namespace at::sparse_csr;
200 // certain utility functions are usable from sparse COO.
201 using namespace at::sparse;
202
mul_out_sparse_csr(const Tensor & t_,const Tensor & src_,Tensor & r)203 Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
204 // // TODO: Use a specialized CSR kernel for performance if needed
205 if (t_.is_sparse_csr() && src_.layout() == kStrided) {
206 return mul_out_sparse_csr(t_, src_.sparse_mask(t_), r);
207 }
208 if (t_.layout() == kStrided && src_.is_sparse_csr()) {
209 return mul_out_sparse_csr(t_.sparse_mask(src_), src_, r);
210 }
211 TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR");
212 Tensor t = t_.to_sparse();
213 Tensor src = src_.to_sparse();
214 Tensor tmp_result = t.mul(src);
215 auto r_sparse_csr = tmp_result.to_sparse_csr();
216 r.resize_as_sparse_(r_sparse_csr);
217 r.copy_(r_sparse_csr);
218 return r;
219 }
220
221 template <typename op_t>
intersection_binary_op_with_wrapped_scalar(const Tensor & sparse,const Tensor & scalar,const op_t & op)222 Tensor intersection_binary_op_with_wrapped_scalar(const Tensor& sparse, const Tensor& scalar, const op_t& op) {
223 // NOTE: intersection_binary_op_with_wrapped_scalar assumes scalar.numel() == 1.
224 const auto result_values = op(sparse.values(), scalar.squeeze()).to(at::result_type(sparse, scalar));
225 const auto result_sizes = infer_size(sparse.sizes(), scalar.sizes());
226 auto [compressed_indices, plain_indices] = getCompressedPlainIndices(sparse);
227 return at::_sparse_compressed_tensor_unsafe(
228 compressed_indices.clone(),
229 plain_indices.clone(),
230 result_values,
231 result_sizes,
232 sparse.options().dtype(result_values.scalar_type()));
233 }
234
235 template <typename op_t>
intersection_binary_op_with_wrapped_scalar_(Tensor & sparse,const Tensor & scalar,const string & op_name,const op_t & op)236 Tensor& intersection_binary_op_with_wrapped_scalar_(Tensor& sparse, const Tensor& scalar, const string& op_name, const op_t& op) {
237 // NOTE: intersection_binary_op_with_wrapped_scalar_ assumes scalar.numel() == 1.
238 const auto broadcasted_shape = infer_size(sparse.sizes(), scalar.sizes());
239 if (sparse.sizes() != broadcasted_shape) {
240 TORCH_CHECK(false, op_name, "(): output with shape ", sparse.sizes(), " does not match ",
241 "the broadcast shape ", broadcasted_shape);
242 }
243 auto values = sparse.values();
244 // Safe to use squeeze here, we already know that scalar safely broadcasts.
245 op(values, scalar.squeeze());
246 return sparse;
247 }
248
mul_sparse_csr(const Tensor & self,const Tensor & other)249 Tensor mul_sparse_csr(const Tensor& self, const Tensor& other) {
250 // Check if either of the arguments is a wrapped Scalar
251 if (self.layout() == kStrided && self.dim() == 0) {
252 return intersection_binary_op_with_wrapped_scalar(other, self, [](const Tensor& a, const Tensor& b) -> Tensor {
253 return a.mul(b);
254 });
255 }
256 if (other.layout() == kStrided && other.dim() == 0) {
257 return intersection_binary_op_with_wrapped_scalar(self, other, [](const Tensor& a, const Tensor& b) -> Tensor {
258 return a.mul(b);
259 });
260 }
261
262 if (self.is_sparse_csr() && other.layout() == kStrided) {
263 return mul_sparse_csr(self, other.sparse_mask(self));
264 }
265 if (self.layout() == kStrided && other.is_sparse_csr()) {
266 return mul_sparse_csr(self.sparse_mask(other), other);
267 }
268
269 auto commonDtype = at::result_type(self, other);
270 auto result_options = self.options().dtype(commonDtype);
271 // CSR is 2d!
272 Tensor result = at::empty({0, 0}, result_options);
273 return at::mul_out(result, self, other); // redispatch!
274 }
275
mul_sparse_csr_(Tensor & self,const Tensor & other)276 Tensor& mul_sparse_csr_(Tensor& self, const Tensor& other) {
277 if (other.layout() == kStrided && other.dim() == 0) {
278 return intersection_binary_op_with_wrapped_scalar_(self, other, "mul_", [](Tensor& a, const Tensor& b) -> Tensor& {
279 return a.mul_(b);
280 });
281 }
282 return at::mul_out(self, self, other); // redispatch!
283 }
284
285
286 namespace {
287
288 template <typename F>
get_result_tensor_for_unary_op(F op,const Tensor & input)289 inline Tensor get_result_tensor_for_unary_op(F op, const Tensor& input) {
290 auto values = input.values();
291
292 // To handle type promotion for inputs to unary ops,
293 // we first get the result from the underlined op, and use the result
294 // to create a sparse compressed tensor, which is used as the input to the out=
295 // variant
296 auto result_values = op(values);
297
298 auto compressed_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(),
299 "get_result_tensor_for_unary_op",
300 [&]{ return input.crow_indices(); },
301 [&]{ return input.ccol_indices(); });
302 auto plain_indices = AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(input.layout(),
303 "get_result_tensor_for_unary_op",
304 [&]{ return input.col_indices(); },
305 [&]{ return input.row_indices(); });
306
307 auto result = at::_sparse_compressed_tensor_unsafe(
308 compressed_indices.clone(),
309 plain_indices.clone(),
310 result_values,
311 input.sizes(),
312 input.options().dtype(result_values.scalar_type()));
313
314 return result;
315 }
316 } // namespace
317
normal_sparse_csr_(Tensor & self,double mean,double std,std::optional<Generator> gen)318 Tensor& normal_sparse_csr_(
319 Tensor& self,
320 double mean,
321 double std,
322 std::optional<Generator> gen) {
323 return unary_op_inplace(self, &Tensor::normal_, mean, std, gen);
324 }
325
fill_sparse_csr_(Tensor & self,const Scalar & value)326 Tensor& fill_sparse_csr_(Tensor& self, const Scalar& value) {
327 return unary_op_inplace(self, &TensorBase::fill_, value);
328 }
329
sparse_mask_sparse_compressed(const Tensor & self,const Tensor & mask)330 Tensor sparse_mask_sparse_compressed(
331 const Tensor& self,
332 const Tensor& mask) {
333 TORCH_CHECK(at::sparse_csr::is_sparse_compressed(mask),
334 "sparse_mask_sparse_compressed expects mask to have sparse compressed layout, got ", mask.layout());
335 TORCH_CHECK(
336 mask.sizes().equals(self.sizes()),
337 "sparse_mask(): operands have incompatible sizes; self has size ",
338 self.sizes(),
339 " but mask has size ",
340 mask.sizes());
341
342 if (self.is_same(mask)) {
343 return self;
344 }
345
346 if (!mask.numel() || !mask._nnz()) {
347 return mask.clone().to(self.device(), self.scalar_type());
348 }
349
350 if (self.layout() == kStrided) {
351 auto [compressed_indices, plain_indices] = at::sparse_csr::getCompressedPlainIndices(mask);
352 auto mask_values = mask.values();
353 auto dense_mask = at::_sparse_compressed_tensor_unsafe(
354 compressed_indices,
355 plain_indices,
356 at::ones({1}, self.options().dtype(kBool)).expand_as(mask_values),
357 self.sizes(),
358 self.options().dtype(kBool).layout(mask.layout())).to_dense();
359 return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
360 mask.layout(), "sparse_mask_sparse_compressed",
361 [&] {
362 return at::native::dense_to_sparse_with_mask(self, dense_mask, mask.layout(), {}, mask.dense_dim());
363 },
364 [&] {
365 auto blocksize = at::sparse_csr::getBlockSize(mask);
366 return at::native::dense_to_sparse_with_mask(self, dense_mask, mask.layout(), blocksize, mask.dense_dim());
367 });
368 } else if (self.layout() == mask.layout()) {
369 // TODO: keeping this for BC but the method used here may lead to
370 // incorrect indices.
371 return self.mul(at::ones_like(mask)).to(self.scalar_type());
372 } else {
373 // TODO: keeping this for BC but the method used here cannot
374 // support batch dimensions because sparse COO tensors are batch
375 // dimension ignorant.
376 return AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS(
377 mask.layout(), "sparse_mask_sparse_compressed",
378 [&] {
379 return self.sparse_mask(mask.to_sparse()).to_sparse(mask.layout());
380 },
381 [&] {
382 auto blocksize = at::sparse_csr::getBlockSize(mask);
383 return self.sparse_mask(mask.to_sparse()).to_sparse(mask.layout(), blocksize);
384 });
385 }
386 }
387
mul_scalar_sparse_csr(const Tensor & self,const Scalar & other)388 Tensor mul_scalar_sparse_csr(const Tensor& self, const Scalar& other) {
389 auto result_values = self.values().mul(other);
390 return at::native::_sparse_csr_tensor_unsafe(
391 self.crow_indices().clone(),
392 self.col_indices().clone(),
393 result_values,
394 self.sizes(),
395 result_values.scalar_type(),
396 self.layout(),
397 result_values.device());
398 }
399
zero_sparse_csr_(Tensor & self)400 Tensor& zero_sparse_csr_(Tensor& self) {
401 /*
402 csr.zero_() resets nnz to 0.
403
404 If the original sparsity pattern needs to be preserved, use
405 `csr.values().zero_()` instead.
406
407 The above behavior also implies that torch.zeros_like(csr) returns
408 a new tensor with nnz == 0. If one needs a zeros_like semantics
409 where the result has the same sparsity pattern as input, then use
410 `result = csr.clone(); result.values.zero_();`
411 */
412 AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(self.layout(), "zero_sparse_csr_", [](){});
413 get_sparse_csr_impl(self)->resize_and_clear_(self.sparse_dim(), self.dense_dim(), self.sizes());
414 return self;
415 }
416
417 /* Implementation of Unary Ufuncs, those supported for Sparse CSR Layout
418 * Only simple funcs, with 0->0 correspondence are currently supported. */
419
420 #define CREATE_UNARY_UFUNC_OUT(op_name) \
421 Tensor& op_name##_sparse_csr_out(const Tensor& self, Tensor& result) { \
422 return unary_op_out(&at::op_name##_outf, self, result); \
423 }
424
425 #define CREATE_UNARY_UFUNC_FUNCTIONAL(op_name) \
426 Tensor op_name##_sparse_csr(const Tensor& self) { \
427 return get_result_tensor_for_unary_op(&at::op_name, self); \
428 }
429
430 #define CREATE_UNARY_UFUNC_INPLACE(op_name) \
431 Tensor& op_name##_sparse_csr_(Tensor& self) { \
432 return unary_op_inplace(self, &Tensor::op_name##_); \
433 }
434
435 #define CREATE_UNARY_UFUNC(op_name) \
436 CREATE_UNARY_UFUNC_OUT(op_name); \
437 CREATE_UNARY_UFUNC_FUNCTIONAL(op_name); \
438 CREATE_UNARY_UFUNC_INPLACE(op_name);
439
440 #define CREATE_UNARY_UFUNC_NO_INPLACE(op_name) \
441 CREATE_UNARY_UFUNC_OUT(op_name); \
442 CREATE_UNARY_UFUNC_FUNCTIONAL(op_name);
443
444 // Exhaustive list of the unary ufuncs supported by sparse compressed
445 CREATE_UNARY_UFUNC(abs);
446 CREATE_UNARY_UFUNC(asin);
447 CREATE_UNARY_UFUNC(asinh);
448 CREATE_UNARY_UFUNC(atan);
449 CREATE_UNARY_UFUNC(atanh);
450 CREATE_UNARY_UFUNC(ceil);
451 CREATE_UNARY_UFUNC(deg2rad);
452 CREATE_UNARY_UFUNC(erf);
453 CREATE_UNARY_UFUNC(erfinv);
454 CREATE_UNARY_UFUNC(expm1);
455 CREATE_UNARY_UFUNC(floor);
456 CREATE_UNARY_UFUNC(frac);
457 CREATE_UNARY_UFUNC(log1p);
458 CREATE_UNARY_UFUNC(neg);
459 CREATE_UNARY_UFUNC(rad2deg);
460 CREATE_UNARY_UFUNC(sign);
461 CREATE_UNARY_UFUNC(sin);
462 CREATE_UNARY_UFUNC(sinh);
463 CREATE_UNARY_UFUNC(sgn);
464 CREATE_UNARY_UFUNC(sqrt);
465 CREATE_UNARY_UFUNC(tan);
466 CREATE_UNARY_UFUNC(tanh);
467 CREATE_UNARY_UFUNC(trunc);
468 CREATE_UNARY_UFUNC(conj_physical);
469
470 C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")
471 static CREATE_UNARY_UFUNC(relu);
C10_DIAGNOSTIC_POP()472 C10_DIAGNOSTIC_POP()
473
474 // With addition of `round.decimals` overload, using CREATE_UNARY_UFUNC leads
475 // to unresolved overload.
476 Tensor& round_sparse_csr_out(const Tensor& self, Tensor& result) {
477 return unary_op_out(&at::_ops::round_out::call, self, result);
478 }
479
round_sparse_csr(const Tensor & self)480 Tensor round_sparse_csr(const Tensor& self) {
481 return get_result_tensor_for_unary_op(&at::_ops::round::call, self);
482 }
483
round_sparse_csr_(Tensor & self)484 Tensor& round_sparse_csr_(Tensor& self) {
485 TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
486 self.values().round_();
487 return self;
488 }
489
threshold_backward_sparse_compressed(const Tensor & grad_output,const Tensor & self,const Scalar & threshold)490 Tensor threshold_backward_sparse_compressed(
491 const Tensor& grad_output,
492 const Tensor& self,
493 const Scalar& threshold) {
494 return get_result_tensor_for_unary_op(
495 [&](const Tensor& t) {
496 return at::threshold_backward(t, self.values(), threshold);
497 },
498 grad_output);
499 }
500
threshold_backward_sparse_compressed_out(const Tensor & grad_output,const Tensor & self,const Scalar & threshold,Tensor & grad_input)501 Tensor& threshold_backward_sparse_compressed_out(
502 const Tensor& grad_output,
503 const Tensor& self,
504 const Scalar& threshold,
505 Tensor& grad_input) {
506 return unary_op_out(
507 [&](const Tensor& t, Tensor& out) {
508 return at::threshold_backward_outf(t, self.values(), threshold, out);
509 },
510 grad_output,
511 grad_input);
512 }
513
514 // angle, isneginf, isposinf and signbit currently don't have an inplace variant
515 CREATE_UNARY_UFUNC_NO_INPLACE(angle);
516 CREATE_UNARY_UFUNC_NO_INPLACE(isneginf);
517 CREATE_UNARY_UFUNC_NO_INPLACE(isposinf);
518 CREATE_UNARY_UFUNC_NO_INPLACE(signbit);
519
520 // isnan and isinf don't have an out variant
521 CREATE_UNARY_UFUNC_FUNCTIONAL(isnan);
522 CREATE_UNARY_UFUNC_FUNCTIONAL(isinf);
523
524 template <typename scalar_t>
addmm_out_sparse_csr_native_cpu(const Tensor & sparse,const Tensor & dense,const Tensor & r,Scalar alpha,Scalar beta)525 void addmm_out_sparse_csr_native_cpu(
526 const Tensor& sparse,
527 const Tensor& dense,
528 const Tensor& r,
529 Scalar alpha,
530 Scalar beta) {
531 auto dim_i = sparse.size(0);
532 auto dim_k = dense.size(1);
533
534 auto csr = sparse.crow_indices();
535 auto col_indices = sparse.col_indices();
536 auto values = sparse.values();
537
538 scalar_t cast_alpha = alpha.to<scalar_t>();
539 r.mul_(beta);
540 AT_DISPATCH_INDEX_TYPES(
541 col_indices.scalar_type(), "csr_mm_crow_indices", [&]() {
542 auto csr_accessor = csr.accessor<index_t, 1>();
543 auto col_indices_accessor = col_indices.accessor<index_t, 1>();
544
545 auto values_accessor = values.accessor<scalar_t, 1>();
546 scalar_t* dense_ptr = dense.data_ptr<scalar_t>();
547 scalar_t* r_ptr = r.data_ptr<scalar_t>();
548
549 int64_t dense_stride0 = dense.stride(0);
550 int64_t dense_stride1 = dense.stride(1);
551 int64_t r_stride0 = r.stride(0);
552 int64_t r_stride1 = r.stride(1);
553
554 at::parallel_for(
555 0,
556 dim_i,
557 internal::GRAIN_SIZE,
558 [&](int64_t irow_start, int64_t irow_end) {
559 for (index_t h = irow_start; h < irow_end; ++h) {
560 index_t i_start = csr_accessor[h];
561 index_t i_end = csr_accessor[h + 1];
562 for (index_t i = i_start; i < i_end; i++) {
563 scalar_t val = values_accessor[i];
564 index_t col = col_indices_accessor[i];
565 at::native::cpublas::axpy<scalar_t>(
566 dim_k,
567 cast_alpha * val,
568 dense_ptr + col * dense_stride0,
569 dense_stride1,
570 r_ptr + h * r_stride0,
571 r_stride1);
572 }
573 }
574 });
575 });
576 }
577
578 // Functions for matrix multiplication.
579 // result = beta * self + alpha (mat1 @ mat2)
addmm_out_sparse_compressed_cpu(const Tensor & self,const Tensor & mat1,const Tensor & mat2,const Scalar & beta,const Scalar & alpha,Tensor & result)580 Tensor& addmm_out_sparse_compressed_cpu(
581 const Tensor& self,
582 const Tensor& mat1,
583 const Tensor& mat2,
584 const Scalar& beta,
585 const Scalar& alpha,
586 Tensor& result) {
587 // All the checks are from addmm_out_cuda_impl (ATen/native/cuda/Blas.cpp) and
588 // TORCH_META_FUNC(addmm) (ATen/native/LinearAlgebra.cpp)
589 // TODO: remove code duplication and unify code
590 sparse::impl::_check_dim(mat1, 2, "mat1");
591 sparse::impl::_check_dim(mat2, 2, "mat2");
592
593 TORCH_CHECK(
594 mat1.size(1) == mat2.size(0), "mat1 and mat2 shapes cannot be multiplied (",
595 mat1.size(0), "x", mat1.size(1), " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");
596
597 c10::MaybeOwned<at::Tensor> self_;
598 // Don't expand self if this is an in-place operation
599 if (&result == &self) {
600 self_ = c10::MaybeOwned<Tensor>::borrowed(self);
601 } else {
602 self_ = expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");
603 }
604
605
606 TORCH_CHECK(((self_->dim() == 2) &&
607 (self_->size(0) == mat1.size(0)) &&
608 (self_->size(1) == mat2.size(1))),
609 "The input tensor must be a matrix with size ",
610 mat1.size(0),
611 "x",
612 mat2.size(1),
613 ", but got a ",
614 self_->dim(),
615 "-D tensor with size ",
616 self_->size(0),
617 "x",
618 self_->size(1));
619
620 if (&result != &self) {
621 if (result.layout() == kStrided) {
622 at::native::resize_output(result, self_->sizes());
623 } else {
624 result.resize_as_sparse_(*self_);
625 }
626 result.copy_(*self_);
627 }
628
629 if (result.numel() == 0) {
630 // If result gets resized and is sparse compressed,
631 // it's compressed_indices tensor will contain junk values
632 // so the whole tensor is not a valid compressed tensor.
633 // To combat that, result needs to get zeroed out.
634 if (at::sparse_csr::is_sparse_compressed(result)) {
635 result.zero_();
636 }
637 return result;
638 }
639
640 if (sparse::impl::_is_sparse_and_zero(mat1) || sparse::impl::_is_sparse_and_zero(mat2)) {
641 // According to docs, when beta==0 values in self should be ignored.
642 // nans and infs should not propagate
643 if (beta.toComplexDouble() == 0.) {
644 result.zero_();
645 } else {
646 result.mul_(beta);
647 }
648 return result;
649 }
650
651 #if !AT_USE_MKL_SPARSE()
652 // The custom impl addmm_out_sparse_csr_native_cpu only supports CSR @
653 // strided -> strided
654 if (mat1.layout() == kStrided) {
655 if (mat2.layout() == kSparseCsr) {
656 if (result.layout() == kStrided) {
657 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
658 result.scalar_type(), "addmm_sparse_dense", [&] {
659 addmm_out_sparse_csr_native_cpu<scalar_t>(
660 mat2.transpose(-2, -1).to_sparse_csr(),
661 mat1.transpose(-2, -1),
662 result.transpose(-2, -1),
663 alpha,
664 beta);
665 });
666 return result;
667 }
668 }
669 if (mat2.layout() == kSparseCsc) {
670 if (result.layout() == kStrided) {
671 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
672 result.scalar_type(), "addmm_sparse_dense", [&] {
673 addmm_out_sparse_csr_native_cpu<scalar_t>(
674 mat2.transpose(-2, -1),
675 mat1.transpose(-2, -1),
676 result.transpose(-2, -1),
677 alpha,
678 beta);
679 });
680 return result;
681 }
682 }
683 } else if (mat1.layout() == kSparseCsr) {
684 if (mat2.layout() == kStrided) {
685 if (result.layout() == kStrided) {
686 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
687 result.scalar_type(), "addmm_sparse_dense", [&] {
688 addmm_out_sparse_csr_native_cpu<scalar_t>(
689 mat1, mat2, result, alpha, beta);
690 });
691 return result;
692 }
693 }
694 } else if (mat1.layout() == kSparseCsc) {
695 if (mat2.layout() == kStrided) {
696 if (result.layout() == kStrided) {
697 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
698 result.scalar_type(), "addmm_sparse_dense", [&] {
699 addmm_out_sparse_csr_native_cpu<scalar_t>(
700 mat1.to_sparse_csr(), mat2, result, alpha, beta);
701 });
702 return result;
703 }
704 }
705 }
706 TORCH_CHECK(
707 false,
708 "addmm: computation on CPU is not implemented for ",
709 result.layout(),
710 " + ",
711 mat1.layout(),
712 " @ ",
713 mat2.layout(),
714 " without MKL. PyTorch built with MKL has better support for addmm with sparse CPU tensors.");
715 #else
716 sparse::impl::mkl::addmm_out_sparse_csr(mat1, mat2, beta, alpha, result);
717 #endif
718 return result;
719 }
720
addmm_sparse_compressed_dense(const Tensor & self,const SparseCsrTensor & sparse,const Tensor & dense,const Scalar & beta,const Scalar & alpha)721 Tensor addmm_sparse_compressed_dense(
722 const Tensor& self,
723 const SparseCsrTensor& sparse,
724 const Tensor& dense,
725 const Scalar& beta,
726 const Scalar& alpha) {
727 Tensor r = at::empty({0, 0}, self.options());
728 at::addmm_out(r, self, sparse, dense, beta, alpha);
729 return r;
730 }
731
_sparse_csr_mm_out(const Tensor & mat1,const Tensor & mat2,Tensor & result)732 Tensor& _sparse_csr_mm_out(
733 const Tensor& mat1,
734 const Tensor& mat2,
735 Tensor& result) {
736 auto zero = at::zeros_like(result);
737 return at::addmm_out(result, zero, mat1, mat2, 0.0, 1.0);
738 }
739
_sparse_csr_mm(const Tensor & mat1,const Tensor & mat2)740 Tensor _sparse_csr_mm(const Tensor& mat1, const Tensor& mat2) {
741 if (mat1.is_sparse_csr() && mat2.is_sparse_csr()) {
742 // Return sparse
743 return at::addmm(
744 at::zeros({mat1.size(0), mat2.size(1)}, mat2.options()),
745 mat1,
746 mat2,
747 0.0,
748 1.0);
749 }
750 if ((mat1.layout() == kSparseCsc || mat1.layout() == kSparseCsr) &&
751 (mat2.layout() == kSparseCsc || mat2.layout() == kSparseCsr)) {
752 // TODO: Expensive conversion to CSR. Should add native support for CSC.
753 // Covers CSC @ CSR
754 // Covers CSR @ CSC
755 // Covers CSC @ CSC
756 return _sparse_csr_mm(mat1.to_sparse_csr(), mat2.to_sparse_csr());
757 }
758 if (mat1.layout() == kSparseCsc && mat2.layout() == c10::kStrided) {
759 // TODO: This is a costly conversion. We should have
760 // native support for CSC.
761 return _sparse_csr_mm(mat1.to_sparse_csr(), mat2);
762 }
763 // Default to taking options from mat1
764 auto result_options = mat1.options();
765 if (mat2.layout() == kStrided) {
766 // if either arg is strided we return strided, so update the options if
767 // mat2 is strided.
768 result_options = result_options.layout(kStrided);
769 }
770 return at::addmm(
771 at::zeros({mat1.size(0), mat2.size(1)}, result_options),
772 mat1,
773 mat2,
774 0.0,
775 1.0);
776 }
777
778 // Functions for element-wise addition.
add_sparse_csr(const Tensor & self,const Tensor & other,const Scalar & alpha)779 Tensor add_sparse_csr(
780 const Tensor& self,
781 const Tensor& other,
782 const Scalar& alpha) {
783 auto commonDtype = at::result_type(self, other);
784 alpha_check(commonDtype, alpha);
785 Tensor result;
786 if (self.layout() != kStrided && other.layout() == kStrided) {
787 // add(sparse, dense) -> dense
788 result = at::empty_like(
789 other,
790 other.options()
791 .dtype(commonDtype)
792 .memory_format(at::MemoryFormat::Contiguous));
793 } else {
794 // add(dense, sparse) -> dense AND add(sparse, sparse) -> sparse
795 result = at::empty_like(
796 self,
797 self.options()
798 .dtype(commonDtype)
799 .memory_format(at::MemoryFormat::Contiguous));
800 }
801 return at::add_out(result, self, other, alpha); // redispatch!
802 }
803
add_sparse_csr_(Tensor & self,const Tensor & other,const Scalar & alpha)804 Tensor& add_sparse_csr_(
805 Tensor& self,
806 const Tensor& other,
807 const Scalar& alpha) {
808 return at::add_out(self, self, other, alpha); // redispatch!
809 }
810
add_out_dense_sparse_compressed_cpu(const Tensor & out,const Tensor & dense,const SparseCsrTensor & src,const Scalar & alpha)811 static void add_out_dense_sparse_compressed_cpu(
812 const Tensor& out,
813 const Tensor& dense,
814 const SparseCsrTensor& src,
815 const Scalar& alpha) {
816 TORCH_INTERNAL_ASSERT(dense.layout() == kStrided);
817 TORCH_INTERNAL_ASSERT(
818 src.layout() == kSparseCsr || src.layout() == kSparseCsc);
819 TORCH_INTERNAL_ASSERT(dense.device() == kCPU || dense.device() == kMeta);
820
821 TORCH_CHECK(
822 out.is_contiguous(),
823 "out argument must be contiguous, but got: ",
824 out.suggest_memory_format());
825 TORCH_CHECK(
826 out.device() == dense.device(),
827 "add: expected 'out' to match dense tensor, but got tensor on device: ",
828 out.device());
829 TORCH_CHECK(
830 src.device() == dense.device(),
831 "add: expected 'src' to match dense tensor, but got tensor on device: ",
832 src.device());
833
834 TORCH_CHECK(
835 dense.sizes().equals(src.sizes()),
836 "add: expected 'self' and 'other' to have same size, but self has size ",
837 dense.sizes(),
838 " while other has size ",
839 src.sizes(),
840 " (FYI: op2-sparse addition does not currently support broadcasting)");
841
842 auto commonDtype = promoteTypes(dense.scalar_type(), src.scalar_type());
843 TORCH_CHECK(
844 canCast(commonDtype, out.scalar_type()),
845 "Can't convert result type ",
846 commonDtype,
847 " to output ",
848 out.scalar_type(),
849 " in add operation");
850
851 auto src_values = src.values();
852
853 resize_output(out, dense.sizes());
854
855 Tensor resultBuffer = out;
856
857 if (out.scalar_type() != commonDtype) {
858 resultBuffer = dense.to(commonDtype);
859 } else if (!is_same_tensor(out, dense)) {
860 resultBuffer.copy_(dense);
861 }
862
863 if (src._nnz() == 0) {
864 return;
865 }
866
867 TORCH_INTERNAL_ASSERT(dense.device() == kCPU);
868
869 auto valuesBuffer = src_values.to(commonDtype).reshape({-1, src_values.size(-1)});
870 resultBuffer = resultBuffer.view({-1, out.size(-2), out.size(-1)});
871 Tensor src_compressed_indices;
872 Tensor src_plain_indices;
873 std::tie(src_compressed_indices, src_plain_indices) =
874 at::sparse_csr::getCompressedPlainIndices(src);
875 src_compressed_indices =
876 src_compressed_indices.reshape({-1, src_compressed_indices.size(-1)});
877 src_plain_indices =
878 src_plain_indices.reshape({-1, src_plain_indices.size(-1)});
879 auto src_layout = src.layout();
880
881 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
882 kComplexHalf,
883 kHalf,
884 kBool,
885 kBFloat16,
886 commonDtype,
887 "add_out_op2_sparse_csr",
888 [&valuesBuffer,
889 &resultBuffer,
890 &alpha,
891 &src_compressed_indices,
892 &src_plain_indices,
893 &src_layout]() {
894 AT_DISPATCH_INDEX_TYPES(
895 src_compressed_indices.scalar_type(),
896 "csr_add_out_crow_indices",
897 [&valuesBuffer,
898 &resultBuffer,
899 &alpha,
900 &src_compressed_indices,
901 &src_plain_indices,
902 &src_layout]() {
903 auto batch_count =
904 resultBuffer.dim() > 2 ? resultBuffer.size(-3) : 1;
905 auto values_accessor = valuesBuffer.accessor<scalar_t, 2>();
906 scalar_t* out_ptr = resultBuffer.data_ptr<scalar_t>();
907 scalar_t cast_value = alpha.to<scalar_t>();
908
909 auto compressed_indices_accessor =
910 src_compressed_indices.accessor<index_t, 2>();
911 auto plain_indices_accessor =
912 src_plain_indices.accessor<index_t, 2>();
913 auto out_strides = resultBuffer.strides();
914 auto const out_stride_batch = out_strides[0];
915 auto const out_stride_compressed =
916 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
917 src_layout,
918 "add_out_dense_sparse_compressed_cpu",
919 [&out_strides] { return out_strides[1]; },
920 [&out_strides] { return out_strides[2]; });
921 auto const out_stride_plain =
922 AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS(
923 src_layout,
924 "add_out_dense_sparse_compressed_cpu",
925 [&out_strides] { return out_strides[2]; },
926 [&out_strides] { return out_strides[1]; });
927
928 for (const auto batch_idx : c10::irange(batch_count)) {
929 for (const auto i_compressed :
930 c10::irange(src_compressed_indices.size(-1) - 1)) {
931 index_t start_index =
932 compressed_indices_accessor[batch_idx][i_compressed];
933 index_t end_index =
934 compressed_indices_accessor[batch_idx][i_compressed + 1];
935 for (const auto i : c10::irange(start_index, end_index)) {
936 auto i_plain = plain_indices_accessor[batch_idx][i];
937 auto index = batch_idx * out_stride_batch +
938 i_compressed * out_stride_compressed +
939 i_plain * out_stride_plain;
940 out_ptr[index] +=
941 cast_value * values_accessor[batch_idx][i];
942 }
943 }
944 }
945 });
946 });
947 if (out.scalar_type() != commonDtype) {
948 out.copy_(resultBuffer);
949 }
950 }
951
add_out_sparse_compressed_cpu(const Tensor & self,const SparseCsrTensor & other,const Scalar & alpha,SparseCsrTensor & out)952 Tensor& add_out_sparse_compressed_cpu(
953 const Tensor& self,
954 const SparseCsrTensor& other,
955 const Scalar& alpha,
956 SparseCsrTensor& out) {
957 if (self.layout() == kStrided) {
958 add_out_dense_sparse_compressed_cpu(out, self, other, alpha);
959 } else if (other.layout() == kStrided) {
960 add_out_dense_sparse_compressed_cpu(out, other, self, alpha);
961 } else {
962 TORCH_CHECK(
963 self.sizes().equals(other.sizes()),
964 "torch.add: Expected input tensors to have the same shape, but got tensor `self` with shape ",
965 self.sizes(),
966 " and tensor `other` with shape ",
967 other.sizes());
968
969 if (only_sparse_compressed_add_trivial_cases(self, other, alpha, out)) {
970 return out;
971 }
972
973 at::native::resize_as_sparse_compressed_(out, self);
974 sparse::impl::cpu::add_out_sparse_csr(self, other, alpha, out);
975 }
976 return out;
977 }
978
979 /*
980 Reductions on sparse CSR tensors using masked semantics.
981
982 - A CSR tensor is a 2D tensor that is specified by a 3-tuple
983 (crow_indices, col_indices, values).
984
985 - To support a reduction operator on a CSR tensor, define:
986
987 template <typename scalar_t>
988 struct Reduction...Op {
989 inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
990 return a ... b;
991 }
992 inline scalar_t identity() const { return ...; }
993 };
994
995 Tensor _sparse_csr_..._cpu(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
996 ...
997 result = reduce_sparse_csr_cpu_template<scalar_t>(input_, dims_to_sum, keepdim, Reduction...Op<scalar_t>());
998 ...
999 return result;
1000 }
1001
1002 and add the following
1003
1004 - func: _sparse_csr_op.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
1005 dispatch:
1006 SparseCsrCUDA: _sparse_csr_..._cpu
1007
1008 to native_functions.yaml
1009
1010 Use ReductionAddOp and _sparse_csr_sum implementation as an example.
1011
1012 - Since a CSR tensor dimensionality is always 2, only reductions
1013 with keepdim=True can be supported.
1014
1015 */
1016
1017 namespace {
1018
1019 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_dim0_cpu_template(const Tensor & sparse,ReductionOp rop)1020 Tensor reduce_sparse_csr_dim0_cpu_template(const Tensor& sparse, ReductionOp rop) {
1021 /*
1022 Consider the following sparse tensor:
1023
1024 1 * * * *
1025 * * * 2 *
1026 * * 3 * *
1027 * * * * *
1028 4 * 5 * *
1029
1030 that has CSR representation
1031
1032 crow_indices = [0, 1, 2, 3, 3, 5]
1033 col_indices = [0, 3, 2, 0, 2]
1034 values = [1, 2, 3, 4, 5]
1035
1036 Reduction with dim=0 results:
1037
1038 rop(1,4) * rop(3,5) 2 *
1039
1040 that has CSR representation
1041
1042 new_crow_indices = [0, 3]
1043 new_col_indices = [0, 2, 3]
1044 new_values = [rop(1, 4], rop(3, 5), 2]
1045
1046 In general, the CSR representation data can be computed as follows:
1047
1048 new_col_indices, col_map = col_indices.unique(sorted=True, return_inverse=True)
1049 nnz = new_col_indices.numel()
1050 new_crow_indices = [0, nnz]
1051 new_values.resize(nnz); new_values.fill_(identity)
1052 for i in range(col_indices.numel()):
1053 new_values[col_map[i]] = rop(new_values[col_map[i], values[i])
1054 */
1055
1056 Tensor col_indices = sparse.col_indices();
1057 Tensor values = sparse.values();
1058 auto numel = values.numel();
1059
1060 /*
1061 Calling at::_unique constitutes the main bottleneck of this
1062 function. However, it is still about 5x faster than using the
1063 invariant:
1064 csr.sum(dim=0) == csr.transpose(0, 1).sum(dim=1)
1065 */
1066 auto [new_col_indices, columns_map] = at::_unique(col_indices, true, true);
1067 auto nnz = new_col_indices.numel();
1068
1069 Tensor new_crow_indices = at::empty({2}, col_indices.options());
1070 new_crow_indices[0] = 0;
1071 new_crow_indices[1] = nnz;
1072
1073 // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1074 // of float should be float in current scenario. In CUDA, float is the accumulate type
1075 // of float, while in CPU, double is the accumulate type of float.
1076 using acc_t = at::acc_type<scalar_t, true>;
1077 auto acc_buffer = at::sparse_csr::create_acc_buffer<acc_t, scalar_t>(
1078 values.options(), values.scalar_type(), nnz);
1079 Tensor new_values = std::get<0>(acc_buffer);
1080 Tensor new_values_acc = std::get<1>(acc_buffer);
1081 new_values_acc.fill_(rop.identity());
1082
1083 int64_t* columns_map_ptr = columns_map.data_ptr<int64_t>();
1084 scalar_t* values_ptr = values.data_ptr<scalar_t>();
1085 acc_t* new_values_acc_ptr =
1086 new_values_acc.data_ptr<acc_t>();
1087
1088 // There is no point in parallelizing the following for-loop
1089 // because about 99.3% of the computation time is spent in the
1090 // at::_unique call above.
1091 for (const auto i : c10::irange(numel)) {
1092 int64_t col = columns_map_ptr[i];
1093 scalar_t val = values_ptr[i];
1094 new_values_acc_ptr[col] = rop(new_values_acc_ptr[col], static_cast<acc_t>(val));
1095 }
1096 copy_from_acc_buffer(new_values, new_values_acc);
1097
1098 return at::native::_sparse_csr_tensor_unsafe(new_crow_indices, new_col_indices, new_values,
1099 {1, sparse.size(1)},
1100 new_values.scalar_type(),
1101 sparse.layout(),
1102 new_values.device());
1103 }
1104
1105 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_dim1_cpu_template(const Tensor & sparse,ReductionOp rop)1106 Tensor reduce_sparse_csr_dim1_cpu_template(const Tensor& sparse, ReductionOp rop) {
1107 /*
1108 Consider the following sparse tensor:
1109
1110 1 * * * *
1111 * * * 2 *
1112 * * 3 * *
1113 * * * * *
1114 4 * 5 * *
1115
1116 that has CSR representation
1117
1118 crow_indices = [0, 1, 2, 3, 3, 5]
1119 col_indices = [0, 3, 2, 0, 2]
1120 values = [1, 2, 3, 4, 5]
1121
1122 Reduction with dim=1 results:
1123
1124 1
1125 2
1126 3
1127 *
1128 rop(4, 5)
1129
1130 that has CSR representation
1131
1132 new_crow_indices = [0, 1, 2, 3, 3, 4]
1133 new_col_indices = [0, 0, 0, 0]
1134 new_values = [1, 2, 3, rop(4, 5)]
1135
1136 In general, the result CSR data can be computed as follows:
1137
1138 new_crow_indices = [0]
1139 for i in range(1, nrows+1):
1140 new_crow_indices[i] = new_crow_indices[i-1] + (crow_indices[i] == crow_indices[i-1])
1141 nnz = new_crow_indices[-1]
1142 new_col_indices = zeros(nnz)
1143 new_values.resize(nnz)
1144 j = -1
1145 for i in range(1, nrows+1):
1146 if crow_indices[i] == crow_indices[i-1]:
1147 continue
1148 j += 1
1149 new_values[j] = rop(values[crow_indices[i] : crow_indices[i-1]])
1150 */
1151
1152 Tensor crow_indices = sparse.crow_indices();
1153 auto ioptions = crow_indices.options();
1154 Tensor values = sparse.values();
1155 auto nrows = sparse.size(0);
1156
1157 Tensor new_crow_indices = at::empty({crow_indices.numel()}, ioptions);
1158 Tensor new_col_indices = at::empty({}, ioptions);
1159 Tensor row_map = at::empty({nrows}, ioptions);
1160
1161 // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1162 // of float should be float in current scenario. In CUDA, float is the accumulate type
1163 // of float, while in CPU, double is the accumulate type of float.
1164 using acc_t = at::acc_type<scalar_t, true>;
1165 auto acc_buffer = at::sparse_csr::create_acc_buffer<acc_t, scalar_t>(
1166 values.options(), values.scalar_type());
1167 Tensor new_values = std::get<0>(acc_buffer);
1168 Tensor new_values_acc = std::get<1>(acc_buffer);
1169
1170 AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "reduce_sparse_csr_dim1_cpu_indices",
1171 [&]() {
1172 index_t* crow_indices_ptr = crow_indices.data_ptr<index_t>();
1173 index_t* new_crow_indices_ptr = new_crow_indices.data_ptr<index_t>();
1174 index_t* row_map_ptr = row_map.data_ptr<index_t>();
1175 int64_t nnz = 0;
1176 new_crow_indices_ptr[0] = 0;
1177 for(int64_t i=0; i<nrows; i++) {
1178 if (crow_indices_ptr[i] != crow_indices_ptr[i + 1]) {
1179 row_map_ptr[i] = nnz;
1180 nnz++;
1181 }
1182 new_crow_indices_ptr[i + 1] = nnz;
1183 }
1184 new_col_indices.resize_(nnz);
1185 new_col_indices.fill_(index_t(0));
1186 new_values.resize_(nnz);
1187 new_values_acc.resize_(nnz);
1188
1189 scalar_t* values_ptr = values.data_ptr<scalar_t>();
1190 acc_t* new_values_acc_ptr = new_values_acc.data_ptr<acc_t>();
1191
1192 at::parallel_for(
1193 0,
1194 nrows,
1195 internal::GRAIN_SIZE,
1196 [&](int64_t irow_start, int64_t irow_end) {
1197 index_t i_end = crow_indices_ptr[irow_start];
1198 for (index_t h = irow_start; h < irow_end; ++h) {
1199 index_t i_start = i_end;
1200 i_end = crow_indices_ptr[h+1];
1201 if (i_start != i_end) {
1202 acc_t res = static_cast<acc_t>(values_ptr[i_start]);
1203 for (index_t i = i_start + 1; i < i_end; i++) {
1204 res = rop(res, static_cast<acc_t>(values_ptr[i]));
1205 }
1206 new_values_acc_ptr[row_map_ptr[h]] = res;
1207 }
1208 }
1209 });
1210 });
1211
1212 copy_from_acc_buffer(new_values, new_values_acc);
1213
1214 return at::native::_sparse_csr_tensor_unsafe(new_crow_indices, new_col_indices, new_values,
1215 {sparse.size(0), 1},
1216 new_values.scalar_type(),
1217 sparse.layout(),
1218 new_values.device());
1219 }
1220
1221 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_dim01_cpu_template(const Tensor & sparse,ReductionOp rop)1222 Tensor reduce_sparse_csr_dim01_cpu_template(const Tensor& sparse, ReductionOp rop) {
1223
1224 auto ioptions = sparse.col_indices().options();
1225 Tensor values = sparse.values();
1226 auto numel = values.numel();
1227 auto nnz = std::min<int64_t>(1, numel);
1228
1229 /* TODO: we can likely do about 3x better than parallel_reduce:
1230
1231 In [2]: t=torch.randn(5000, 5000).to_sparse_csr()
1232
1233 In [3]: %timeit torch._sparse_csr_sum(t, dim=(0, 1), keepdim=True)
1234 3.39 ms ± 898 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)
1235
1236 In [4]: %timeit torch.sum(t.values())
1237 1.07 ms ± 291 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
1238 */
1239
1240 // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1241 // of float should be float in current scenario. In CUDA, float is the accumulate type
1242 // of float, while in CPU, double is the accumulate type of float.
1243 using acc_t = at::acc_type<scalar_t, true>;
1244 scalar_t* values_ptr = values.data_ptr<scalar_t>();
1245 acc_t value = at::parallel_reduce(
1246 0,
1247 numel,
1248 internal::GRAIN_SIZE,
1249 rop.identity(),
1250 [&](int64_t i_start, int64_t i_end, scalar_t identity) {
1251 acc_t res = acc_t(identity);
1252 for (int64_t i=i_start; i<i_end; i++) {
1253 acc_t val = acc_t(values_ptr[i]);
1254 res = rop(res, val);
1255 }
1256 return res;
1257 }, rop
1258 );
1259
1260 Tensor new_col_indices = at::zeros({nnz}, ioptions);
1261 Tensor new_crow_indices = at::tensor(ArrayRef<int64_t>{0, nnz}, ioptions);
1262 Tensor new_values;
1263 auto result_dtype = at::isIntegralType(values.scalar_type(), /*includeBool=*/true) ? ScalarType::Long : values.scalar_type();
1264 if (numel > 0) {
1265 new_values = at::empty({1}, values.options().dtype(result_dtype));
1266 new_values.fill_(value);
1267 } else {
1268 new_values = at::empty({}, values.options().dtype(result_dtype));
1269 }
1270 return at::native::_sparse_csr_tensor_unsafe(new_crow_indices, new_col_indices, new_values,
1271 {1, std::min<int64_t>(1, sparse.size(1))},
1272 new_values.scalar_type(),
1273 sparse.layout(),
1274 new_values.device());
1275 }
1276
1277 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_cpu_template(const Tensor & sparse,std::vector<int64_t> dims,ReductionOp rop)1278 Tensor reduce_sparse_csr_cpu_template(const Tensor& sparse, std::vector<int64_t> dims, ReductionOp rop) {
1279 if (dims.size() == 1) {
1280 if (dims[0] == 0) {
1281 return reduce_sparse_csr_dim0_cpu_template<scalar_t>(sparse, rop);
1282 } else {
1283 TORCH_INTERNAL_ASSERT(dims[0] == 1);
1284 return reduce_sparse_csr_dim1_cpu_template<scalar_t>(sparse, rop);
1285 }
1286 } else if (dims.size() == 2) {
1287 TORCH_INTERNAL_ASSERT(((dims[0] == 0 && dims[1] == 1) || (dims[0] == 1 && dims[1] == 0)));
1288 return reduce_sparse_csr_dim01_cpu_template<scalar_t>(sparse, rop);
1289 }
1290 TORCH_INTERNAL_ASSERT(dims.empty());
1291 // effective after gh-29137 has been resolved
1292 return sparse.clone();
1293 }
1294
1295 template <typename scalar_t, typename ReductionOp>
reduce_sparse_csr_cpu_template(const Tensor & sparse,IntArrayRef dims_to_sum,bool keepdim,ReductionOp rop)1296 Tensor reduce_sparse_csr_cpu_template(const Tensor& sparse, IntArrayRef dims_to_sum, bool keepdim, ReductionOp rop) {
1297 TORCH_INTERNAL_ASSERT(sparse.is_sparse_csr());
1298 TORCH_CHECK(keepdim, "reduction operations on CSR tensors with keepdim=False is unsupported");
1299 TORCH_INTERNAL_ASSERT(sparse.device() == kCPU);
1300
1301 const int64_t input_dim = sparse.dim();
1302 TORCH_INTERNAL_ASSERT(input_dim == 2);
1303 auto dims = dims_to_sum.vec();
1304 maybe_wrap_dims(dims, input_dim);
1305 if (dims.empty()) {
1306 // after gh-29137 is resolved, delete this if-block
1307 dims.emplace_back(0);
1308 dims.emplace_back(1);
1309 }
1310 return reduce_sparse_csr_cpu_template<scalar_t>(sparse, dims, rop);
1311 }
1312
1313 template <typename scalar_t>
1314 struct ReductionAddOp {
operator ()at::native::__anon0d9558831e11::ReductionAddOp1315 inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
1316 return a + b;
1317 }
identityat::native::__anon0d9558831e11::ReductionAddOp1318 inline scalar_t identity() const { return 0; }
1319 };
1320
1321 template <typename scalar_t>
1322 struct ReductionMulOp {
operator ()at::native::__anon0d9558831e11::ReductionMulOp1323 inline scalar_t operator()(const scalar_t& a, const scalar_t& b) const {
1324 return a * b;
1325 }
identityat::native::__anon0d9558831e11::ReductionMulOp1326 inline scalar_t identity() const { return 1; }
1327 };
1328
1329 } // namespace
1330
_sparse_csr_sum_cpu(const Tensor & input,IntArrayRef dims_to_sum,bool keepdim,std::optional<ScalarType> dtype)1331 Tensor _sparse_csr_sum_cpu(const Tensor& input, IntArrayRef dims_to_sum, bool keepdim, std::optional<ScalarType> dtype) {
1332 ScalarType dtype_ = dtype.value_or(input.scalar_type());
1333 Tensor input_ = at::sparse_csr::to_type(input, dtype_);
1334 Tensor result;
1335 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
1336 kHalf, kBFloat16, input_.scalar_type(), "_sparse_csr_sum_cpu", [&] {
1337 // Set `is_cuda` = `true` in acc_type in CPU backend. Because the accumulate type
1338 // of float should be float in current scenario. In CUDA, float is the accumulate type
1339 // of float, while in CPU, double is the accumulate type of float.
1340 using acc_t = at::acc_type<scalar_t, true>;
1341 result = reduce_sparse_csr_cpu_template<scalar_t>(
1342 input_, dims_to_sum, keepdim, ReductionAddOp<acc_t>());
1343 });
1344 return result;
1345 }
1346
_sparse_csr_prod_cpu(const Tensor & input,IntArrayRef dims_to_reduce,bool keepdim,std::optional<ScalarType> dtype)1347 Tensor _sparse_csr_prod_cpu(const Tensor& input, IntArrayRef dims_to_reduce, bool keepdim, std::optional<ScalarType> dtype) {
1348 ScalarType dtype_ = dtype.value_or(input.scalar_type());
1349 Tensor input_ = input.to(dtype_);
1350 Tensor result;
1351 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(
1352 kHalf, kBFloat16, input_.scalar_type(), "_sparse_csr_prod_cpu",
1353 [&] {
1354 result = reduce_sparse_csr_cpu_template<scalar_t>(input_, dims_to_reduce, keepdim, ReductionMulOp<scalar_t>());
1355 });
1356 return result;
1357 }
1358
_sparse_mm_reduce_impl_sparse_csr_cpu(const Tensor & self,const Tensor & other,const c10::string_view reduce)1359 std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_sparse_csr_cpu(
1360 const Tensor& self,
1361 const Tensor& other,
1362 const c10::string_view reduce) {
1363
1364 auto layout = self.layout();
1365 TORCH_CHECK(layout == kSparseCsr,
1366 "sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
1367 TORCH_CHECK(self.dense_dim() == 0,
1368 "sparse_mm_reduce: expected non-hybrid self tensor.");
1369 TORCH_CHECK(self.dim() == 2,
1370 "sparse_mm_reduce: expected self to be a 2-D tensor, got ", self.dim(), "-D tensor.");
1371
1372 sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/false>(
1373 self, Tensor(), other);
1374
1375 auto op = get_reduction_enum(reduce);
1376 TORCH_CHECK(op != ReductionType::PROD, "sparse_mm_reduce: reduce type of prod has not been enabled.")
1377
1378 auto crow = self.crow_indices();
1379 auto col = self.col_indices();
1380 auto val = self.values();
1381
1382 // init output to be all zeros, for `rows` that has no nonzero elements,
1383 // the corresponding rows in the output will be zero.
1384 auto out = at::zeros({self.size(0), other.size(1)}, other.options());
1385 auto arg_out = at::empty({0}, col.options());
1386
1387 int64_t nnz = self._nnz();
1388 if (nnz == 0) {
1389 return std::make_tuple(out, arg_out);
1390 }
1391
1392 // only need to calculate the out args
1393 // for reduce type "amax" and "amin" for training
1394 bool need_arg_out = at::GradMode::is_enabled()
1395 && (self.requires_grad() || other.requires_grad())
1396 && (op == ReductionType::MAX || op == ReductionType::MIN);
1397
1398 if (!need_arg_out) {
1399 spmm_reduce_stub(kCPU, out, crow, col, val, other, op);
1400 } else {
1401 // allocate memory and init with invalid index
1402 arg_out.resize_(out.sizes());
1403 arg_out.fill_(nnz);
1404 spmm_reduce_arg_stub(kCPU, out, arg_out, crow, col, val, other, op);
1405 }
1406
1407 return std::make_tuple(std::move(out), std::move(arg_out));
1408 }
1409
_sparse_mm_reduce_impl_backward_sparse_csr_cpu(const Tensor & self,const Tensor & grad_out,const Tensor & other,const c10::string_view reduce,const Tensor & arg_out,std::array<bool,2> output_mask)1410 std::tuple<Tensor, Tensor> _sparse_mm_reduce_impl_backward_sparse_csr_cpu(
1411 const Tensor& self,
1412 const Tensor& grad_out,
1413 const Tensor& other,
1414 const c10::string_view reduce,
1415 const Tensor& arg_out,
1416 std::array<bool, 2> output_mask) {
1417
1418 auto layout = self.layout();
1419 TORCH_CHECK(layout == kSparseCsr,
1420 "sparse_mm_reduce: expect self to be SparseCsr, got ", layout);
1421
1422 sparse::impl::check_sparse_mm_reduce_impl_inputs</*train*/true>(
1423 self, grad_out, other);
1424
1425 auto op = get_reduction_enum(reduce);
1426
1427 auto crow = self.crow_indices();
1428 auto col = self.col_indices();
1429 auto val = self.values();
1430
1431 // `row`: row indices of COO format
1432 // `ccol`: ccol indices of CSC format (with permute)
1433 // `permute`: permute pattern from CSR to CSC
1434 //
1435 // TODO: optimize the following section,
1436 // currently `argsort` is sequential.
1437 Tensor row, ccol, permute;
1438 {
1439 bool out_int32 = crow.scalar_type() == ScalarType::Int;
1440 Tensor coo_indices = at::_convert_indices_from_csr_to_coo(
1441 crow,
1442 col,
1443 out_int32,
1444 /*transpose*/false);
1445 row = coo_indices.select(0, 0);
1446
1447 // calculate the global index for CSC
1448 // and get the conversion permute pattern
1449 Tensor index = col.mul(self.size(0)).add_(row);
1450 permute = index.argsort();
1451
1452 ccol = at::_convert_indices_from_coo_to_csr(
1453 /*column indices*/col.index_select(0, permute),
1454 /*column count*/self.size(1),
1455 out_int32);
1456 }
1457
1458 Tensor grad_self, grad_other;
1459 if (output_mask[0]) {
1460 // grad_input has the same indices and nnz with input
1461 grad_self = at::empty_like(self);
1462 grad_self.values().zero_();
1463 if (op == ReductionType::MAX || op == ReductionType::MIN) {
1464 spmm_reduce_backward_input_arg_stub(kCPU, grad_self, grad_out, col, other, arg_out, op);
1465 } else {
1466 spmm_reduce_backward_input_stub(kCPU, grad_self, grad_out, crow, col, other, row, op);
1467 }
1468 }
1469 if (output_mask[1]) {
1470 grad_other = at::zeros(other.sizes(), other.options());
1471 if (op == ReductionType::MAX || op == ReductionType::MIN) {
1472 spmm_reduce_backward_other_arg_stub(kCPU, grad_other, grad_out, col, val, arg_out, op);
1473 } else {
1474 spmm_reduce_backward_other_stub(kCPU, grad_other, grad_out, crow, val, row, ccol, permute, op);
1475 }
1476 }
1477
1478 return std::make_tuple(std::move(grad_self), std::move(grad_other));
1479 }
1480
1481 DEFINE_DISPATCH(spmm_reduce_stub);
1482 DEFINE_DISPATCH(spmm_reduce_arg_stub);
1483 DEFINE_DISPATCH(spmm_reduce_backward_input_stub);
1484 DEFINE_DISPATCH(spmm_reduce_backward_input_arg_stub);
1485 DEFINE_DISPATCH(spmm_reduce_backward_other_stub);
1486 DEFINE_DISPATCH(spmm_reduce_backward_other_arg_stub);
1487
1488 } // namespace native
1489 } // namespace at
1490