xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/sparse/ParamUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorUtils.h>
5 #include <tuple>
6 
7 namespace at::native {
8 
9 TORCH_API std::tuple<Tensor, Tensor, int64_t> softmax_sparse_input_preprocessing(
10     const Tensor& input_,
11     const int64_t dim_,
12     const bool half_to_float,
13     CheckedFrom function_name);
14 
15 TORCH_API std::tuple<Tensor, Tensor, Tensor, int64_t> softmax_backward_sparse_input_preprocessing(
16     const Tensor& grad_,
17     const Tensor& output_,
18     int64_t dim_,
19     const Tensor& input_,
20     CheckedFrom function_name);
21 
22 } // namespace at::native
23