1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17 #include <cmath>
18
19 using namespace ::testing;
20 using exec_aten::Scalar;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24
25 class OpScatterAddOutTest : public OperatorTest {
26 protected:
op_scatter_add_out(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)27 Tensor& op_scatter_add_out(
28 const Tensor& self,
29 int64_t dim,
30 const Tensor& index,
31 const Tensor& src,
32 Tensor& out) {
33 return torch::executor::aten::scatter_add_outf(
34 context_, self, dim, index, src, out);
35 }
36
37 // Common testing for the operator
38 template <ScalarType DATA_DTYPE>
test_scatter_add_out()39 void test_scatter_add_out() {
40 TensorFactory<ScalarType::Long> tf_index;
41 TensorFactory<DATA_DTYPE> tf_data;
42 const std::vector<int32_t> sizes = {3, 5};
43 // clang-format off
44 Tensor src = tf_data.make(
45 /*sizes=*/{2, 5},
46 {
47 1, 2, 3, 4, 5,
48 6, 7, 8, 9, 10
49 });
50 // clang-format on
51 Tensor self = tf_data.zeros(sizes);
52 Tensor out = tf_data.zeros(sizes);
53 // clang-format off
54 Tensor index = tf_index.make(
55 /*sizes=*/{2, 3},
56 {
57 0, 1, 2,
58 0, 1, 2
59 });
60 // clang-format on
61
62 // Valid input should give the expected output
63 op_scatter_add_out(self, 0, index, src, out);
64 // clang-format off
65 EXPECT_TENSOR_EQ(
66 out, tf_data.make(
67 sizes,
68 {
69 7, 0, 0, 0, 0,
70 0, 9, 0, 0, 0,
71 0, 0, 11, 0, 0
72 }));
73 // clang-format on
74
75 // Valid input should give the expected output
76 op_scatter_add_out(self, 1, index, src, out);
77 // clang-format off
78 EXPECT_TENSOR_EQ(
79 out, tf_data.make(sizes,
80 {
81 1, 2, 3, 0, 0,
82 6, 7, 8, 0, 0,
83 0, 0, 0, 0, 0
84 }));
85
86 src = tf_data.make(
87 /*sizes=*/{2, 3, 3},
88 {
89 // [0, :, :]
90 1, 2, 3,
91 4, 5, 6,
92 7, 8, 9,
93
94 // [1, :, :]
95 10, 11, 12,
96 13, 14, 15,
97 16, 17, 18
98 });
99 // clang-format on
100 self = tf_data.ones(/*sizes=*/{2, 3, 3});
101 out = tf_data.zeros(/*sizes=*/{2, 3, 3});
102 // clang-format off
103 index = tf_index.make(
104 /*sizes=*/{1, 3, 2},
105 {
106 0, 1,
107 1, 2,
108 0, 2
109 });
110 // clang-format on
111
112 op_scatter_add_out(self, 1, index, src, out);
113 // clang-format off
114 EXPECT_TENSOR_EQ(
115 out,
116 tf_data.make(
117 /*sizes=*/{2, 3, 3},
118 {
119 // [0, :, :]
120 9, 1, 1,
121 5, 3, 1,
122 1, 14, 1,
123
124 // [1, :, :]
125 1, 1, 1,
126 1, 1, 1,
127 1, 1, 1
128 }));
129 // clang-format on
130
131 out = tf_data.zeros(/*sizes=*/{2, 3, 3});
132 op_scatter_add_out(self, 2, index, src, out);
133 // clang-format off
134 EXPECT_TENSOR_EQ(
135 out,
136 tf_data.make(
137 /*sizes=*/{2, 3, 3},
138 {
139 // [0, :, :]
140 2, 3, 1,
141 1, 5, 6,
142 8, 1, 9,
143
144 // [1, :, :]
145 1, 1, 1,
146 1, 1, 1,
147 1, 1, 1
148 }));
149 // clang-format on
150 }
151
152 // Invalid dimensions
153 template <ScalarType DATA_DTYPE>
test_scatter_add_out_invalid_dim()154 void test_scatter_add_out_invalid_dim() {
155 TensorFactory<ScalarType::Long> tf_index;
156 TensorFactory<DATA_DTYPE> tf_data;
157 const std::vector<int32_t> sizes = {3, 5};
158 // clang-format off
159 Tensor src = tf_data.make(/*sizes=*/{2, 5},
160 {
161 1, 2, 3, 4, 5,
162 6, 7, 8, 9, 10
163 });
164 Tensor index = tf_index.make(/*sizes=*/{2, 3},
165 {
166 0, 1, 2,
167 0, 1, 2
168 });
169 // clang-format on
170 Tensor self = tf_data.zeros(sizes);
171 Tensor out = tf_data.zeros(sizes);
172
173 // Invalid dim should die
174 ET_EXPECT_KERNEL_FAILURE(
175 context_, op_scatter_add_out(self, -3, index, src, out));
176 ET_EXPECT_KERNEL_FAILURE(
177 context_, op_scatter_add_out(self, 2, index, src, out));
178
179 // Self, index and src hsould have same number of dimensions
180 src = tf_data.zeros(/*sizes=*/{2, 2, 2});
181 ET_EXPECT_KERNEL_FAILURE(
182 context_, op_scatter_add_out(self, 0, index, src, out));
183
184 src = tf_data.zeros(/*sizes=*/{5, 5});
185 index = tf_index.zeros(/*sizes=*/{2, 2, 2});
186 ET_EXPECT_KERNEL_FAILURE(
187 context_, op_scatter_add_out(self, 0, index, src, out));
188
189 // Size of dimension of index should be smaller than the size of that
190 // dimension of src
191 index = tf_index.zeros(/*sizes=*/{4, 6});
192 ET_EXPECT_KERNEL_FAILURE(
193 context_, op_scatter_add_out(self, 0, index, src, out));
194
195 // Size of dimension of index should be smaller than the size of that
196 // dimension of self if dimension != dim
197 index = tf_index.zeros(/*sizes=*/{4, 5});
198 ET_EXPECT_KERNEL_FAILURE(
199 context_, op_scatter_add_out(self, 1, index, src, out));
200
201 // Index out of bound for self in dim
202 index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 3, 0, 1, 3});
203 ET_EXPECT_KERNEL_FAILURE(
204 context_, op_scatter_add_out(self, 0, index, src, out));
205 }
206
207 // Mismatched shape
208 template <ScalarType DATA_DTYPE>
test_scatter_add_out_mismatched_shape()209 void test_scatter_add_out_mismatched_shape() {
210 TensorFactory<ScalarType::Long> tf_index;
211 TensorFactory<DATA_DTYPE> tf_data;
212
213 // clang-format off
214 Tensor src = tf_data.make(/*sizes=*/{2, 5},
215 {
216 1, 2, 3, 4, 5,
217 6, 7, 8, 9, 10
218 });
219 Tensor index = tf_index.make(/*sizes=*/{2, 3},
220 {
221 0, 1, 2,
222 0, 1, 2
223 });
224 // clang-format on
225 Tensor self = tf_data.zeros(/*sizes=*/{3, 5});
226 Tensor out = tf_data.zeros(/*sizes=*/{2, 5});
227
228 // self and out should be of the same shape
229 ET_EXPECT_KERNEL_FAILURE(
230 context_, op_scatter_add_out(self, 0, index, src, out));
231 }
232
233 /* %python
234 import torch
235 torch.manual_seed(0)
236 input_shape = (2, 3, 4)
237 input = torch.randint(10, input_shape)
238 dim = 2
239 index = torch.randint(input.size(dim), input_shape)
240 src = torch.randint(10, input_shape)
241 expected = torch.scatter_add(input, dim, index, src)
242
243 scatter_add_template = f"""
244 {declare_tensor_factory("ScalarType::Int", "tf")}
245 {declare_tensor_factory("ScalarType::Long", "tf_index")}
246
247 {declare_tensor_make_t("input", "tf")}
248 {declare_tensor_make_t("index", "tf_index")}
249 {declare_tensor_make_t("src", "tf")}
250 {declare_tensor_make_t("expected", "tf")}
251 {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
252
253 op_scatter_add_out(input, $dim$, index, src, out);
254 EXPECT_TENSOR_EQ(out, expected);""" */
255
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)256 void test_dynamic_shape(
257 const std::vector<int32_t>& out_shape,
258 enum torch::executor::TensorShapeDynamism dynamism) {
259 /* %python
260 %rewrite(scatter_add_template) */
261
262 TensorFactory<ScalarType::Int> tf;
263 TensorFactory<ScalarType::Long> tf_index;
264
265 Tensor input = tf.make({2, 3, 4}, {4, 9, 3, 0, 3, 9, 7, 3, 7, 3, 1, 6,
266 6, 9, 8, 6, 6, 8, 4, 3, 6, 9, 1, 4});
267 Tensor index =
268 tf_index.make({2, 3, 4}, {0, 1, 1, 1, 1, 0, 1, 0, 3, 0, 3, 1,
269 2, 3, 3, 0, 2, 3, 0, 1, 3, 1, 3, 3});
270 Tensor src = tf.make({2, 3, 4}, {2, 1, 0, 9, 3, 1, 1, 0, 3, 6, 6, 7,
271 9, 6, 3, 4, 5, 0, 8, 2, 8, 2, 7, 5});
272 Tensor expected =
273 tf.make({2, 3, 4}, {6, 19, 3, 0, 4, 13, 7, 3, 13, 10, 1, 15,
274 10, 9, 17, 15, 14, 10, 9, 3, 6, 11, 1, 24});
275 Tensor out = tf.zeros(out_shape, dynamism);
276
277 op_scatter_add_out(input, 2, index, src, out);
278 EXPECT_TENSOR_EQ(out, expected);
279 }
280 };
281
TEST_F(OpScatterAddOutTest,AllValidInputOutputSupport)282 TEST_F(OpScatterAddOutTest, AllValidInputOutputSupport) {
283 #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_add_out<ScalarType::DTYPE>();
284 ET_FORALL_REAL_TYPES(TEST_ENTRY);
285 #undef TEST_ENTRY
286 }
287
TEST_F(OpScatterAddOutTest,InfinityAndNANTest)288 TEST_F(OpScatterAddOutTest, InfinityAndNANTest) {
289 TensorFactory<ScalarType::Long> tf_index;
290 TensorFactory<ScalarType::Float> tf_data;
291 const std::vector<int32_t> sizes = {3, 5};
292
293 // clang-format off
294 Tensor src = tf_data.make(
295 /*sizes=*/{2, 5},
296 {
297 INFINITY, -INFINITY, NAN, 2.33, 3.14,
298 NAN, INFINITY, -INFINITY, 3.14, 2.33
299 });
300 // clang-format on
301 Tensor self = tf_data.ones(sizes);
302 Tensor out = tf_data.zeros(sizes);
303 Tensor index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 2, 0, 1, 2});
304
305 // Valid input should give the expected output
306 op_scatter_add_out(self, 0, index, src, out);
307 // clang-format off
308 EXPECT_TENSOR_CLOSE(
309 out,
310 tf_data.make(sizes,
311 {
312 NAN, 1, 1, 1, 1,
313 1, NAN, 1, 1, 1,
314 1, 1, NAN, 1, 1
315 }));
316 // clang-format on
317 }
318
TEST_F(OpScatterAddOutTest,InvalidDimensionsDies)319 TEST_F(OpScatterAddOutTest, InvalidDimensionsDies) {
320 #define TEST_ENTRY(CTYPE, DTYPE) \
321 test_scatter_add_out_invalid_dim<ScalarType::DTYPE>();
322 ET_FORALL_REAL_TYPES(TEST_ENTRY);
323 #undef TEST_ENTRY
324 }
325
TEST_F(OpScatterAddOutTest,MismatchedShapeDies)326 TEST_F(OpScatterAddOutTest, MismatchedShapeDies) {
327 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
328 GTEST_SKIP() << "ATen kernel can handle mismatched shape";
329 }
330 #define TEST_ENTRY(CTYPE, DTYPE) \
331 test_scatter_add_out_mismatched_shape<ScalarType::DTYPE>();
332 ET_FORALL_REAL_TYPES(TEST_ENTRY);
333 #undef TEST_ENTRY
334 }
335
TEST_F(OpScatterAddOutTest,MismatchedInputDtypesDies)336 TEST_F(OpScatterAddOutTest, MismatchedInputDtypesDies) {
337 TensorFactory<ScalarType::Byte> tf_byte;
338 TensorFactory<ScalarType::Char> tf_char;
339 TensorFactory<ScalarType::Long> tf_long;
340 const std::vector<int32_t> sizes = {3, 5};
341 // clang-format off
342 Tensor src = tf_char.make(/*sizes=*/{2, 5},
343 {
344 1, 2, 3, 4, 5,
345 6, 7, 8, 9, 10
346 });
347 Tensor index = tf_byte.make(/*sizes=*/{2, 3},
348 {
349 0, 1, 2,
350 0, 1, 2
351 });
352 // clang-format on
353 Tensor self = tf_char.zeros(sizes);
354 Tensor out = tf_char.zeros(sizes);
355
356 // Types other than long for index should die
357 ET_EXPECT_KERNEL_FAILURE(
358 context_, op_scatter_add_out(self, 0, index, src, out));
359
360 // Mismatched dtype of src and self should die
361 // clang-format off
362 src = tf_char.make(/*sizes=*/{2, 5},
363 {
364 1, 2, 3, 4, 5,
365 6, 7, 8, 9, 10
366 });
367 // clang-format on
368 ET_EXPECT_KERNEL_FAILURE(
369 context_, op_scatter_add_out(self, 0, index, src, out));
370 // clang-format off
371 src = tf_byte.make(/*sizes=*/{2, 5},
372 {
373 1, 2, 3, 4, 5,
374 6, 7, 8, 9, 10
375 });
376 // clang-format on
377 self = tf_byte.zeros(sizes);
378 out = tf_char.zeros(sizes);
379
380 // Mismatched dtype of self and out should die
381 ET_EXPECT_KERNEL_FAILURE(
382 context_, op_scatter_add_out(self, 0, index, src, out));
383 }
384
TEST_F(OpScatterAddOutTest,DynamicShapeUpperBoundSameAsExpected)385 TEST_F(OpScatterAddOutTest, DynamicShapeUpperBoundSameAsExpected) {
386 test_dynamic_shape(
387 {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
388 }
389
TEST_F(OpScatterAddOutTest,DynamicShapeUpperBoundLargerThanExpected)390 TEST_F(OpScatterAddOutTest, DynamicShapeUpperBoundLargerThanExpected) {
391 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
392 GTEST_SKIP() << "Dynamic shape not supported";
393 }
394 test_dynamic_shape(
395 {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
396 }
397
TEST_F(OpScatterAddOutTest,DynamicShapeUnbound)398 TEST_F(OpScatterAddOutTest, DynamicShapeUnbound) {
399 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
400 GTEST_SKIP() << "Dynamic shape not supported";
401 }
402 test_dynamic_shape(
403 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
404 }
405