1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/native/sparse/ParamUtils.h>
7 #include <ATen/native/SparseTensorUtils.h>
8 #include <ATen/Parallel.h>
9 #include <c10/util/accumulate.h>
10 #include <c10/util/irange.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/CPUFunctions.h>
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_log_softmax_backward_data_cpu_dispatch.h>
18 #include <ATen/ops/_log_softmax_cpu_dispatch.h>
19 #include <ATen/ops/_softmax_backward_data_cpu_dispatch.h>
20 #include <ATen/ops/_softmax_cpu_dispatch.h>
21 #include <ATen/ops/_sparse_log_softmax.h>
22 #include <ATen/ops/_sparse_log_softmax_backward_data_native.h>
23 #include <ATen/ops/_sparse_log_softmax_native.h>
24 #include <ATen/ops/_sparse_softmax.h>
25 #include <ATen/ops/_sparse_softmax_backward_data_native.h>
26 #include <ATen/ops/_sparse_softmax_native.h>
27 #endif
28
29 #include <map>
30
31 namespace at::native {
32 namespace {
33
get_nvalues(const IntArrayRef & sizes,int64_t sparse_dim)34 int64_t get_nvalues(const IntArrayRef& sizes, int64_t sparse_dim) {
35 /* Return the number of entries in the dense part of a sparse tensor.
36
37 `sizes` is a vector of sparse tensor dimensions.
38 `sparse_dim` is the dimension of the sparse part of a sparse tensor.
39 */
40 return c10::multiply_integers(sizes.begin() + sparse_dim, sizes.end());
41 }
42
get_offsets(const Tensor & indices,const IntArrayRef & sizes,const int64_t dim)43 std::vector<int64_t> get_offsets(const Tensor& indices, const IntArrayRef& sizes, const int64_t dim) {
44 /*
45 Given the indices of a sparse tensor, return a vector of offsets
46 for the entries in the equivalent dense tensor:
47
48 If
49 offsets = get_offsets(A._indices(), A.sizes(), -1)
50 data = A.to_dense().resize((nnz,))
51 then
52 data[offsets[n]] == A._values()[n]
53
54 `indices` must be a contiguous 2-d tensor with int64_t entries.
55 `sizes` must be a vector with at least ndim entries.
56
57 `dim` is an integer. When >= 0 and < ndim, the indices of all
58 entries in the given dimension will be mapped to the index of the
59 first entry before computing the offset. Otherwise, the value is
60 ignored.
61
62 For example, consider a sparse tensor
63
64 11 ** ** 14 15
65 ** 22 ** 24 **
66
67 with
68
69 indices = [[0, 0, 0, 1, 1],
70 [0, 3, 4, 1, 3]]
71
72 then
73
74 get_offsets(indices, (2, 5), -1) -> [0, 3, 4, 6, 8]
75 get_offsets(indices, (2, 5), 0) -> [0, 3, 4, 1, 3]
76 get_offsets(indices, (2, 5), 1) -> [0, 0, 0, 5, 5]
77
78 */
79 auto ndim = indices.size(0);
80 auto nnz = indices.size(1);
81 std::vector<int64_t> offsets(nnz);
82 std::vector<int64_t> strides(ndim, 1);
83 auto indices_accessor = indices.accessor<int64_t, 2>();
84
85 if (ndim > 1) {
86 for (int64_t i=ndim - 2; i >= 0; i--) {
87 strides[i] = strides[i + 1] * sizes[i + 1];
88 }
89 }
90
91 for (const auto i : c10::irange(nnz)) {
92 int64_t acc = 0;
93 for (const auto j : c10::irange(ndim)) {
94 auto indices_row = indices_accessor[j];
95 auto stride = strides[j];
96 if (j != dim) {
97 acc += stride * indices_row[i];
98 }
99 }
100 offsets[i] = acc;
101 }
102
103 return offsets;
104 }
105
get_pools(const Tensor & indices,const IntArrayRef & sizes,const int64_t dim)106 std::vector<std::vector<int64_t>> get_pools(const Tensor& indices, const IntArrayRef& sizes, const int64_t dim) {
107 /*
108 Return pools of indices that align with the given dimension.
109
110 Parameters:
111 `indices` - sparse tensor indices
112 `sizes` - sparse tensor dimensions
113 `dim` - given dimension
114
115 Returns:
116 `pools` - a ragged array of indices
117
118 A pool is defined as a list of indices (of sparse tensor values)
119 that participate in the same softmax computation:
120
121 - pools[i] intersection with pools[j] is empty iff i != j
122 - union of all pools is set(range(nnz))
123 - X.values[k], k in pools[i], does not affect the result of softmax(X)[n], n in pools[j], iff i != j
124
125 */
126 std::vector<std::vector<int64_t>> pools;
127
128 auto ndim = indices.size(0);
129 auto nnz = indices.size(1);
130 std::vector<int64_t> strides(ndim, 1);
131 auto indices_accessor = indices.accessor<int64_t, 2>();
132
133 if (ndim > 1) {
134 for (int64_t i=ndim - 2; i >= 0; i--) {
135 strides[i] = strides[i + 1] * (i + 1 == dim? 1 : sizes[i + 1]);
136 }
137 }
138
139 for (const auto i : c10::irange(nnz)) {
140 int64_t pool_index = 0;
141 for (const auto j : c10::irange(ndim)) {
142 if (j != dim) {
143 const auto indices_row = indices_accessor[j];
144 const auto stride = strides[j];
145 pool_index += stride * indices_row[i];
146 }
147 }
148 if(static_cast<int64_t>(pools.size()) <= pool_index){
149 pools.resize(pool_index + 1);
150 }
151 pools.at(pool_index).push_back(i);
152 }
153
154 return pools;
155 }
156
157 template <typename scalar_t, bool LogSoftMax>
cpu_sparse_coo_softmax(Tensor output,const Tensor & input,const int64_t dim)158 void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t dim) {
159 /*
160 See test/test_sparse.py:test_softmax:sparse_softmax for the Python
161 prototype of the sparse softmax algorithm that this implementation
162 is based on.
163
164 Derivation of the sparse softmax algorithm with an example
165 ----------------------------------------------------------
166
167 Consider the following 2-D sparse tensor with 0-D dense part as an
168 example, denote it by X:
169
170 11 ** ** 14 15
171 ** 22 ** 24 **
172
173 where `**` represent unspecified entries. The COO sparse tensor
174 representation of X is:
175
176 indices = [[0, 1, 0, 1, 0],
177 [0, 1, 3, 3, 4]]
178 values = [11, 22, 14, 24, 15]
179
180 that after coalescing becomes
181
182 indices = [[0, 0, 0, 1, 1],
183 [0, 3, 4, 1, 3]]
184 values = [11, 14, 15, 22, 24]
185
186 The softmax of X along the given dimension d is defined as
187
188 S_d[i, j] = exp(X[i, j]) / sum(exp(X[I_d[k]]), k=0..X.shape[d]-1)
189
190 where the index tuple I_d[k] is defined as
191
192 I_0[k] = k, j
193 I_1[k] = i, k
194
195 For sparse tensors, the unspecified entries are skipped in the
196 softmax sum of exponents so that the result will be sparse tensor
197 with the same indices as the input. Mathematically, this
198 corresponds to the case where the unspecified entries are
199 interpreted as negative infinities rather than zeros.
200
201 To minimize the defects from numerical evaluation of exponents
202 with very large or small arguments, the softmax implementation
203 uses the following a numerically stable definition:
204
205 S_d[i, j] = exp(X[i, j] - maxX_d) / sum(exp(X[I_d[k]] - maxX_d), k=0...X.shape[d]-1)
206
207 where
208
209 maxX_d = max(X[I_d[k]], k=0...X.shape[d]-1)
210
211 is the maximum tensor along the direction d (it has dimensionality
212 `maxX_d.ndim = X.ndim - 1`).
213
214 For the example sparse tensor X, we have:
215
216 S_0._indices() == S_1._indices() == X._indices()
217
218 maxX_0 = [11, 22, -inf, 24, 15]
219 maxX_1 = [15, 24]
220
221 S_0._values() = [exp(11 - maxX_0[0]) / exp(11 - maxX_0[0]),
222 exp(14 - maxX_0[3]) / (exp(14 - maxX_0[3]) + exp(24 - maxX_0[3])),
223 exp(15 - maxX_0[4]) / exp(15 - maxX_0[4]),
224 exp(22 - maxX_0[1]) / exp(22 - maxX_0[1]),
225 exp(24 - maxX_0[3]) / (exp(14 - maxX_0[3]) + exp(24 - maxX_0[3]))]
226 = [1, exp(-10)/(exp(-10) + 1), 1, 1, 1/(exp(-10) + 1)]
227
228 (note that `maxX_0[2] == -inf` not used to obtain S_0)
229
230 S_1._values() = [exp(11 - maxX_1[0]) / (exp(11 - maxX_1[0]) + exp(14 - maxX_1[0]) + exp(15 - maxX_1[0])),
231 exp(14 - maxX_1[0]) / (exp(11 - maxX_1[0]) + exp(14 - maxX_1[0]) + exp(15 - maxX_1[0])),
232 exp(15 - maxX_1[0]) / (exp(11 - maxX_1[0]) + exp(14 - maxX_1[0]) + exp(15 - maxX_1[0])),
233 exp(22 - maxX_1[1]) / (exp(22 - maxX_1[1]) + exp(24 - maxX_1[1])),
234 exp(24 - maxX_1[1]) / (exp(22 - maxX_1[1]) + exp(24 - maxX_1[1]))]
235 = [exp(-4) / (exp(-4) + exp(-1) + 1),
236 exp(-1) / (exp(-4) + exp(-1) + 1),
237 1 / (exp(-4) + exp(-1) + 1),
238 exp(-2) / (exp(-2) + 1),
239 1 / (exp(-2) + 1)]
240
241 To obtain the above via the for-loop over
242 `nnz(=len(X._values()))`, we introduce the indices mapping `pool`
243 as follows:
244
245 indices = X._indices()
246 for i in range(nnz):
247 for j in range(nnz):
248 if indices[d, i] == indices[d, j]:
249 assert pool_d[i] == pool_d[j]
250 else:
251 assert pool_d[i] != pool_d[j]
252
253 that is, the entries with values indices i and j are in the same
254 pool iff their locations in the grid of tensor indices align with
255 the direction along which the softmax is calculated. The `pool`
256 mapping maps the X._values() indices to the corresponding pool
257 index.
258
259 To save memory and processor resources, we pre-compute the entries
260 of maxX tensor and the sums of exponents as follows:
261
262 mx_d = [max(values[i] for i in range(nnz) if pool_0[i] == k) for k in pool_d]
263 exp_sum_d = [sum(exp(values[i] - mx_d[k]) for i in range(nnz) if pool_d[i] == k) for k in pool_d]
264
265 For example, if
266
267 pool_0 = [0, 1, 2, 3, 1]
268 pool_1 = [0, 0, 0, 1, 1]
269
270 then
271
272 mx_0 = [11, 24, 15, 22]
273 mx_1 = [15, 24]
274 exp_sum_0 = [1, (exp(-10) + 1), 1, 1]
275 exp_sum_1 = [(exp(-4) + exp(-1) + 1), (exp(-2) + 1)]
276
277 and
278
279 S_0._values() = [exp(11 - mx_0[pool_0[0]]) / exp_sum_0[pool_0[0]]
280 exp(14 - mx_0[pool_0[1]]) / exp_sum_0[pool_0[1]]
281 exp(15 - mx_0[pool_0[2]]) / exp_sum_0[pool_0[2]]
282 exp(22 - mx_0[pool_0[3]]) / exp_sum_0[pool_0[3]]
283 exp(24 - mx_0[pool_0[4]]) / exp_sum_0[pool_0[4]]
284
285 or in general,
286
287 S_d._values() = [exp(values[i] - mx_d[pool_d[i]]) / exp_sum_d[pool_d[i] for i in range(nnz)]
288
289 The above algorithm can be easily extended for cases with
290 non-scalar dense part of the sparse tensor where all scalar
291 operations become element-wise tensor operations.
292
293 The implementation below has more optimizations such as that
294 collect pool indices for enabling concurrency, minimize the calls
295 to exp functions as well as reuse of softmax implementation for
296 log_softmax.
297 */
298 auto sparse_dim = input.sparse_dim();
299 auto indices = input._indices().contiguous();
300 auto values = input._values().contiguous();
301 auto out_values = output._values();
302 auto out_indices = output._indices();
303 out_values.resize_as_(values);
304 out_indices.resize_as_(indices);
305 out_indices.copy_(indices);
306
307 if (dim >= sparse_dim) {
308 if (LogSoftMax) {
309 auto new_values =
310 at::cpu::_log_softmax(values, dim - sparse_dim + 1, false);
311 out_values.set_(new_values);
312 } else {
313 auto new_values = at::cpu::_softmax(values, dim - sparse_dim + 1, false);
314 out_values.set_(new_values);
315 }
316 return;
317 }
318
319 auto nnz = values.size(0);
320 auto sizes = input.sizes();
321 auto nvalues = get_nvalues(sizes, sparse_dim);
322
323 /* Prepare accessors */
324 auto values_2 = values.view({nnz, nvalues});
325 auto values_accessor = values_2.accessor<scalar_t, 2>();
326
327 auto out_values_2 = out_values.view({nnz, nvalues});
328 auto out_values_accessor = out_values_2.accessor<scalar_t, 2>();
329
330 /* Compute independent pools of indices */
331 auto pools = get_pools(indices, sizes, dim);
332
333 int64_t grain_size = 1;
334 parallel_for(0, pools.size(), grain_size, [&](int64_t begin, int64_t end) {
335 for (const auto p : c10::irange(begin, end)) {
336 auto pool_indices = pools[p];
337
338 // Skip empty pools
339 if (pool_indices.empty())
340 continue;
341
342 /* Prepare scratch space */
343 std::vector<scalar_t> mx_row(nvalues, -std::numeric_limits<scalar_t>::infinity());
344 std::vector<scalar_t> exp_sums_row(nvalues, 0);
345
346 /* Compute mx */
347 for (int64_t i : pool_indices) {
348 auto values_row = values_accessor[i];
349 for (const auto j : c10::irange(nvalues)) {
350 mx_row[j] = std::max(mx_row[j], values_row[j]);
351 }
352 }
353
354 /* Apply exp to (v - mx) and sum the results */
355 for (int64_t i : pool_indices) {
356 auto values_row = values_accessor[i];
357 auto out_values_row = out_values_accessor[i];
358 for (const auto j : c10::irange(nvalues)) {
359 auto v = std::exp(values_row[j] - mx_row[j]);
360 if (!LogSoftMax) {
361 out_values_row[j] = v;
362 }
363 exp_sums_row[j] += v;
364 }
365 }
366
367 for (const auto j : c10::irange(nvalues)) {
368 if (LogSoftMax) {
369 mx_row[j] += std::log(exp_sums_row[j]);
370 } else {
371 exp_sums_row[j] = 1.0 / exp_sums_row[j];
372 }
373 }
374
375 /* Normalize with the sum of exponents */
376 for (int64_t i : pool_indices) {
377 auto values_row = values_accessor[i];
378 auto out_values_row = out_values_accessor[i];
379 for (const auto j : c10::irange(nvalues)) {
380 if (LogSoftMax) {
381 out_values_row[j] = values_row[j] - mx_row[j];
382 } else {
383 out_values_row[j] *= exp_sums_row[j];
384 }
385 }
386 }
387 }
388 });
389 }
390
391 template <typename scalar_t, bool LogSoftMax>
cpu_sparse_coo_softmax_backward(const Tensor & grad_input,const Tensor & grad,const Tensor & output,const int64_t dim,ScalarType input_dtype)392 void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& grad, const Tensor& output, const int64_t dim, ScalarType input_dtype) {
393 /*
394
395 If LogSoftMax == false, then
396
397 gI_i = sum_j d<output_j>/d<input_i> * grad_j = sum_j output_i * (1[i==j] - output_j) * grad_j
398 = output_i * (grad_i - sum_j output_j * grad_j)
399
400 else
401
402 gI_i = (1-exp(output_i)) * grad_i - sum_{j} 1[i!=j] * exp(output_i) * grad_j
403 = grad_i - exp(output_i) * sum_j grad_j.
404
405 where
406
407 i, j in range(shape[dim])
408 x_i = x[..., i_dim, ...]
409 output.sparse_dim() == grad.sparse_dim()
410 */
411 auto sparse_dim = output.sparse_dim();
412 auto sizes = output.sizes().vec();
413 auto grad_indices = grad._indices().contiguous();
414 auto grad_values = grad._values().contiguous();
415 auto out_indices = output._indices().contiguous();
416 auto out_values = output._values().contiguous();
417 auto values = grad_input._values();
418 auto indices = grad_input._indices();
419 auto out_nnz = out_values.size(0);
420 auto grad_nnz = grad_values.size(0);
421
422 values.resize_as_(out_values);
423 values.zero_();
424 indices.resize_as_(out_indices);
425 indices.copy_(out_indices);
426
427 auto out_offsets = get_offsets(out_indices, sizes, -1);
428 auto grad_offsets = get_offsets(grad_indices, sizes, -1);
429
430 if (dim >= sparse_dim) {
431 if (out_offsets == grad_offsets) {
432 if (LogSoftMax) {
433 auto r = at::cpu::_log_softmax_backward_data(
434 grad_values, out_values, dim - sparse_dim + 1, input_dtype);
435 values.set_(r);
436 } else {
437 auto r = at::cpu::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, input_dtype);
438 values.set_(r);
439 }
440 } else {
441 for (const auto i : c10::irange(out_nnz)) {
442 auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
443 auto j = low - grad_offsets.begin();
444 if (j < grad_nnz && out_offsets[i] == grad_offsets[j]) {
445 if (LogSoftMax) {
446 auto r = at::cpu::_log_softmax_backward_data(
447 grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
448 values[i].copy_(r);
449 } else {
450 auto r = at::cpu::_softmax_backward_data(grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
451 values[i].copy_(r);
452 }
453 }
454 }
455 }
456 return;
457 }
458
459 auto nnz = values.size(0);
460 auto nvalues = get_nvalues(sizes, sparse_dim);
461
462 auto values_2 = values.view({nnz, nvalues});
463 auto values_accessor = values_2.accessor<scalar_t, 2>();
464
465 auto out_values_2 = out_values.view({out_nnz, nvalues});
466 auto out_values_accessor = out_values_2.accessor<scalar_t, 2>();
467
468 auto grad_values_2 = grad_values.view({grad_nnz, nvalues});
469 auto grad_values_accessor = grad_values_2.accessor<scalar_t, 2>();
470
471 /* Compute independent pools of indices */
472 auto pools = get_pools(out_indices, sizes, dim);
473
474 int64_t grain_size = 1;
475 parallel_for(0, pools.size(), grain_size, [&](int64_t begin, int64_t end) {
476 for (const auto p : c10::irange(begin, end)) {
477 auto pool_indices = pools[p];
478
479 // Skip empty pools
480 if (pool_indices.empty())
481 continue;
482
483 std::vector<scalar_t> tmp_row(nvalues, 0);
484
485 /* Compute tmp = - sum_j output_j * grad_j */
486 for (int64_t i : pool_indices) {
487 auto out_values_row = out_values_accessor[i];
488 auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
489 auto j = low - grad_offsets.begin();
490
491 if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
492 auto grad_values_row = grad_values_accessor[j];
493 for (const auto k : c10::irange(nvalues)) {
494 if (LogSoftMax) {
495 tmp_row[k] -= grad_values_row[k];
496 } else {
497 tmp_row[k] -= out_values_row[k] * grad_values_row[k];
498 }
499 }
500 }
501 }
502
503 /* Compute grad_input = output * (grad + tmp)*/
504 for (int64_t i : pool_indices) {
505 auto out_values_row = out_values_accessor[i];
506 auto values_row = values_accessor[i];
507 auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
508 auto j = low - grad_offsets.begin();
509
510 if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
511 auto grad_values_row = grad_values_accessor[j];
512 for (const auto k : c10::irange(nvalues)) {
513 if (LogSoftMax) {
514 values_row[k] = grad_values_row[k] + std::exp(out_values_row[k]) * tmp_row[k];
515 } else {
516 values_row[k] = out_values_row[k] * (grad_values_row[k] + tmp_row[k]);
517 }
518 }
519 } else {
520 for (const auto k : c10::irange(nvalues)) {
521 if (LogSoftMax) {
522 values_row[k] = std::exp(out_values_row[k]) * tmp_row[k];
523 } else {
524 values_row[k] = out_values_row[k] * (tmp_row[k]);
525 }
526 }
527 }
528 }
529 }
530 });
531 }
532
533 } // anonymous namespace
534
softmax_sparse_cpu(const Tensor & input_,const int64_t dim_,const bool half_to_float)535 Tensor softmax_sparse_cpu(
536 const Tensor& input_,
537 const int64_t dim_,
538 const bool half_to_float) {
539 Tensor input, output;
540 int64_t dim;
541 std::tie(input, output, dim) = softmax_sparse_input_preprocessing(
542 input_, dim_, half_to_float, "softmax");
543 if (input.numel() == 0) {
544 return output;
545 }
546 AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "softmax", [&] {
547 cpu_sparse_coo_softmax<scalar_t, false>(output, input, dim);
548 });
549 return output;
550 }
551
log_softmax_sparse_cpu(const Tensor & input_,const int64_t dim_,const bool half_to_float)552 Tensor log_softmax_sparse_cpu(
553 const Tensor& input_,
554 const int64_t dim_,
555 const bool half_to_float) {
556 Tensor input, output;
557 int64_t dim;
558 std::tie(input, output, dim) = softmax_sparse_input_preprocessing(
559 input_, dim_, half_to_float, "log_softmax");
560 if (input.numel() == 0) {
561 return output;
562 }
563 AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax", [&] {
564 cpu_sparse_coo_softmax<scalar_t, true>(output, input, dim);
565 });
566 return output;
567 }
568
softmax_backward_sparse_cpu(const Tensor & grad_,const Tensor & output_,int64_t dim_,const Tensor & input_)569 Tensor softmax_backward_sparse_cpu(
570 const Tensor& grad_,
571 const Tensor& output_,
572 int64_t dim_,
573 const Tensor& input_) {
574 Tensor grad_input, grad, output;
575 int64_t dim;
576 std::tie(grad_input, grad, output, dim) =
577 softmax_backward_sparse_input_preprocessing(
578 grad_, output_, dim_, input_, "softmax_backward");
579 if (output.numel() == 0) {
580 return grad_input;
581 }
582 AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
583 cpu_sparse_coo_softmax_backward<scalar_t, false>(
584 grad_input, grad, output, dim_, input_.scalar_type());
585 });
586 return grad_input;
587 }
588
log_softmax_backward_sparse_cpu(const Tensor & grad_,const Tensor & output_,int64_t dim_,const Tensor & input_)589 Tensor log_softmax_backward_sparse_cpu(
590 const Tensor& grad_,
591 const Tensor& output_,
592 int64_t dim_,
593 const Tensor& input_) {
594 Tensor grad_input, grad, output;
595 int64_t dim;
596 std::tie(grad_input, grad, output, dim) =
597 softmax_backward_sparse_input_preprocessing(
598 grad_, output_, dim_, input_, "log_softmax_backward");
599 if (output.numel() == 0) {
600 return grad_input;
601 }
602 AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] {
603 cpu_sparse_coo_softmax_backward<scalar_t, true>(
604 grad_input, grad, output, dim_, input_.scalar_type());
605 });
606 return grad_input;
607 }
608
_sparse_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)609 Tensor _sparse_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
610 auto result = [&]() {
611 NoNamesGuard guard;
612 if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
613 return at::_sparse_softmax(input_, dim_, true);
614 } else {
615 Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
616 return at::_sparse_softmax(converted, dim_, false);
617 }
618 }();
619 namedinference::propagate_names(result, input_);
620 return result;
621 }
622
_sparse_softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)623 Tensor _sparse_softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
624 return at::_sparse_softmax(self, dimname_to_position(self, dim), dtype);
625 }
626
_sparse_log_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)627 Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
628 auto result = [&]() {
629 NoNamesGuard guard;
630 if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
631 return at::_sparse_log_softmax(input_, dim_, true);
632 } else {
633 Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
634 return at::_sparse_log_softmax(converted, dim_, false);
635 }
636 }();
637 namedinference::propagate_names(result, input_);
638 return result;
639 }
640
_sparse_log_softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)641 Tensor _sparse_log_softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
642 return at::_sparse_log_softmax(self, dimname_to_position(self, dim), dtype);
643 }
644
645 } // namespace at::native
646