xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorAdvancedIndexing.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // Indexing tensors by tensors
4 
5 #include <ATen/core/List.h>
6 #include <ATen/core/Tensor.h>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/ReductionType.h>
9 
10 namespace at {
11 struct TensorIterator;
12 }
13 
14 namespace at::native {
15 
16 using index_put_with_sort_fn = void(*)(Tensor &, const c10::List<std::optional<Tensor>> &, const Tensor &, bool accumulate, bool unsafe);
17 using index_put_with_sort_quantized_fn = void(*)(Tensor& self, const c10::List<std::optional<Tensor>>& indices, const Tensor& value, double scale, int zero_point, bool unsafe);
18 using gather_fn = void (*)(const Tensor & result, const Tensor & self, int64_t dim, const Tensor & index);
19 using scatter_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
20 using scatter_fill_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src);
21 using scatter_add_fn = void(*)(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src);
22 using scatter_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
23                                   const Tensor& src, const ReductionType& reduce);
24 using scatter_scalar_reduce_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
25                                          const Scalar& value, const ReductionType& reduce);
26 using scatter_reduce_two_fn = void(*)(const Tensor& self, const int64_t dim, const Tensor& index,
27                                       const Tensor& src, const ReductionType& reduce);
28 
29 DECLARE_DISPATCH(index_put_with_sort_fn, index_put_with_sort_stub);
30 DECLARE_DISPATCH(index_put_with_sort_quantized_fn, index_put_with_sort_quantized_stub);
31 DECLARE_DISPATCH(gather_fn, gather_stub);
32 DECLARE_DISPATCH(scatter_fn, scatter_stub);
33 DECLARE_DISPATCH(scatter_fill_fn, scatter_fill_stub);
34 DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
35 DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
36 DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
37 DECLARE_DISPATCH(scatter_reduce_two_fn, scatter_reduce_two_stub);
38 
39 TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<std::optional<at::Tensor>>& indices);
40 
41 using scatter_add_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&);
42 using scatter_reduce_expanded_index_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const ReductionType& reduce, bool);
43 using gather_expanded_index_fn = void (*)(const Tensor&, const Tensor&, const Tensor&);
44 
45 DECLARE_DISPATCH(scatter_add_expanded_index_fn, scatter_add_expanded_index_stub);
46 DECLARE_DISPATCH(scatter_reduce_expanded_index_fn, scatter_reduce_expanded_index_stub);
47 DECLARE_DISPATCH(gather_expanded_index_fn, gather_expanded_index_stub);
48 
49 } // namespace at::native
50