1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/TensorUtils.h> 5 #include <tuple> 6 7 namespace at::native { 8 9 TORCH_API std::tuple<Tensor, Tensor, int64_t> softmax_sparse_input_preprocessing( 10 const Tensor& input_, 11 const int64_t dim_, 12 const bool half_to_float, 13 CheckedFrom function_name); 14 15 TORCH_API std::tuple<Tensor, Tensor, Tensor, int64_t> softmax_backward_sparse_input_preprocessing( 16 const Tensor& grad_, 17 const Tensor& output_, 18 int64_t dim_, 19 const Tensor& input_, 20 CheckedFrom function_name); 21 22 } // namespace at::native 23