xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/SoftMax.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Config.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/native/sparse/ParamUtils.h>
7 #include <ATen/native/SparseTensorUtils.h>
8 #include <ATen/Parallel.h>
9 #include <c10/util/accumulate.h>
10 #include <c10/util/irange.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/CPUFunctions.h>
14 #include <ATen/Functions.h>
15 #include <ATen/NativeFunctions.h>
16 #else
17 #include <ATen/ops/_log_softmax_backward_data_cpu_dispatch.h>
18 #include <ATen/ops/_log_softmax_cpu_dispatch.h>
19 #include <ATen/ops/_softmax_backward_data_cpu_dispatch.h>
20 #include <ATen/ops/_softmax_cpu_dispatch.h>
21 #include <ATen/ops/_sparse_log_softmax.h>
22 #include <ATen/ops/_sparse_log_softmax_backward_data_native.h>
23 #include <ATen/ops/_sparse_log_softmax_native.h>
24 #include <ATen/ops/_sparse_softmax.h>
25 #include <ATen/ops/_sparse_softmax_backward_data_native.h>
26 #include <ATen/ops/_sparse_softmax_native.h>
27 #endif
28 
29 #include <map>
30 
31 namespace at::native {
32 namespace {
33 
get_nvalues(const IntArrayRef & sizes,int64_t sparse_dim)34 int64_t get_nvalues(const IntArrayRef& sizes, int64_t sparse_dim) {
35   /* Return the number of entries in the dense part of a sparse tensor.
36 
37      `sizes` is a vector of sparse tensor dimensions.
38      `sparse_dim` is the dimension of the sparse part of a sparse tensor.
39    */
40   return c10::multiply_integers(sizes.begin() + sparse_dim, sizes.end());
41 }
42 
get_offsets(const Tensor & indices,const IntArrayRef & sizes,const int64_t dim)43 std::vector<int64_t> get_offsets(const Tensor& indices, const IntArrayRef& sizes, const int64_t dim) {
44   /*
45     Given the indices of a sparse tensor, return a vector of offsets
46     for the entries in the equivalent dense tensor:
47 
48       If
49         offsets = get_offsets(A._indices(), A.sizes(), -1)
50         data = A.to_dense().resize((nnz,))
51       then
52         data[offsets[n]] == A._values()[n]
53 
54     `indices` must be a contiguous 2-d tensor with int64_t entries.
55     `sizes` must be a vector with at least ndim entries.
56 
57     `dim` is an integer. When >= 0 and < ndim, the indices of all
58     entries in the given dimension will be mapped to the index of the
59     first entry before computing the offset. Otherwise, the value is
60     ignored.
61 
62     For example, consider a sparse tensor
63 
64       11 ** ** 14 15
65       ** 22 ** 24 **
66 
67     with
68 
69       indices = [[0, 0, 0, 1, 1],
70                  [0, 3, 4, 1, 3]]
71 
72     then
73 
74       get_offsets(indices, (2, 5), -1) -> [0, 3, 4, 6, 8]
75       get_offsets(indices, (2, 5), 0) -> [0, 3, 4, 1, 3]
76       get_offsets(indices, (2, 5), 1) -> [0, 0, 0, 5, 5]
77 
78   */
79   auto ndim = indices.size(0);
80   auto nnz = indices.size(1);
81   std::vector<int64_t> offsets(nnz);
82   std::vector<int64_t> strides(ndim, 1);
83   auto indices_accessor = indices.accessor<int64_t, 2>();
84 
85   if (ndim > 1) {
86     for (int64_t i=ndim - 2; i >= 0; i--) {
87       strides[i] = strides[i + 1] * sizes[i + 1];
88     }
89   }
90 
91   for (const auto i : c10::irange(nnz)) {
92     int64_t acc = 0;
93     for (const auto j : c10::irange(ndim)) {
94       auto indices_row = indices_accessor[j];
95       auto stride = strides[j];
96       if (j != dim) {
97         acc += stride * indices_row[i];
98       }
99     }
100     offsets[i] = acc;
101   }
102 
103   return offsets;
104 }
105 
get_pools(const Tensor & indices,const IntArrayRef & sizes,const int64_t dim)106 std::vector<std::vector<int64_t>> get_pools(const Tensor& indices, const IntArrayRef& sizes, const int64_t dim) {
107   /*
108     Return pools of indices that align with the given dimension.
109 
110     Parameters:
111       `indices` - sparse tensor indices
112       `sizes`   - sparse tensor dimensions
113       `dim`     - given dimension
114 
115     Returns:
116       `pools`   - a ragged array of indices
117 
118     A pool is defined as a list of indices (of sparse tensor values)
119     that participate in the same softmax computation:
120 
121     - pools[i] intersection with pools[j] is empty iff i != j
122     - union of all pools is set(range(nnz))
123     - X.values[k], k in pools[i], does not affect the result of softmax(X)[n], n in pools[j], iff i != j
124 
125   */
126   std::vector<std::vector<int64_t>> pools;
127 
128   auto ndim = indices.size(0);
129   auto nnz = indices.size(1);
130   std::vector<int64_t> strides(ndim, 1);
131   auto indices_accessor = indices.accessor<int64_t, 2>();
132 
133   if (ndim > 1) {
134     for (int64_t i=ndim - 2; i >= 0; i--) {
135       strides[i] = strides[i + 1] * (i + 1 == dim? 1 : sizes[i + 1]);
136     }
137   }
138 
139   for (const auto i : c10::irange(nnz)) {
140     int64_t pool_index = 0;
141     for (const auto j : c10::irange(ndim)) {
142       if (j != dim) {
143         const auto indices_row = indices_accessor[j];
144         const auto stride = strides[j];
145         pool_index += stride * indices_row[i];
146       }
147     }
148     if(static_cast<int64_t>(pools.size()) <= pool_index){
149       pools.resize(pool_index + 1);
150     }
151     pools.at(pool_index).push_back(i);
152   }
153 
154   return pools;
155 }
156 
157 template <typename scalar_t, bool LogSoftMax>
cpu_sparse_coo_softmax(Tensor output,const Tensor & input,const int64_t dim)158 void cpu_sparse_coo_softmax(Tensor output, const Tensor& input, const int64_t dim) {
159   /*
160     See test/test_sparse.py:test_softmax:sparse_softmax for the Python
161     prototype of the sparse softmax algorithm that this implementation
162     is based on.
163 
164     Derivation of the sparse softmax algorithm with an example
165     ----------------------------------------------------------
166 
167     Consider the following 2-D sparse tensor with 0-D dense part as an
168     example, denote it by X:
169 
170       11 ** ** 14 15
171       ** 22 ** 24 **
172 
173     where `**` represent unspecified entries. The COO sparse tensor
174     representation of X is:
175 
176       indices = [[0, 1, 0, 1, 0],
177                  [0, 1, 3, 3, 4]]
178       values = [11, 22, 14, 24, 15]
179 
180     that after coalescing becomes
181 
182       indices = [[0, 0, 0, 1, 1],
183                  [0, 3, 4, 1, 3]]
184       values = [11, 14, 15, 22, 24]
185 
186     The softmax of X along the given dimension d is defined as
187 
188       S_d[i, j] = exp(X[i, j]) / sum(exp(X[I_d[k]]), k=0..X.shape[d]-1)
189 
190     where the index tuple I_d[k] is defined as
191 
192       I_0[k] = k, j
193       I_1[k] = i, k
194 
195     For sparse tensors, the unspecified entries are skipped in the
196     softmax sum of exponents so that the result will be sparse tensor
197     with the same indices as the input. Mathematically, this
198     corresponds to the case where the unspecified entries are
199     interpreted as negative infinities rather than zeros.
200 
201     To minimize the defects from numerical evaluation of exponents
202     with very large or small arguments, the softmax implementation
203     uses the following a numerically stable definition:
204 
205       S_d[i, j] = exp(X[i, j] - maxX_d) / sum(exp(X[I_d[k]] - maxX_d), k=0...X.shape[d]-1)
206 
207     where
208 
209       maxX_d = max(X[I_d[k]], k=0...X.shape[d]-1)
210 
211     is the maximum tensor along the direction d (it has dimensionality
212     `maxX_d.ndim = X.ndim - 1`).
213 
214     For the example sparse tensor X, we have:
215 
216       S_0._indices() == S_1._indices() == X._indices()
217 
218       maxX_0 = [11, 22, -inf, 24, 15]
219       maxX_1 = [15, 24]
220 
221       S_0._values() = [exp(11 - maxX_0[0]) / exp(11 - maxX_0[0]),
222                        exp(14 - maxX_0[3]) / (exp(14 - maxX_0[3]) + exp(24 - maxX_0[3])),
223                        exp(15 - maxX_0[4]) / exp(15 - maxX_0[4]),
224                        exp(22 - maxX_0[1]) / exp(22 - maxX_0[1]),
225                        exp(24 - maxX_0[3]) / (exp(14 - maxX_0[3]) + exp(24 - maxX_0[3]))]
226                     = [1, exp(-10)/(exp(-10) + 1), 1, 1, 1/(exp(-10) + 1)]
227 
228       (note that `maxX_0[2] == -inf` not used to obtain S_0)
229 
230       S_1._values() = [exp(11 - maxX_1[0]) / (exp(11 - maxX_1[0]) + exp(14 - maxX_1[0]) + exp(15 - maxX_1[0])),
231                        exp(14 - maxX_1[0]) / (exp(11 - maxX_1[0]) + exp(14 - maxX_1[0]) + exp(15 - maxX_1[0])),
232                        exp(15 - maxX_1[0]) / (exp(11 - maxX_1[0]) + exp(14 - maxX_1[0]) + exp(15 - maxX_1[0])),
233                        exp(22 - maxX_1[1]) / (exp(22 - maxX_1[1]) + exp(24 - maxX_1[1])),
234                        exp(24 - maxX_1[1]) / (exp(22 - maxX_1[1]) + exp(24 - maxX_1[1]))]
235                     = [exp(-4) / (exp(-4) + exp(-1) + 1),
236                        exp(-1) / (exp(-4) + exp(-1) + 1),
237                        1 / (exp(-4) + exp(-1) + 1),
238                        exp(-2) / (exp(-2) + 1),
239                        1 / (exp(-2) + 1)]
240 
241     To obtain the above via the for-loop over
242     `nnz(=len(X._values()))`, we introduce the indices mapping `pool`
243     as follows:
244 
245       indices = X._indices()
246       for i in range(nnz):
247           for j in range(nnz):
248               if indices[d, i] == indices[d, j]:
249                   assert pool_d[i] == pool_d[j]
250               else:
251                   assert pool_d[i] != pool_d[j]
252 
253     that is, the entries with values indices i and j are in the same
254     pool iff their locations in the grid of tensor indices align with
255     the direction along which the softmax is calculated. The `pool`
256     mapping maps the X._values() indices to the corresponding pool
257     index.
258 
259     To save memory and processor resources, we pre-compute the entries
260     of maxX tensor and the sums of exponents as follows:
261 
262       mx_d = [max(values[i] for i in range(nnz) if pool_0[i] == k) for k in pool_d]
263       exp_sum_d = [sum(exp(values[i] - mx_d[k]) for i in range(nnz) if pool_d[i] == k) for k in pool_d]
264 
265     For example, if
266 
267       pool_0 = [0, 1, 2, 3, 1]
268       pool_1 = [0, 0, 0, 1, 1]
269 
270     then
271 
272       mx_0 = [11, 24, 15, 22]
273       mx_1 = [15, 24]
274       exp_sum_0 = [1, (exp(-10) + 1), 1, 1]
275       exp_sum_1 = [(exp(-4) + exp(-1) + 1), (exp(-2) + 1)]
276 
277     and
278 
279       S_0._values() = [exp(11 - mx_0[pool_0[0]]) / exp_sum_0[pool_0[0]]
280                        exp(14 - mx_0[pool_0[1]]) / exp_sum_0[pool_0[1]]
281                        exp(15 - mx_0[pool_0[2]]) / exp_sum_0[pool_0[2]]
282                        exp(22 - mx_0[pool_0[3]]) / exp_sum_0[pool_0[3]]
283                        exp(24 - mx_0[pool_0[4]]) / exp_sum_0[pool_0[4]]
284 
285     or in general,
286 
287       S_d._values() = [exp(values[i] - mx_d[pool_d[i]]) / exp_sum_d[pool_d[i] for i in range(nnz)]
288 
289     The above algorithm can be easily extended for cases with
290     non-scalar dense part of the sparse tensor where all scalar
291     operations become element-wise tensor operations.
292 
293     The implementation below has more optimizations such as that
294     collect pool indices for enabling concurrency, minimize the calls
295     to exp functions as well as reuse of softmax implementation for
296     log_softmax.
297   */
298   auto sparse_dim = input.sparse_dim();
299   auto indices = input._indices().contiguous();
300   auto values = input._values().contiguous();
301   auto out_values = output._values();
302   auto out_indices = output._indices();
303   out_values.resize_as_(values);
304   out_indices.resize_as_(indices);
305   out_indices.copy_(indices);
306 
307   if (dim >= sparse_dim) {
308     if (LogSoftMax) {
309       auto new_values =
310           at::cpu::_log_softmax(values, dim - sparse_dim + 1, false);
311       out_values.set_(new_values);
312     } else {
313       auto new_values = at::cpu::_softmax(values, dim - sparse_dim + 1, false);
314       out_values.set_(new_values);
315     }
316     return;
317   }
318 
319   auto nnz = values.size(0);
320   auto sizes = input.sizes();
321   auto nvalues = get_nvalues(sizes, sparse_dim);
322 
323   /* Prepare accessors */
324   auto values_2 = values.view({nnz, nvalues});
325   auto values_accessor = values_2.accessor<scalar_t, 2>();
326 
327   auto out_values_2 = out_values.view({nnz, nvalues});
328   auto out_values_accessor = out_values_2.accessor<scalar_t, 2>();
329 
330   /* Compute independent pools of indices */
331   auto pools = get_pools(indices, sizes, dim);
332 
333   int64_t grain_size = 1;
334   parallel_for(0, pools.size(), grain_size, [&](int64_t begin, int64_t end) {
335       for (const auto p : c10::irange(begin, end)) {
336         auto pool_indices = pools[p];
337 
338         // Skip empty pools
339         if (pool_indices.empty())
340           continue;
341 
342         /* Prepare scratch space */
343         std::vector<scalar_t> mx_row(nvalues, -std::numeric_limits<scalar_t>::infinity());
344         std::vector<scalar_t> exp_sums_row(nvalues, 0);
345 
346         /* Compute mx */
347         for (int64_t i : pool_indices) {
348           auto values_row = values_accessor[i];
349           for (const auto j : c10::irange(nvalues)) {
350             mx_row[j] = std::max(mx_row[j], values_row[j]);
351           }
352         }
353 
354         /* Apply exp to (v - mx) and sum the results */
355         for (int64_t i : pool_indices) {
356           auto values_row = values_accessor[i];
357           auto out_values_row = out_values_accessor[i];
358           for (const auto j : c10::irange(nvalues)) {
359             auto v = std::exp(values_row[j] - mx_row[j]);
360             if (!LogSoftMax) {
361               out_values_row[j] = v;
362             }
363             exp_sums_row[j] += v;
364           }
365         }
366 
367         for (const auto j : c10::irange(nvalues)) {
368           if (LogSoftMax) {
369             mx_row[j] += std::log(exp_sums_row[j]);
370           } else {
371             exp_sums_row[j] = 1.0 / exp_sums_row[j];
372           }
373         }
374 
375         /* Normalize with the sum of exponents */
376         for (int64_t i : pool_indices) {
377           auto values_row = values_accessor[i];
378           auto out_values_row = out_values_accessor[i];
379           for (const auto j : c10::irange(nvalues)) {
380             if (LogSoftMax) {
381               out_values_row[j] = values_row[j] - mx_row[j];
382             } else {
383               out_values_row[j] *= exp_sums_row[j];
384             }
385           }
386         }
387       }
388     });
389 }
390 
391 template <typename scalar_t, bool LogSoftMax>
cpu_sparse_coo_softmax_backward(const Tensor & grad_input,const Tensor & grad,const Tensor & output,const int64_t dim,ScalarType input_dtype)392 void cpu_sparse_coo_softmax_backward(const Tensor& grad_input, const Tensor& grad, const Tensor& output, const int64_t dim, ScalarType input_dtype) {
393   /*
394 
395     If LogSoftMax == false, then
396 
397       gI_i = sum_j d<output_j>/d<input_i> * grad_j = sum_j output_i * (1[i==j] - output_j) * grad_j
398            = output_i * (grad_i - sum_j output_j * grad_j)
399 
400     else
401 
402       gI_i = (1-exp(output_i)) * grad_i - sum_{j} 1[i!=j] * exp(output_i) * grad_j
403            = grad_i - exp(output_i) * sum_j grad_j.
404 
405     where
406 
407       i, j in range(shape[dim])
408       x_i = x[..., i_dim, ...]
409       output.sparse_dim() == grad.sparse_dim()
410   */
411   auto sparse_dim = output.sparse_dim();
412   auto sizes = output.sizes().vec();
413   auto grad_indices = grad._indices().contiguous();
414   auto grad_values = grad._values().contiguous();
415   auto out_indices = output._indices().contiguous();
416   auto out_values = output._values().contiguous();
417   auto values = grad_input._values();
418   auto indices = grad_input._indices();
419   auto out_nnz = out_values.size(0);
420   auto grad_nnz = grad_values.size(0);
421 
422   values.resize_as_(out_values);
423   values.zero_();
424   indices.resize_as_(out_indices);
425   indices.copy_(out_indices);
426 
427   auto out_offsets = get_offsets(out_indices, sizes, -1);
428   auto grad_offsets = get_offsets(grad_indices, sizes, -1);
429 
430   if (dim >= sparse_dim) {
431     if (out_offsets == grad_offsets) {
432       if (LogSoftMax) {
433         auto r = at::cpu::_log_softmax_backward_data(
434             grad_values, out_values, dim - sparse_dim + 1, input_dtype);
435         values.set_(r);
436       } else {
437         auto r = at::cpu::_softmax_backward_data(grad_values, out_values, dim - sparse_dim + 1, input_dtype);
438         values.set_(r);
439       }
440     } else {
441       for (const auto i : c10::irange(out_nnz)) {
442         auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
443         auto j = low - grad_offsets.begin();
444         if (j < grad_nnz && out_offsets[i] == grad_offsets[j]) {
445           if (LogSoftMax) {
446             auto r = at::cpu::_log_softmax_backward_data(
447                 grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
448             values[i].copy_(r);
449           } else {
450             auto r = at::cpu::_softmax_backward_data(grad_values[j], out_values[i], dim - sparse_dim, input_dtype);
451             values[i].copy_(r);
452           }
453         }
454       }
455     }
456     return;
457   }
458 
459   auto nnz = values.size(0);
460   auto nvalues = get_nvalues(sizes, sparse_dim);
461 
462   auto values_2 = values.view({nnz, nvalues});
463   auto values_accessor = values_2.accessor<scalar_t, 2>();
464 
465   auto out_values_2 = out_values.view({out_nnz, nvalues});
466   auto out_values_accessor = out_values_2.accessor<scalar_t, 2>();
467 
468   auto grad_values_2 = grad_values.view({grad_nnz, nvalues});
469   auto grad_values_accessor = grad_values_2.accessor<scalar_t, 2>();
470 
471   /* Compute independent pools of indices */
472   auto pools = get_pools(out_indices, sizes, dim);
473 
474   int64_t grain_size = 1;
475   parallel_for(0, pools.size(), grain_size, [&](int64_t begin, int64_t end) {
476       for (const auto p : c10::irange(begin, end)) {
477         auto pool_indices = pools[p];
478 
479         // Skip empty pools
480         if (pool_indices.empty())
481           continue;
482 
483         std::vector<scalar_t> tmp_row(nvalues, 0);
484 
485         /* Compute tmp = - sum_j output_j * grad_j */
486         for (int64_t i : pool_indices) {
487           auto out_values_row = out_values_accessor[i];
488           auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
489           auto j = low - grad_offsets.begin();
490 
491           if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
492             auto grad_values_row = grad_values_accessor[j];
493             for (const auto k : c10::irange(nvalues)) {
494               if (LogSoftMax) {
495                 tmp_row[k] -= grad_values_row[k];
496               } else {
497                 tmp_row[k] -= out_values_row[k] * grad_values_row[k];
498               }
499             }
500           }
501         }
502 
503         /* Compute grad_input = output * (grad + tmp)*/
504         for (int64_t i : pool_indices) {
505           auto out_values_row = out_values_accessor[i];
506           auto values_row = values_accessor[i];
507           auto low = std::lower_bound(grad_offsets.begin(), grad_offsets.end(), out_offsets[i]);
508           auto j = low - grad_offsets.begin();
509 
510           if (j < grad_nnz && (out_offsets[i] == grad_offsets[j])) {
511             auto grad_values_row = grad_values_accessor[j];
512             for (const auto k : c10::irange(nvalues)) {
513               if (LogSoftMax) {
514                 values_row[k] = grad_values_row[k] + std::exp(out_values_row[k]) * tmp_row[k];
515               } else {
516                 values_row[k] = out_values_row[k] * (grad_values_row[k] + tmp_row[k]);
517               }
518             }
519           } else {
520             for (const auto k : c10::irange(nvalues)) {
521               if (LogSoftMax) {
522                 values_row[k] = std::exp(out_values_row[k]) * tmp_row[k];
523               } else {
524                 values_row[k] = out_values_row[k] * (tmp_row[k]);
525               }
526             }
527           }
528         }
529       }
530     });
531 }
532 
533 } // anonymous namespace
534 
softmax_sparse_cpu(const Tensor & input_,const int64_t dim_,const bool half_to_float)535 Tensor softmax_sparse_cpu(
536     const Tensor& input_,
537     const int64_t dim_,
538     const bool half_to_float) {
539   Tensor input, output;
540   int64_t dim;
541   std::tie(input, output, dim) = softmax_sparse_input_preprocessing(
542       input_, dim_, half_to_float, "softmax");
543   if (input.numel() == 0) {
544     return output;
545   }
546   AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "softmax", [&] {
547     cpu_sparse_coo_softmax<scalar_t, false>(output, input, dim);
548   });
549   return output;
550 }
551 
log_softmax_sparse_cpu(const Tensor & input_,const int64_t dim_,const bool half_to_float)552 Tensor log_softmax_sparse_cpu(
553     const Tensor& input_,
554     const int64_t dim_,
555     const bool half_to_float) {
556   Tensor input, output;
557   int64_t dim;
558   std::tie(input, output, dim) = softmax_sparse_input_preprocessing(
559       input_, dim_, half_to_float, "log_softmax");
560   if (input.numel() == 0) {
561     return output;
562   }
563   AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "log_softmax", [&] {
564     cpu_sparse_coo_softmax<scalar_t, true>(output, input, dim);
565   });
566   return output;
567 }
568 
softmax_backward_sparse_cpu(const Tensor & grad_,const Tensor & output_,int64_t dim_,const Tensor & input_)569 Tensor softmax_backward_sparse_cpu(
570     const Tensor& grad_,
571     const Tensor& output_,
572     int64_t dim_,
573     const Tensor& input_) {
574   Tensor grad_input, grad, output;
575   int64_t dim;
576   std::tie(grad_input, grad, output, dim) =
577       softmax_backward_sparse_input_preprocessing(
578           grad_, output_, dim_, input_, "softmax_backward");
579   if (output.numel() == 0) {
580     return grad_input;
581   }
582   AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "softmax_backward", [&] {
583     cpu_sparse_coo_softmax_backward<scalar_t, false>(
584         grad_input, grad, output, dim_, input_.scalar_type());
585   });
586   return grad_input;
587 }
588 
log_softmax_backward_sparse_cpu(const Tensor & grad_,const Tensor & output_,int64_t dim_,const Tensor & input_)589 Tensor log_softmax_backward_sparse_cpu(
590     const Tensor& grad_,
591     const Tensor& output_,
592     int64_t dim_,
593     const Tensor& input_) {
594   Tensor grad_input, grad, output;
595   int64_t dim;
596   std::tie(grad_input, grad, output, dim) =
597       softmax_backward_sparse_input_preprocessing(
598           grad_, output_, dim_, input_, "log_softmax_backward");
599   if (output.numel() == 0) {
600     return grad_input;
601   }
602   AT_DISPATCH_FLOATING_TYPES(grad.scalar_type(), "log_softmax_backward", [&] {
603     cpu_sparse_coo_softmax_backward<scalar_t, true>(
604         grad_input, grad, output, dim_, input_.scalar_type());
605   });
606   return grad_input;
607 }
608 
_sparse_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)609 Tensor _sparse_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
610   auto result = [&]() {
611     NoNamesGuard guard;
612     if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
613         return at::_sparse_softmax(input_, dim_, true);
614     } else {
615         Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
616         return at::_sparse_softmax(converted, dim_, false);
617     }
618   }();
619   namedinference::propagate_names(result, input_);
620   return result;
621 }
622 
_sparse_softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)623 Tensor _sparse_softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
624   return at::_sparse_softmax(self, dimname_to_position(self, dim), dtype);
625 }
626 
_sparse_log_softmax(const Tensor & input_,const int64_t dim_,std::optional<ScalarType> dtype)627 Tensor _sparse_log_softmax(const Tensor& input_, const int64_t dim_, std::optional<ScalarType> dtype) {
628   auto result = [&]() {
629     NoNamesGuard guard;
630     if (input_.is_cuda() && input_.scalar_type() == ScalarType::Half && dtype == ScalarType::Float){
631         return at::_sparse_log_softmax(input_, dim_, true);
632     } else {
633         Tensor converted = dtype.has_value() ? input_.toType(dtype.value()) : input_;
634         return at::_sparse_log_softmax(converted, dim_, false);
635     }
636   }();
637   namedinference::propagate_names(result, input_);
638   return result;
639 }
640 
_sparse_log_softmax(const Tensor & self,Dimname dim,std::optional<ScalarType> dtype)641 Tensor _sparse_log_softmax(const Tensor& self, Dimname dim, std::optional<ScalarType> dtype) {
642   return at::_sparse_log_softmax(self, dimname_to_position(self, dim), dtype);
643 }
644 
645 } // namespace at::native
646