1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/index_util.h>
10 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
11
12 namespace torch {
13 namespace executor {
14
check_gather_args(const Tensor & in,int64_t dim,const Tensor & index,bool sparse_grad,Tensor & out)15 bool check_gather_args(
16 const Tensor& in,
17 int64_t dim,
18 const Tensor& index,
19 bool sparse_grad,
20 Tensor& out) {
21 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
22 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
23 ET_LOG_MSG_AND_RETURN_IF_FALSE(
24 index.scalar_type() == ScalarType::Long,
25 "Expected dypte int64 for index");
26 if (index.numel() != 0) {
27 ET_LOG_MSG_AND_RETURN_IF_FALSE(
28 nonzero_dim(in) == nonzero_dim(index),
29 "self and index should have the same dimensionality when index is not empty "
30 "except for the case when one has dimension 0 and the other has dimension 1");
31 }
32
33 // Normalize dim to non-negative value
34 if (dim < 0) {
35 dim += nonzero_dim(in);
36 }
37
38 for (size_t d = 0; d < nonzero_dim(in); ++d) {
39 if (d != dim) {
40 ET_LOG_MSG_AND_RETURN_IF_FALSE(
41 nonempty_size(index, d) <= nonempty_size(in, d),
42 "size of dimension %zd of index should be smaller than the size of that dimension of input if dimension %zd != dim %zd",
43 d,
44 d,
45 (size_t)dim);
46 }
47 }
48 const long* index_data = index.const_data_ptr<long>();
49 for (size_t i = 0; i < index.numel(); ++i) {
50 ET_LOG_MSG_AND_RETURN_IF_FALSE(
51 index_data[i] >= 0 && index_data[i] < nonempty_size(in, dim),
52 "Index is out of bounds for dimension %zd with size %zd",
53 (size_t)dim,
54 nonempty_size(index, dim));
55 }
56
57 return true;
58 }
59
check_index_select_args(const Tensor & in,int64_t dim,const Tensor & index,Tensor & out)60 bool check_index_select_args(
61 const Tensor& in,
62 int64_t dim,
63 const Tensor& index,
64 Tensor& out) {
65 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
66 dim = dim < 0 ? dim + nonzero_dim(in) : dim;
67 ET_LOG_MSG_AND_RETURN_IF_FALSE(
68 nonempty_size(in, dim) > 0,
69 "index_select: Indexing axis dim should be positive");
70
71 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
72 ET_LOG_MSG_AND_RETURN_IF_FALSE(
73 index.scalar_type() == ScalarType::Long ||
74 index.scalar_type() == ScalarType::Int,
75 "Expected index to have type of Long or Int, but found %s",
76 toString(index.scalar_type()));
77
78 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_rank_smaller_or_equal_to(index, 1));
79 if (index.dim() > 0 && in.dim() == 0) {
80 ET_LOG_MSG_AND_RETURN_IF_FALSE(
81 index.numel() == 1,
82 "index_select: Index to scalar must have exactly 1 value");
83 }
84
85 if (index.scalar_type() == ScalarType::Long) {
86 const int64_t* const index_ptr = index.const_data_ptr<int64_t>();
87 for (size_t i = 0; i < index.numel(); ++i) {
88 ET_LOG_MSG_AND_RETURN_IF_FALSE(
89 index_ptr[i] >= 0 && index_ptr[i] < nonempty_size(in, dim),
90 "index[%zu] = %" PRId64 " is out of range [0, %zd)",
91 i,
92 index_ptr[i],
93 static_cast<size_t>(nonempty_size(in, dim)));
94 }
95 } else {
96 const int32_t* const index_ptr = index.const_data_ptr<int32_t>();
97 for (size_t i = 0; i < index.numel(); ++i) {
98 ET_LOG_MSG_AND_RETURN_IF_FALSE(
99 index_ptr[i] >= 0 && index_ptr[i] < nonempty_size(in, dim),
100 "index[%zu] = %" PRId32 " is out of range [0, %zd)",
101 i,
102 index_ptr[i],
103 static_cast<size_t>(nonempty_size(in, dim)));
104 }
105 }
106
107 return true;
108 }
109
get_index_select_out_target_size(const Tensor & in,int64_t dim,const Tensor & index,exec_aten::SizesType * out_sizes,size_t * out_ndim)110 void get_index_select_out_target_size(
111 const Tensor& in,
112 int64_t dim,
113 const Tensor& index,
114 exec_aten::SizesType* out_sizes,
115 size_t* out_ndim) {
116 *out_ndim = in.dim();
117 for (size_t i = 0; i < in.dim(); ++i) {
118 if (i == dim) {
119 out_sizes[i] = index.numel();
120 } else {
121 out_sizes[i] = in.size(i);
122 }
123 }
124 }
125
check_nonzero_args(const Tensor & in,const Tensor & out)126 bool check_nonzero_args(const Tensor& in, const Tensor& out) {
127 (void)in;
128
129 ET_LOG_MSG_AND_RETURN_IF_FALSE(
130 out.scalar_type() == ScalarType::Long,
131 "Expected out to be a Long tensor but received %" PRId8,
132 static_cast<int8_t>(out.scalar_type()));
133
134 ET_LOG_MSG_AND_RETURN_IF_FALSE(
135 out.dim() == 2,
136 "Expected out to be a 2d tensor received %zd",
137 ssize_t(out.dim()));
138
139 return true;
140 }
141
check_scatter_add_args(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)142 bool check_scatter_add_args(
143 const Tensor& self,
144 int64_t dim,
145 const Tensor& index,
146 const Tensor& src,
147 Tensor& out) {
148 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, out));
149 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(self, src));
150 ET_LOG_MSG_AND_RETURN_IF_FALSE(
151 index.scalar_type() == ScalarType::Long,
152 "Expected dypte int64 for index");
153 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(self, dim));
154
155 if (index.numel() == 0) {
156 return true;
157 }
158
159 ET_LOG_MSG_AND_RETURN_IF_FALSE(
160 nonzero_dim(self) == nonzero_dim(src) &&
161 nonzero_dim(self) == nonzero_dim(index),
162 "self, index and src should have same number of dimensions.");
163
164 // Normalize dim to non-negative value
165 if (dim < 0) {
166 dim += nonzero_dim(self);
167 }
168
169 for (size_t d = 0; d < nonzero_dim(self); ++d) {
170 ET_LOG_MSG_AND_RETURN_IF_FALSE(
171 nonempty_size(index, d) <= nonempty_size(src, d),
172 "size of dimension %zd of index should be smaller than the size of that dimension of src",
173 d);
174 if (d != dim) {
175 ET_LOG_MSG_AND_RETURN_IF_FALSE(
176 nonempty_size(index, d) <= nonempty_size(self, d),
177 "size of dimension %zd of index should be smaller than the size of that dimension of self if dimension %zd != dim %zd",
178 d,
179 d,
180 (size_t)dim);
181 }
182 }
183 const long* index_data = index.const_data_ptr<long>();
184 for (size_t i = 0; i < index.numel(); ++i) {
185 ET_LOG_MSG_AND_RETURN_IF_FALSE(
186 index_data[i] >= 0 && index_data[i] < nonempty_size(self, dim),
187 "Index is out of bounds for dimension %zd with size %zd",
188 (size_t)dim,
189 nonempty_size(self, dim));
190 }
191 return true;
192 }
193
check_scatter_src_args(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)194 bool check_scatter_src_args(
195 const Tensor& self,
196 int64_t dim,
197 const Tensor& index,
198 const Tensor& src,
199 Tensor& out) {
200 return check_scatter_add_args(self, dim, index, src, out);
201 }
202
check_scatter_value_args(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & value,Tensor & out)203 bool check_scatter_value_args(
204 const Tensor& self,
205 int64_t dim,
206 const Tensor& index,
207 const Scalar& value,
208 Tensor& out) {
209 return check_gather_args(self, dim, index, false, out);
210 }
211
check_select_scatter_args(const Tensor & in,const Tensor & src,int64_t dim,int64_t index,Tensor & output)212 bool check_select_scatter_args(
213 const Tensor& in,
214 const Tensor& src,
215 int64_t dim,
216 int64_t index,
217 Tensor& output) {
218 /**
219 * Assumptions for inputs:
220 * 1. output size is the same as input size
221 * 2. src size is the same as the selected slice from the input
222 * 3. dim and index values are valid given the input tensor
223 */
224
225 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, output));
226
227 // The dim planed to be selected on shall exist in input
228 ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, in.dim()));
229
230 // The index shall be valid in the given dimenson
231 ET_LOG_MSG_AND_RETURN_IF_FALSE(
232 index >= 0 && index < in.size(dim),
233 "index %" PRId64 " out of range [-%zd,%zd) at in.size( %" PRId64 ")",
234 index,
235 in.size(dim),
236 in.size(dim),
237 dim);
238
239 // The src.dim() shall be one lower than in.dim() since src needs to fit
240 // into the selected data on one dim of input
241 // https://pytorch.org/docs/stable/generated/torch.select_scatter.html
242 ET_LOG_MSG_AND_RETURN_IF_FALSE(
243 in.dim() == src.dim() + 1,
244 "in.dim() %zd != src.dim() + 1 %zd",
245 in.dim(),
246 src.dim() + 1);
247
248 // The size of src tensor should follow these rules:
249 // - src.size(i) shall equal to in.size(i) if i < dim,
250 // - src.size(i) shall equal to in.size(i+1) if i >= dim
251
252 for (ssize_t d = 0; d < in.dim() - 1; d++) {
253 if (d < dim) {
254 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, d, src, d));
255 } else {
256 ET_LOG_AND_RETURN_IF_FALSE(
257 tensors_have_same_size_at_dims(in, d + 1, src, d));
258 }
259 }
260
261 return true;
262 }
263
264 } // namespace executor
265 } // namespace torch
266