xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/index_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/index_util.h>
10 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11 
12 namespace torch {
13 namespace executor {
14 
check_gather_args(const Tensor & in,int64_t dim,const Tensor & index,bool sparse_grad,Tensor & out)15 bool check_gather_args(
16     const Tensor& in,
17     int64_t dim,
18     const Tensor& index,
19     bool sparse_grad,
20     Tensor& out) {
21   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
22   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
23   ET_LOG_MSG_AND_RETURN_IF_FALSE(
24       index.scalar_type() == ScalarType::Long,
25       "Expected dypte int64 for index");
26   if (index.numel() != 0) {
27     ET_LOG_MSG_AND_RETURN_IF_FALSE(
28         nonzero_dim(in) == nonzero_dim(index),
29         "self and index should have the same dimensionality when index is not empty "
30         "except for the case when one has dimension 0 and the other has dimension 1");
31   }
32 
33   // Normalize dim to non-negative value
34   if (dim < 0) {
35     dim += nonzero_dim(in);
36   }
37 
38   for (size_t d = 0; d < nonzero_dim(in); ++d) {
39     if (d != dim) {
40       ET_LOG_MSG_AND_RETURN_IF_FALSE(
41           nonempty_size(index, d) <= nonempty_size(in, d),
42           "size of dimension %zd of index should be smaller than the size of that dimension of input if dimension %zd != dim %zd",
43           d,
44           d,
45           (size_t)dim);
46     }
47   }
48   const long* index_data = index.const_data_ptr<long>();
49   for (size_t i = 0; i < index.numel(); ++i) {
50     ET_LOG_MSG_AND_RETURN_IF_FALSE(
51         index_data[i] >= 0 && index_data[i] < nonempty_size(in, dim),
52         "Index is out of bounds for dimension %zd with size %zd",
53         (size_t)dim,
54         nonempty_size(index, dim));
55   }
56 
57   return true;
58 }
59 
check_index_select_args(const Tensor & in,int64_t dim,const Tensor & index,Tensor & out)60 bool check_index_select_args(
61     const Tensor& in,
62     int64_t dim,
63     const Tensor& index,
64     Tensor& out) {
65   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
66   dim = dim < 0 ? dim + nonzero_dim(in) : dim;
67   ET_LOG_MSG_AND_RETURN_IF_FALSE(
68       nonempty_size(in, dim) > 0,
69       "index_select: Indexing axis dim should be positive");
70 
71   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
72   ET_LOG_MSG_AND_RETURN_IF_FALSE(
73       index.scalar_type() == ScalarType::Long ||
74           index.scalar_type() == ScalarType::Int,
75       "Expected index to have type of Long or Int, but found %s",
76       toString(index.scalar_type()));
77 
78   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_smaller_or_equal_to(index, 1));
79   if (index.dim() > 0 && in.dim() == 0) {
80     ET_LOG_MSG_AND_RETURN_IF_FALSE(
81         index.numel() == 1,
82         "index_select: Index to scalar must have exactly 1 value");
83   }
84 
85   if (index.scalar_type() == ScalarType::Long) {
86     const int64_t* const index_ptr = index.const_data_ptr<int64_t>();
87     for (size_t i = 0; i < index.numel(); ++i) {
88       ET_LOG_MSG_AND_RETURN_IF_FALSE(
89           index_ptr[i] >= 0 && index_ptr[i] < nonempty_size(in, dim),
90           "index[%zu] = %" PRId64 " is out of range [0, %zd)",
91           i,
92           index_ptr[i],
93           static_cast<size_t>(nonempty_size(in, dim)));
94     }
95   } else {
96     const int32_t* const index_ptr = index.const_data_ptr<int32_t>();
97     for (size_t i = 0; i < index.numel(); ++i) {
98       ET_LOG_MSG_AND_RETURN_IF_FALSE(
99           index_ptr[i] >= 0 && index_ptr[i] < nonempty_size(in, dim),
100           "index[%zu] = %" PRId32 " is out of range [0, %zd)",
101           i,
102           index_ptr[i],
103           static_cast<size_t>(nonempty_size(in, dim)));
104     }
105   }
106 
107   return true;
108 }
109 
get_index_select_out_target_size(const Tensor & in,int64_t dim,const Tensor & index,exec_aten::SizesType * out_sizes,size_t * out_ndim)110 void get_index_select_out_target_size(
111     const Tensor& in,
112     int64_t dim,
113     const Tensor& index,
114     exec_aten::SizesType* out_sizes,
115     size_t* out_ndim) {
116   *out_ndim = in.dim();
117   for (size_t i = 0; i < in.dim(); ++i) {
118     if (i == dim) {
119       out_sizes[i] = index.numel();
120     } else {
121       out_sizes[i] = in.size(i);
122     }
123   }
124 }
125 
check_nonzero_args(const Tensor & in,const Tensor & out)126 bool check_nonzero_args(const Tensor& in, const Tensor& out) {
127   (void)in;
128 
129   ET_LOG_MSG_AND_RETURN_IF_FALSE(
130       out.scalar_type() == ScalarType::Long,
131       "Expected out to be a Long tensor but received %" PRId8,
132       static_cast<int8_t>(out.scalar_type()));
133 
134   ET_LOG_MSG_AND_RETURN_IF_FALSE(
135       out.dim() == 2,
136       "Expected out to be a 2d tensor received %zd",
137       ssize_t(out.dim()));
138 
139   return true;
140 }
141 
check_scatter_add_args(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)142 bool check_scatter_add_args(
143     const Tensor& self,
144     int64_t dim,
145     const Tensor& index,
146     const Tensor& src,
147     Tensor& out) {
148   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, out));
149   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, src));
150   ET_LOG_MSG_AND_RETURN_IF_FALSE(
151       index.scalar_type() == ScalarType::Long,
152       "Expected dypte int64 for index");
153   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(self, dim));
154 
155   if (index.numel() == 0) {
156     return true;
157   }
158 
159   ET_LOG_MSG_AND_RETURN_IF_FALSE(
160       nonzero_dim(self) == nonzero_dim(src) &&
161           nonzero_dim(self) == nonzero_dim(index),
162       "self, index and src should have same number of dimensions.");
163 
164   // Normalize dim to non-negative value
165   if (dim < 0) {
166     dim += nonzero_dim(self);
167   }
168 
169   for (size_t d = 0; d < nonzero_dim(self); ++d) {
170     ET_LOG_MSG_AND_RETURN_IF_FALSE(
171         nonempty_size(index, d) <= nonempty_size(src, d),
172         "size of dimension %zd of index should be smaller than the size of that dimension of src",
173         d);
174     if (d != dim) {
175       ET_LOG_MSG_AND_RETURN_IF_FALSE(
176           nonempty_size(index, d) <= nonempty_size(self, d),
177           "size of dimension %zd of index should be smaller than the size of that dimension of self if dimension %zd != dim %zd",
178           d,
179           d,
180           (size_t)dim);
181     }
182   }
183   const long* index_data = index.const_data_ptr<long>();
184   for (size_t i = 0; i < index.numel(); ++i) {
185     ET_LOG_MSG_AND_RETURN_IF_FALSE(
186         index_data[i] >= 0 && index_data[i] < nonempty_size(self, dim),
187         "Index is out of bounds for dimension %zd with size %zd",
188         (size_t)dim,
189         nonempty_size(self, dim));
190   }
191   return true;
192 }
193 
check_scatter_src_args(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)194 bool check_scatter_src_args(
195     const Tensor& self,
196     int64_t dim,
197     const Tensor& index,
198     const Tensor& src,
199     Tensor& out) {
200   return check_scatter_add_args(self, dim, index, src, out);
201 }
202 
check_scatter_value_args(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & value,Tensor & out)203 bool check_scatter_value_args(
204     const Tensor& self,
205     int64_t dim,
206     const Tensor& index,
207     const Scalar& value,
208     Tensor& out) {
209   return check_gather_args(self, dim, index, false, out);
210 }
211 
check_select_scatter_args(const Tensor & in,const Tensor & src,int64_t dim,int64_t index,Tensor & output)212 bool check_select_scatter_args(
213     const Tensor& in,
214     const Tensor& src,
215     int64_t dim,
216     int64_t index,
217     Tensor& output) {
218   /**
219    * Assumptions for inputs:
220    * 1. output size is the same as input size
221    * 2. src size is the same as the selected slice from the input
222    * 3. dim and index values are valid given the input tensor
223    */
224 
225   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, output));
226 
227   // The dim planed to be selected on shall exist in input
228   ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, in.dim()));
229 
230   // The index shall be valid in the given dimenson
231   ET_LOG_MSG_AND_RETURN_IF_FALSE(
232       index >= 0 && index < in.size(dim),
233       "index %" PRId64 " out of range [-%zd,%zd) at in.size( %" PRId64 ")",
234       index,
235       in.size(dim),
236       in.size(dim),
237       dim);
238 
239   // The src.dim() shall be one lower than in.dim() since src needs to fit
240   // into the selected data on one dim of input
241   // https://pytorch.org/docs/stable/generated/torch.select_scatter.html
242   ET_LOG_MSG_AND_RETURN_IF_FALSE(
243       in.dim() == src.dim() + 1,
244       "in.dim() %zd != src.dim() + 1 %zd",
245       in.dim(),
246       src.dim() + 1);
247 
248   // The size of src tensor should follow these rules:
249   // - src.size(i) shall equal to in.size(i) if i < dim,
250   // - src.size(i) shall equal to in.size(i+1) if i >= dim
251 
252   for (ssize_t d = 0; d < in.dim() - 1; d++) {
253     if (d < dim) {
254       ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, d, src, d));
255     } else {
256       ET_LOG_AND_RETURN_IF_FALSE(
257           tensors_have_same_size_at_dims(in, d + 1, src, d));
258     }
259   }
260 
261   return true;
262 }
263 
264 } // namespace executor
265 } // namespace torch
266