1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/sparse/ParamUtils.h>
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorUtils.h>
5 #include <ATen/WrapDimUtils.h>
6 #include <tuple>
7
8 #ifndef AT_PER_OPERATOR_HEADERS
9 #include <ATen/NativeFunctions.h>
10 #else
11 #include <ATen/ops/empty_like_native.h>
12 #endif
13
14 namespace at::native {
15
softmax_sparse_input_preprocessing(const Tensor & input_,const int64_t dim_,const bool half_to_float,CheckedFrom function_name)16 std::tuple<Tensor, Tensor, int64_t> softmax_sparse_input_preprocessing(
17 const Tensor& input_,
18 const int64_t dim_,
19 const bool half_to_float,
20 CheckedFrom function_name) {
21 TORCH_INTERNAL_ASSERT(input_.is_sparse());
22 TORCH_CHECK(
23 !half_to_float,
24 std::string(function_name) +
25 ": with half to float conversion is not supported on " +
26 input_.device().str());
27 auto input = input_.coalesce();
28 Tensor output = at::native::empty_like_sparse_coo(input);
29 int64_t dim = c10::maybe_wrap_dim(dim_, input.dim());
30 return std::make_tuple(input, output, dim);
31 }
32
softmax_backward_sparse_input_preprocessing(const Tensor & grad_,const Tensor & output_,int64_t dim_,const Tensor & input_,CheckedFrom function_name)33 std::tuple<Tensor, Tensor, Tensor, int64_t> softmax_backward_sparse_input_preprocessing(
34 const Tensor& grad_,
35 const Tensor& output_,
36 int64_t dim_,
37 const Tensor& input_,
38 CheckedFrom function_name) {
39 TensorArg grad_arg{grad_, "grad", 1}, output_arg{output_, "output", 2};
40 checkSameSize(function_name, grad_arg, output_arg);
41
42 int64_t dim = maybe_wrap_dim(dim_, grad_.dim());
43
44 auto grad = grad_.coalesce();
45 auto output = output_.coalesce();
46
47 Tensor grad_input = at::native::empty_like_sparse_coo(output);
48 TORCH_CHECK(
49 grad.sparse_dim() == output.sparse_dim(),
50 ": grad and output sparse dimensions must be equal");
51 return std::make_tuple(grad_input, grad, output, dim);
52 }
53
54 } // namespace at::native
55