xref: /aosp_15_r20/external/executorch/kernels/test/op_scatter_add_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 #include <cmath>
18 
19 using namespace ::testing;
20 using exec_aten::Scalar;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24 
25 class OpScatterAddOutTest : public OperatorTest {
26  protected:
op_scatter_add_out(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src,Tensor & out)27   Tensor& op_scatter_add_out(
28       const Tensor& self,
29       int64_t dim,
30       const Tensor& index,
31       const Tensor& src,
32       Tensor& out) {
33     return torch::executor::aten::scatter_add_outf(
34         context_, self, dim, index, src, out);
35   }
36 
37   // Common testing for the operator
38   template <ScalarType DATA_DTYPE>
test_scatter_add_out()39   void test_scatter_add_out() {
40     TensorFactory<ScalarType::Long> tf_index;
41     TensorFactory<DATA_DTYPE> tf_data;
42     const std::vector<int32_t> sizes = {3, 5};
43     // clang-format off
44     Tensor src = tf_data.make(
45       /*sizes=*/{2, 5},
46       {
47         1, 2, 3, 4, 5,
48         6, 7, 8, 9, 10
49       });
50     // clang-format on
51     Tensor self = tf_data.zeros(sizes);
52     Tensor out = tf_data.zeros(sizes);
53     // clang-format off
54     Tensor index = tf_index.make(
55       /*sizes=*/{2, 3},
56       {
57         0, 1, 2,
58         0, 1, 2
59       });
60     // clang-format on
61 
62     // Valid input should give the expected output
63     op_scatter_add_out(self, 0, index, src, out);
64     // clang-format off
65     EXPECT_TENSOR_EQ(
66         out, tf_data.make(
67           sizes,
68           {
69             7, 0, 0,  0, 0,
70             0, 9, 0,  0, 0,
71             0, 0, 11, 0, 0
72           }));
73     // clang-format on
74 
75     // Valid input should give the expected output
76     op_scatter_add_out(self, 1, index, src, out);
77     // clang-format off
78     EXPECT_TENSOR_EQ(
79         out, tf_data.make(sizes,
80         {
81           1, 2, 3, 0, 0,
82           6, 7, 8, 0, 0,
83           0, 0, 0, 0, 0
84         }));
85 
86     src = tf_data.make(
87         /*sizes=*/{2, 3, 3},
88         {
89           // [0, :, :]
90           1,  2,  3,
91           4,  5,  6,
92           7,  8,  9,
93 
94           // [1, :, :]
95           10, 11, 12,
96           13, 14, 15,
97           16, 17, 18
98         });
99     // clang-format on
100     self = tf_data.ones(/*sizes=*/{2, 3, 3});
101     out = tf_data.zeros(/*sizes=*/{2, 3, 3});
102     // clang-format off
103     index = tf_index.make(
104       /*sizes=*/{1, 3, 2},
105       {
106         0, 1,
107         1, 2,
108         0, 2
109       });
110     // clang-format on
111 
112     op_scatter_add_out(self, 1, index, src, out);
113     // clang-format off
114     EXPECT_TENSOR_EQ(
115         out,
116         tf_data.make(
117             /*sizes=*/{2, 3, 3},
118             {
119               // [0, :, :]
120               9, 1,  1,
121               5, 3,  1,
122               1, 14, 1,
123 
124               // [1, :, :]
125               1, 1,  1,
126               1, 1,  1,
127               1, 1,  1
128             }));
129     // clang-format on
130 
131     out = tf_data.zeros(/*sizes=*/{2, 3, 3});
132     op_scatter_add_out(self, 2, index, src, out);
133     // clang-format off
134     EXPECT_TENSOR_EQ(
135         out,
136         tf_data.make(
137             /*sizes=*/{2, 3, 3},
138             {
139               // [0, :, :]
140               2, 3, 1,
141               1, 5, 6,
142               8, 1, 9,
143 
144               // [1, :, :]
145               1, 1, 1,
146               1, 1, 1,
147               1, 1, 1
148             }));
149     // clang-format on
150   }
151 
152   // Invalid dimensions
153   template <ScalarType DATA_DTYPE>
test_scatter_add_out_invalid_dim()154   void test_scatter_add_out_invalid_dim() {
155     TensorFactory<ScalarType::Long> tf_index;
156     TensorFactory<DATA_DTYPE> tf_data;
157     const std::vector<int32_t> sizes = {3, 5};
158     // clang-format off
159     Tensor src = tf_data.make(/*sizes=*/{2, 5},
160       {
161         1, 2, 3, 4, 5,
162         6, 7, 8, 9, 10
163       });
164     Tensor index = tf_index.make(/*sizes=*/{2, 3},
165       {
166         0, 1, 2,
167         0, 1, 2
168       });
169     // clang-format on
170     Tensor self = tf_data.zeros(sizes);
171     Tensor out = tf_data.zeros(sizes);
172 
173     // Invalid dim should die
174     ET_EXPECT_KERNEL_FAILURE(
175         context_, op_scatter_add_out(self, -3, index, src, out));
176     ET_EXPECT_KERNEL_FAILURE(
177         context_, op_scatter_add_out(self, 2, index, src, out));
178 
179     // Self, index and src hsould have same number of dimensions
180     src = tf_data.zeros(/*sizes=*/{2, 2, 2});
181     ET_EXPECT_KERNEL_FAILURE(
182         context_, op_scatter_add_out(self, 0, index, src, out));
183 
184     src = tf_data.zeros(/*sizes=*/{5, 5});
185     index = tf_index.zeros(/*sizes=*/{2, 2, 2});
186     ET_EXPECT_KERNEL_FAILURE(
187         context_, op_scatter_add_out(self, 0, index, src, out));
188 
189     // Size of dimension of index should be smaller than the size of that
190     // dimension of src
191     index = tf_index.zeros(/*sizes=*/{4, 6});
192     ET_EXPECT_KERNEL_FAILURE(
193         context_, op_scatter_add_out(self, 0, index, src, out));
194 
195     // Size of dimension of index should be smaller than the size of that
196     // dimension of self if dimension != dim
197     index = tf_index.zeros(/*sizes=*/{4, 5});
198     ET_EXPECT_KERNEL_FAILURE(
199         context_, op_scatter_add_out(self, 1, index, src, out));
200 
201     // Index out of bound for self in dim
202     index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 3, 0, 1, 3});
203     ET_EXPECT_KERNEL_FAILURE(
204         context_, op_scatter_add_out(self, 0, index, src, out));
205   }
206 
207   // Mismatched shape
208   template <ScalarType DATA_DTYPE>
test_scatter_add_out_mismatched_shape()209   void test_scatter_add_out_mismatched_shape() {
210     TensorFactory<ScalarType::Long> tf_index;
211     TensorFactory<DATA_DTYPE> tf_data;
212 
213     // clang-format off
214     Tensor src = tf_data.make(/*sizes=*/{2, 5},
215       {
216         1, 2, 3, 4, 5,
217         6, 7, 8, 9, 10
218       });
219     Tensor index = tf_index.make(/*sizes=*/{2, 3},
220       {
221         0, 1, 2,
222         0, 1, 2
223       });
224     // clang-format on
225     Tensor self = tf_data.zeros(/*sizes=*/{3, 5});
226     Tensor out = tf_data.zeros(/*sizes=*/{2, 5});
227 
228     // self and out should be of the same shape
229     ET_EXPECT_KERNEL_FAILURE(
230         context_, op_scatter_add_out(self, 0, index, src, out));
231   }
232 
233   /* %python
234   import torch
235   torch.manual_seed(0)
236   input_shape = (2, 3, 4)
237   input = torch.randint(10, input_shape)
238   dim = 2
239   index = torch.randint(input.size(dim), input_shape)
240   src = torch.randint(10, input_shape)
241   expected = torch.scatter_add(input, dim, index, src)
242 
243   scatter_add_template = f"""
244     {declare_tensor_factory("ScalarType::Int", "tf")}
245     {declare_tensor_factory("ScalarType::Long", "tf_index")}
246 
247     {declare_tensor_make_t("input", "tf")}
248     {declare_tensor_make_t("index", "tf_index")}
249     {declare_tensor_make_t("src", "tf")}
250     {declare_tensor_make_t("expected", "tf")}
251     {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
252 
253     op_scatter_add_out(input, $dim$, index, src, out);
254     EXPECT_TENSOR_EQ(out, expected);""" */
255 
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)256   void test_dynamic_shape(
257       const std::vector<int32_t>& out_shape,
258       enum torch::executor::TensorShapeDynamism dynamism) {
259     /* %python
260     %rewrite(scatter_add_template) */
261 
262     TensorFactory<ScalarType::Int> tf;
263     TensorFactory<ScalarType::Long> tf_index;
264 
265     Tensor input = tf.make({2, 3, 4}, {4, 9, 3, 0, 3, 9, 7, 3, 7, 3, 1, 6,
266                                        6, 9, 8, 6, 6, 8, 4, 3, 6, 9, 1, 4});
267     Tensor index =
268         tf_index.make({2, 3, 4}, {0, 1, 1, 1, 1, 0, 1, 0, 3, 0, 3, 1,
269                                   2, 3, 3, 0, 2, 3, 0, 1, 3, 1, 3, 3});
270     Tensor src = tf.make({2, 3, 4}, {2, 1, 0, 9, 3, 1, 1, 0, 3, 6, 6, 7,
271                                      9, 6, 3, 4, 5, 0, 8, 2, 8, 2, 7, 5});
272     Tensor expected =
273         tf.make({2, 3, 4}, {6,  19, 3,  0,  4,  13, 7, 3, 13, 10, 1, 15,
274                             10, 9,  17, 15, 14, 10, 9, 3, 6,  11, 1, 24});
275     Tensor out = tf.zeros(out_shape, dynamism);
276 
277     op_scatter_add_out(input, 2, index, src, out);
278     EXPECT_TENSOR_EQ(out, expected);
279   }
280 };
281 
TEST_F(OpScatterAddOutTest,AllValidInputOutputSupport)282 TEST_F(OpScatterAddOutTest, AllValidInputOutputSupport) {
283 #define TEST_ENTRY(CTYPE, DTYPE) test_scatter_add_out<ScalarType::DTYPE>();
284   ET_FORALL_REAL_TYPES(TEST_ENTRY);
285 #undef TEST_ENTRY
286 }
287 
TEST_F(OpScatterAddOutTest,InfinityAndNANTest)288 TEST_F(OpScatterAddOutTest, InfinityAndNANTest) {
289   TensorFactory<ScalarType::Long> tf_index;
290   TensorFactory<ScalarType::Float> tf_data;
291   const std::vector<int32_t> sizes = {3, 5};
292 
293   // clang-format off
294   Tensor src = tf_data.make(
295       /*sizes=*/{2, 5},
296       {
297         INFINITY, -INFINITY, NAN,       2.33, 3.14,
298         NAN,      INFINITY,  -INFINITY, 3.14, 2.33
299       });
300   // clang-format on
301   Tensor self = tf_data.ones(sizes);
302   Tensor out = tf_data.zeros(sizes);
303   Tensor index = tf_index.make(/*sizes=*/{2, 3}, {0, 1, 2, 0, 1, 2});
304 
305   // Valid input should give the expected output
306   op_scatter_add_out(self, 0, index, src, out);
307   // clang-format off
308   EXPECT_TENSOR_CLOSE(
309       out,
310       tf_data.make(sizes,
311       {
312         NAN, 1,   1,   1, 1,
313         1,   NAN, 1,   1, 1,
314         1,   1,   NAN, 1, 1
315       }));
316   // clang-format on
317 }
318 
TEST_F(OpScatterAddOutTest,InvalidDimensionsDies)319 TEST_F(OpScatterAddOutTest, InvalidDimensionsDies) {
320 #define TEST_ENTRY(CTYPE, DTYPE) \
321   test_scatter_add_out_invalid_dim<ScalarType::DTYPE>();
322   ET_FORALL_REAL_TYPES(TEST_ENTRY);
323 #undef TEST_ENTRY
324 }
325 
TEST_F(OpScatterAddOutTest,MismatchedShapeDies)326 TEST_F(OpScatterAddOutTest, MismatchedShapeDies) {
327   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
328     GTEST_SKIP() << "ATen kernel can handle mismatched shape";
329   }
330 #define TEST_ENTRY(CTYPE, DTYPE) \
331   test_scatter_add_out_mismatched_shape<ScalarType::DTYPE>();
332   ET_FORALL_REAL_TYPES(TEST_ENTRY);
333 #undef TEST_ENTRY
334 }
335 
TEST_F(OpScatterAddOutTest,MismatchedInputDtypesDies)336 TEST_F(OpScatterAddOutTest, MismatchedInputDtypesDies) {
337   TensorFactory<ScalarType::Byte> tf_byte;
338   TensorFactory<ScalarType::Char> tf_char;
339   TensorFactory<ScalarType::Long> tf_long;
340   const std::vector<int32_t> sizes = {3, 5};
341   // clang-format off
342   Tensor src = tf_char.make(/*sizes=*/{2, 5},
343     {
344       1, 2, 3, 4, 5,
345       6, 7, 8, 9, 10
346     });
347   Tensor index = tf_byte.make(/*sizes=*/{2, 3},
348     {
349       0, 1, 2,
350       0, 1, 2
351     });
352   // clang-format on
353   Tensor self = tf_char.zeros(sizes);
354   Tensor out = tf_char.zeros(sizes);
355 
356   // Types other than long for index should die
357   ET_EXPECT_KERNEL_FAILURE(
358       context_, op_scatter_add_out(self, 0, index, src, out));
359 
360   // Mismatched dtype of src and self should die
361   // clang-format off
362   src = tf_char.make(/*sizes=*/{2, 5},
363     {
364       1, 2, 3, 4, 5,
365       6, 7, 8, 9, 10
366     });
367   // clang-format on
368   ET_EXPECT_KERNEL_FAILURE(
369       context_, op_scatter_add_out(self, 0, index, src, out));
370   // clang-format off
371   src = tf_byte.make(/*sizes=*/{2, 5},
372     {
373       1, 2, 3, 4, 5,
374       6, 7, 8, 9, 10
375     });
376   // clang-format on
377   self = tf_byte.zeros(sizes);
378   out = tf_char.zeros(sizes);
379 
380   // Mismatched dtype of self and out should die
381   ET_EXPECT_KERNEL_FAILURE(
382       context_, op_scatter_add_out(self, 0, index, src, out));
383 }
384 
TEST_F(OpScatterAddOutTest,DynamicShapeUpperBoundSameAsExpected)385 TEST_F(OpScatterAddOutTest, DynamicShapeUpperBoundSameAsExpected) {
386   test_dynamic_shape(
387       {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
388 }
389 
TEST_F(OpScatterAddOutTest,DynamicShapeUpperBoundLargerThanExpected)390 TEST_F(OpScatterAddOutTest, DynamicShapeUpperBoundLargerThanExpected) {
391   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
392     GTEST_SKIP() << "Dynamic shape not supported";
393   }
394   test_dynamic_shape(
395       {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
396 }
397 
TEST_F(OpScatterAddOutTest,DynamicShapeUnbound)398 TEST_F(OpScatterAddOutTest, DynamicShapeUnbound) {
399   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
400     GTEST_SKIP() << "Dynamic shape not supported";
401   }
402   test_dynamic_shape(
403       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
404 }
405