xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/ParamUtils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/ParamUtils.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/WrapDimUtils.h>
6 #include <tuple>
7 
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty_like_native.h>
12 #endif
13 
14 namespace at::native {
15 
softmax_sparse_input_preprocessing(const Tensor & input_,const int64_t dim_,const bool half_to_float,CheckedFrom function_name)16 std::tuple<Tensor, Tensor, int64_t> softmax_sparse_input_preprocessing(
17     const Tensor& input_,
18     const int64_t dim_,
19     const bool half_to_float,
20     CheckedFrom function_name) {
21   TORCH_INTERNAL_ASSERT(input_.is_sparse());
22   TORCH_CHECK(
23       !half_to_float,
24       std::string(function_name) +
25           ": with half to float conversion is not supported on " +
26           input_.device().str());
27   auto input = input_.coalesce();
28   Tensor output = at::native::empty_like_sparse_coo(input);
29   int64_t dim = c10::maybe_wrap_dim(dim_, input.dim());
30   return std::make_tuple(input, output, dim);
31 }
32 
softmax_backward_sparse_input_preprocessing(const Tensor & grad_,const Tensor & output_,int64_t dim_,const Tensor & input_,CheckedFrom function_name)33 std::tuple<Tensor, Tensor, Tensor, int64_t> softmax_backward_sparse_input_preprocessing(
34     const Tensor& grad_,
35     const Tensor& output_,
36     int64_t dim_,
37     const Tensor& input_,
38     CheckedFrom function_name) {
39   TensorArg grad_arg{grad_, "grad", 1}, output_arg{output_, "output", 2};
40   checkSameSize(function_name, grad_arg, output_arg);
41 
42   int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
43 
44   auto grad = grad_.coalesce();
45   auto output = output_.coalesce();
46 
47   Tensor grad_input = at::native::empty_like_sparse_coo(output);
48   TORCH_CHECK(
49       grad.sparse_dim() == output.sparse_dim(),
50       ": grad and output sparse dimensions must be equal");
51   return std::make_tuple(grad_input, grad, output, dim);
52 }
53 
54 } // namespace at::native
55