1 #pragma once 2 3 // Indexing tensors by tensors 4 5 #include <ATen/core/List.h> 6 #include <ATen/core/Tensor.h> 7 #include <ATen/native/DispatchStub.h> 8 #include <ATen/native/ReductionType.h> 9 10 namespace at { 11 struct TensorIterator; 12 } 13 14 namespace at::native { 15 16 using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<std::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe); 17 using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<std::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe); 18 using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index); 19 using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); 20 using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src); 21 using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src); 22 using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, 23 const Tensor& src, const ReductionType& reduce); 24 using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, 25 const Scalar& value, const ReductionType& reduce); 26 using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index, 27 const Tensor& src, const ReductionType& reduce); 28 29 DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub); 30 DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub); 31 DECLARE_DISPATCH(gather_fn, gather_stub); 32 DECLARE_DISPATCH(scatter_fn, scatter_stub); 33 DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub); 34 DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub); 35 DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub); 36 DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub); 37 DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub); 38 39 TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<std::optional<at::Tensor>>& indices); 40 41 using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&); 42 using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool); 43 using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&); 44 45 DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub); 46 DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub); 47 DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub); 48 49 } // namespace at::native 50