xref: /aosp_15_r20/external/executorch/kernels/test/op_index_put_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 #include <sys/types.h>
18 
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25 
26 using OptTensorArrayRef = ArrayRef<optional<Tensor>>;
27 
28 class OpIndexPutOutTest : public OperatorTest {
29  protected:
op_index_put_out(const Tensor & input,OptTensorArrayRef indices,const Tensor & values,const bool accumulate,Tensor & out)30   Tensor& op_index_put_out(
31       const Tensor& input,
32       OptTensorArrayRef indices,
33       const Tensor& values,
34       const bool accumulate,
35       Tensor& out) {
36 #ifdef USE_ATEN_LIB
37     c10::List<std::optional<at::Tensor>> indices_list(indices);
38     return torch::executor::aten::index_put_outf(
39         context_, input, indices_list, values, accumulate, out);
40 #else
41     return torch::executor::aten::index_put_outf(
42         context_, input, indices, values, accumulate, out);
43 #endif
44   }
45 
46   template <
47       exec_aten::ScalarType INPUT_DTYPE,
48       exec_aten::ScalarType INDICES_DTYPE>
test_dtype()49   void test_dtype() {
50     TensorFactory<INPUT_DTYPE> tf;
51     TensorFactory<INDICES_DTYPE> tfl;
52     TensorFactory<ScalarType::Bool> tfb;
53 
54     // clang-format off
55     Tensor x = tf.make(
56         {3, 2, 4},
57         {
58           // [0, :, :]
59           1, 1, 1, 1, // [0, 0, :]
60           0, 0, 0, 0, // [0, 1, :]
61 
62           // [1, :, :]
63           1, 1, 1, 1, // [1, 0, :]
64           0, 0, 0, 0, // [1, 1, :]
65 
66           // [2, :, :]
67           1, 1, 1, 1, // [2, 0, :]
68           0, 0, 0, 0, // [2, 1, :]
69         });
70     // clang-format on
71 
72     // First, index_put to make everything equal to 1
73 
74     // indices [0, 1, :], [1, 1, :], [2, 1, :]
75     optional<Tensor> indices[] = {
76         optional<Tensor>(tfl.make({1, 3}, {0, 1, 2})),
77         optional<Tensor>(tfl.make({1, 3}, {1, 1, 1})),
78     };
79     // bool representation of the same index list
80     optional<Tensor> indices_bool[] = {
81         optional<Tensor>(tfb.make({3}, {true, true, true})),
82         optional<Tensor>(tfb.make({2}, {false, true})),
83     };
84 
85     Tensor values = tf.ones({3, 4});
86 
87     std::vector<int32_t> out_size{3, 2, 4};
88 
89     Tensor out = tf.zeros(out_size);
90     Tensor ret =
91         op_index_put_out(x, indices, values, /*accumulate=*/false, out);
92 
93     EXPECT_TENSOR_EQ(ret, out);
94     EXPECT_TENSOR_EQ(ret, tf.ones(out_size));
95 
96     // Repeat the test with bool indices
97     Tensor out_with_bool = tf.zeros(out_size);
98     Tensor ret_with_bool = op_index_put_out(
99         x, indices_bool, values, /*accumulate=*/false, out_with_bool);
100 
101     EXPECT_TENSOR_EQ(ret_with_bool, out_with_bool);
102     EXPECT_TENSOR_EQ(ret_with_bool, tf.ones(out_size));
103 
104     // Then, index_put to make everything equal to 0
105 
106     // indices [0, 1, :], [1, 0, :], [2, 0, :]
107     optional<Tensor> indices_alt[] = {
108         optional<Tensor>(tfl.make({1, 3}, {0, 1, 2})),
109         optional<Tensor>(tfl.make({1, 3}, {0, 0, 0})),
110     };
111     // bool representation of the same index list
112     optional<Tensor> indices_alt_bool[] = {
113         optional<Tensor>(tfb.make({3}, {true, true, true})),
114         optional<Tensor>(tfb.make({2}, {true, false})),
115     };
116 
117     Tensor values_alt = tf.zeros({3, 4});
118 
119     Tensor out_alt = tf.ones(out_size);
120     Tensor ret_alt = op_index_put_out(
121         x, indices_alt, values_alt, /*accumulate=*/false, out_alt);
122 
123     EXPECT_TENSOR_EQ(ret_alt, out_alt);
124     EXPECT_TENSOR_EQ(ret_alt, tf.zeros(out_size));
125 
126     // Repeat the test with bool indices
127     Tensor out_alt_with_bool = tf.ones(out_size);
128     Tensor ret_alt_with_bool = op_index_put_out(
129         x,
130         indices_alt_bool,
131         values_alt,
132         /*accumulate=*/false,
133         out_alt_with_bool);
134 
135     EXPECT_TENSOR_EQ(ret_alt_with_bool, out_alt_with_bool);
136     EXPECT_TENSOR_EQ(ret_alt_with_bool, tf.zeros(out_size));
137   }
138 
139   /* %python
140   import torch
141   torch.manual_seed(0)
142   input = torch.rand(2, 3, 4)
143   indices = [torch.tensor([1]), torch.tensor([0]), torch.tensor([1, 2])]
144   values = torch.rand(2)
145   accumulate = False
146   expected = input.index_put(indices, values, accumulate=accumulate)
147 
148   index_put_template = f"""
149     {declare_tensor_factory("ScalarType::Float", "tf")}
150     {declare_tensor_factory("ScalarType::Long", "tf_indices")}
151 
152     {declare_tensor_make_t("input", "tf")}
153     {declare_optional_tensor_list_make_t("indices", "tf_indices")}
154     {declare_tensor_make_t("values", "tf")}
155     {declare_tensor_make_t("expected", "tf")}
156     {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
157 
158     op_index_put_out(input, indices, values, $accumulate$, out);
159     EXPECT_TENSOR_EQ(out, expected);"""
160   */
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)161   void test_dynamic_shape(
162       const std::vector<int32_t>& out_shape,
163       enum torch::executor::TensorShapeDynamism dynamism) {
164     /* %python
165     %rewrite(index_put_template) */
166 
167     TensorFactory<ScalarType::Float> tf;
168     TensorFactory<ScalarType::Long> tf_indices;
169 
170     Tensor input = tf.make(
171         {2, 3, 4},
172         {0.49625658988952637,  0.7682217955589294,  0.08847743272781372,
173          0.13203048706054688,  0.30742281675338745, 0.6340786814689636,
174          0.4900934100151062,   0.8964447379112244,  0.455627977848053,
175          0.6323062777519226,   0.3488934636116028,  0.40171730518341064,
176          0.022325754165649414, 0.16885894536972046, 0.2938884496688843,
177          0.518521785736084,    0.6976675987243652,  0.800011396408081,
178          0.16102945804595947,  0.28226858377456665, 0.6816085577011108,
179          0.9151939749717712,   0.39709991216659546, 0.8741558790206909});
180     optional<Tensor> indices[] = {
181         optional<Tensor>(tf_indices.make({1}, {1})),
182         optional<Tensor>(tf_indices.make({1}, {0})),
183         optional<Tensor>(tf_indices.make({2}, {1, 2}))};
184     Tensor values = tf.make({2}, {0.41940832138061523, 0.5529070496559143});
185     Tensor expected = tf.make(
186         {2, 3, 4},
187         {0.49625658988952637,  0.7682217955589294,  0.08847743272781372,
188          0.13203048706054688,  0.30742281675338745, 0.6340786814689636,
189          0.4900934100151062,   0.8964447379112244,  0.455627977848053,
190          0.6323062777519226,   0.3488934636116028,  0.40171730518341064,
191          0.022325754165649414, 0.41940832138061523, 0.5529070496559143,
192          0.518521785736084,    0.6976675987243652,  0.800011396408081,
193          0.16102945804595947,  0.28226858377456665, 0.6816085577011108,
194          0.9151939749717712,   0.39709991216659546, 0.8741558790206909});
195     Tensor out = tf.zeros(out_shape, dynamism);
196 
197     op_index_put_out(input, indices, values, false, out);
198     EXPECT_TENSOR_EQ(out, expected);
199   }
200 
201   // Run the test by putting values into the selected elements
run_test_cases(const Tensor & x,OptTensorArrayRef indices,const Tensor & values,const Tensor & expected,const Tensor & expected_accum)202   void run_test_cases(
203       const Tensor& x,
204       OptTensorArrayRef indices,
205       const Tensor& values,
206       const Tensor& expected,
207       const Tensor& expected_accum) {
208     // Generated out tensor sharing same size and dtype with expected tensor
209     TensorFactory<ScalarType::Double> tf;
210 
211     const std::vector<int32_t> out_size(
212         expected.sizes().begin(), expected.sizes().end());
213     Tensor out = tf.ones(out_size);
214 
215     Tensor ret =
216         op_index_put_out(x, indices, values, /*accumulate=*/false, out);
217     EXPECT_TENSOR_EQ(out, ret);
218     EXPECT_TENSOR_EQ(ret, expected);
219 
220     Tensor out_accum = tf.ones(out_size);
221     Tensor ret_accum =
222         op_index_put_out(x, indices, values, /*accumulate=*/true, out_accum);
223     EXPECT_TENSOR_EQ(out_accum, ret_accum);
224     EXPECT_TENSOR_EQ(ret_accum, expected_accum);
225   }
226 };
227 
228 //
229 // Correctness Tests
230 //
231 
TEST_F(OpIndexPutOutTest,IndexPutMask)232 TEST_F(OpIndexPutOutTest, IndexPutMask) {
233   TensorFactory<ScalarType::Double> tf;
234   TensorFactory<ScalarType::Bool> tfb;
235   // clang-format off
236   Tensor x = tf.make(
237       {2, 3, 4},
238       {
239           // [0, :, :]
240           1.,   2.,   3.,   4., // [0, 0, :]
241           5.,   6.,   7.,   8., // [0, 1, :]
242           9.,  10.,  11.,  12., // [0, 2, :]
243 
244           // [1, :, :]
245          -1.,  -2.,  -3.,  -4., // [1, 0, :]
246          -5.,  -6.,  -7.,  -8., // [1, 1, :]
247          -9., -10., -11., -12., // [1, 2, :]
248       });
249   // clang-format on
250 
251   // clang-format off
252   Tensor indices = tfb.make(
253       {2, 3, 4},
254       {
255          // [0, :, :]
256           true, false, false, false, // [0, 0, :]
257          false, false,  true, false, // [0, 1, :]
258          false, false, false, false, // [0, 2, :]
259 
260          // [1, :, :]
261          false,  true, false, false, // [1, 0, :]
262          false, false, false, false, // [1, 1, :]
263          false, false,  true, false, // [1, 2, :]
264       });
265   // clang-format on
266 
267   // clang-format off
268   Tensor values = tf.make(
269     {4},
270     {10., 20., 30., 40.}
271   );
272   // clang-format on
273 
274   // clang-format off
275   Tensor expected = tf.make(
276       {2, 3, 4},
277       {
278           // [0, :, :]
279          10.,   2.,   3.,   4., // [0, 0, :]
280           5.,   6.,  20.,   8., // [0, 1, :]
281           9.,  10.,  11.,  12., // [0, 2, :]
282 
283           // [1, :, :]
284          -1.,  30.,  -3.,  -4., // [1, 0, :]
285          -5.,  -6.,  -7.,  -8., // [1, 1, :]
286          -9., -10.,  40., -12., // [1, 2, :]
287       });
288   // clang-format on
289 
290   // clang-format off
291   Tensor expected_accum = tf.make(
292       {2, 3, 4},
293       {
294           // [0, :, :]
295          11.,   2.,   3.,   4., // [0, 0, :]
296           5.,   6.,  27.,   8., // [0, 1, :]
297           9.,  10.,  11.,  12., // [0, 2, :]
298 
299           // [1, :, :]
300          -1.,  28.,  -3.,  -4., // [1, 0, :]
301          -5.,  -6.,  -7.,  -8., // [1, 1, :]
302          -9., -10.,  29., -12., // [1, 2, :]
303       });
304   // clang-format on
305 
306   run_test_cases(x, {indices}, values, expected, expected_accum);
307 }
308 
TEST_F(OpIndexPutOutTest,IndexPutMaskBroadcast)309 TEST_F(OpIndexPutOutTest, IndexPutMaskBroadcast) {
310   TensorFactory<ScalarType::Double> tf;
311   TensorFactory<ScalarType::Bool> tfb;
312   // clang-format off
313   Tensor x = tf.make(
314       {2, 3, 4},
315       {
316           // [0, :, :]
317           1.,   2.,   3.,   4., // [0, 0, :]
318           5.,   6.,   7.,   8., // [0, 1, :]
319           9.,  10.,  11.,  12., // [0, 2, :]
320 
321           // [1, :, :]
322          -1.,  -2.,  -3.,  -4., // [1, 0, :]
323          -5.,  -6.,  -7.,  -8., // [1, 1, :]
324          -9., -10., -11., -12., // [1, 2, :]
325       });
326   // clang-format on
327 
328   // Try to select the input value at indices
329   // [1, 0, 1], [1, 0, 2]. This is expressed in various ways to test different
330   // indexing expressions.
331 
332   // clang-format off
333   Tensor indices = tfb.make(
334       {2, 3, 4},
335       {
336          // [0, :, :]
337           true, false, false, false, // [0, 0, :]
338          false, false,  true, false, // [0, 1, :]
339          false, false, false, false, // [0, 2, :]
340 
341          // [1, :, :]
342          false,  true, false, false, // [1, 0, :]
343          false, false, false, false, // [1, 1, :]
344          false, false,  true, false, // [1, 2, :]
345       });
346   // clang-format on
347 
348   // clang-format off
349   Tensor values = tf.make(
350     {1},
351     {10.}
352   );
353   // clang-format on
354 
355   // clang-format off
356   Tensor expected = tf.make(
357       {2, 3, 4},
358       {
359           // [0, :, :]
360          10.,   2.,   3.,   4., // [0, 0, :]
361           5.,   6.,  10.,   8., // [0, 1, :]
362           9.,  10.,  11.,  12., // [0, 2, :]
363 
364           // [1, :, :]
365          -1.,  10.,  -3.,  -4., // [1, 0, :]
366          -5.,  -6.,  -7.,  -8., // [1, 1, :]
367          -9., -10.,  10., -12., // [1, 2, :]
368       });
369   // clang-format on
370 
371   // clang-format off
372   Tensor expected_accum = tf.make(
373       {2, 3, 4},
374       {
375           // [0, :, :]
376          11.,   2.,   3.,   4., // [0, 0, :]
377           5.,   6.,  17.,   8., // [0, 1, :]
378           9.,  10.,  11.,  12., // [0, 2, :]
379 
380           // [1, :, :]
381          -1.,   8.,  -3.,  -4., // [1, 0, :]
382          -5.,  -6.,  -7.,  -8., // [1, 1, :]
383          -9., -10.,  -1., -12., // [1, 2, :]
384       });
385   // clang-format on
386 
387   run_test_cases(x, {indices}, values, expected, expected_accum);
388 }
389 
TEST_F(OpIndexPutOutTest,PutFrontDimAllIndexes)390 TEST_F(OpIndexPutOutTest, PutFrontDimAllIndexes) {
391   TensorFactory<ScalarType::Double> tf;
392   TensorFactory<ScalarType::Int> tfi;
393   TensorFactory<ScalarType::Long> tfl;
394   TensorFactory<ScalarType::Bool> tfb;
395   // clang-format off
396   Tensor x = tf.make(
397       {2, 3, 4},
398       {
399           // [0, :, :]
400           1.,   2.,   3.,   4., // [0, 0, :]
401           5.,   6.,   7.,   8., // [0, 1, :]
402           9.,  10.,  11.,  12., // [0, 2, :]
403 
404           // [1, :, :]
405          -1.,  -2.,  -3.,  -4., // [1, 0, :]
406          -5.,  -6.,  -7.,  -8., // [1, 1, :]
407          -9., -10., -11., -12., // [1, 2, :]
408       });
409   // clang-format on
410 
411   // Try to select the input value at indices
412   // [1, 0, 1], [1, 0, 2]. This is expressed in various ways to test different
413   // indexing expressions.
414 
415   optional<Tensor> indices_long[] = {
416       optional<Tensor>(tfl.make({1}, {1})),
417       optional<Tensor>(tfl.make({1}, {0})),
418       optional<Tensor>(tfl.make({2}, {1, 2}))};
419 
420   optional<Tensor> indices_int[] = {
421       optional<Tensor>(tfi.make({1}, {1})),
422       optional<Tensor>(tfi.make({1}, {0})),
423       optional<Tensor>(tfi.make({2}, {1, 2}))};
424 
425   optional<Tensor> indices_negative[] = {
426       optional<Tensor>(tfl.make({1}, {-1})),
427       optional<Tensor>(tfl.make({1}, {0})),
428       optional<Tensor>(tfl.make({2}, {-3, -2}))};
429 
430   optional<Tensor> indices_bool[] = {
431       optional<Tensor>(tfb.make({2}, {false, true})),
432       optional<Tensor>(tfb.make({3}, {true, false, false})),
433       optional<Tensor>(tfl.make({2}, {-3, -2}))};
434 
435   optional<Tensor> indices_mixed[] = {
436       optional<Tensor>(tfb.make({2}, {false, true})),
437       optional<Tensor>(tfl.make({1}, {0})),
438       optional<Tensor>(tfl.make({2}, {-3, -2}))};
439 
440   // clang-format off
441   Tensor values = tf.make(
442     {2},
443     {10., 20.}
444   );
445   // clang-format on
446 
447   // clang-format off
448   Tensor expected = tf.make(
449       {2, 3, 4},
450       {
451           // [0, :, :]
452           1.,   2.,   3.,   4., // [0, 0, :]
453           5.,   6.,   7.,   8., // [0, 1, :]
454           9.,  10.,  11.,  12., // [0, 2, :]
455 
456           // [1, :, :]
457          -1.,  10.,  20.,  -4., // [1, 0, :]
458          -5.,  -6.,  -7.,  -8., // [1, 1, :]
459          -9., -10., -11., -12., // [1, 2, :]
460       });
461   // clang-format on
462 
463   // clang-format off
464   Tensor expected_accum = tf.make(
465       {2, 3, 4},
466       {
467           // [0, :, :]
468           1.,   2.,   3.,   4., // [0, 0, :]
469           5.,   6.,   7.,   8., // [0, 1, :]
470           9.,  10.,  11.,  12., // [0, 2, :]
471 
472           // [1, :, :]
473          -1.,   8.,  17.,  -4., // [1, 0, :]
474          -5.,  -6.,  -7.,  -8., // [1, 1, :]
475          -9., -10., -11., -12., // [1, 2, :]
476       });
477   // clang-format on
478 
479   run_test_cases(x, indices_long, values, expected, expected_accum);
480   run_test_cases(x, indices_int, values, expected, expected_accum);
481   run_test_cases(x, indices_negative, values, expected, expected_accum);
482   run_test_cases(x, indices_bool, values, expected, expected_accum);
483   run_test_cases(x, indices_mixed, values, expected, expected_accum);
484 }
485 
TEST_F(OpIndexPutOutTest,PutTwoValuesAtSameIndex)486 TEST_F(OpIndexPutOutTest, PutTwoValuesAtSameIndex) {
487   TensorFactory<ScalarType::Double> tf;
488   TensorFactory<ScalarType::Long> tfl;
489   // clang-format off
490   Tensor x = tf.make(
491       {2, 3, 4},
492       {
493           // [0, :, :]
494           1.,   2.,   3.,   4., // [0, 0, :]
495           5.,   6.,   7.,   8., // [0, 1, :]
496           9.,  10.,  11.,  12., // [0, 2, :]
497 
498           // [1, :, :]
499          -1.,  -2.,  -3.,  -4., // [1, 0, :]
500          -5.,  -6.,  -7.,  -8., // [1, 1, :]
501          -9., -10., -11., -12., // [1, 2, :]
502       });
503   // clang-format on
504 
505   // Try to select the value at the same index
506   optional<Tensor> indices[] = {
507       optional<Tensor>(tfl.make({1, 2}, {0, 0})),
508       optional<Tensor>(tfl.make({1, 2}, {1, 1})),
509       optional<Tensor>(tfl.make({1, 2}, {2, 2}))};
510 
511   // clang-format off
512   Tensor values = tf.make(
513     {1},
514     {10.,}
515   );
516   // clang-format on
517 
518   // clang-format off
519   Tensor expected = tf.make(
520       {2, 3, 4},
521       {
522           // [0, :, :]
523           1.,   2.,   3.,   4., // [0, 0, :]
524           5.,   6.,  10.,   8., // [0, 1, :]
525           9.,  10.,  11.,  12., // [0, 2, :]
526 
527           // [1, :, :]
528          -1.,  -2.,  -3.,  -4., // [1, 0, :]
529          -5.,  -6.,  -7.,  -8., // [1, 1, :]
530          -9., -10., -11., -12., // [1, 2, :]
531       });
532   // clang-format on
533 
534   // clang-format off
535   Tensor expected_accum = tf.make(
536       {2, 3, 4},
537       {
538           // [0, :, :]
539           1.,   2.,   3.,   4., // [0, 0, :]
540           5.,   6.,  27.,   8., // [0, 1, :]
541           9.,  10.,  11.,  12., // [0, 2, :]
542 
543           // [1, :, :]
544          -1.,  -2.,  -3.,  -4., // [1, 0, :]
545          -5.,  -6.,  -7.,  -8., // [1, 1, :]
546          -9., -10., -11., -12., // [1, 2, :]
547       });
548   // clang-format on
549 
550   run_test_cases(x, /*indices=*/indices, values, expected, expected_accum);
551 }
552 
TEST_F(OpIndexPutOutTest,IndicesFewerThanInputDimSupported)553 TEST_F(OpIndexPutOutTest, IndicesFewerThanInputDimSupported) {
554   TensorFactory<ScalarType::Double> tf;
555   TensorFactory<ScalarType::Int> tfi;
556   TensorFactory<ScalarType::Long> tfl;
557   TensorFactory<ScalarType::Bool> tfb;
558   // clang-format off
559   Tensor x = tf.make(
560       {2, 3, 4},
561       {
562           // [0, :, :]
563           1.,   2.,   3.,   4., // [0, 0, :]
564           5.,   6.,   7.,   8., // [0, 1, :]
565           9.,  10.,  11.,  12., // [0, 2, :]
566 
567           // [1, :, :]
568          -1.,  -2.,  -3.,  -4., // [1, 0, :]
569          -5.,  -6.,  -7.,  -8., // [1, 1, :]
570          -9., -10., -11., -12., // [1, 2, :]
571       });
572   // clang-format on
573 
574   // Try to select the input value at indices
575   // [1, 0, :], [1, 1, :]. This is expressed in various ways to test different
576   // indexing expressions.
577 
578   optional<Tensor> indices_long[] = {
579       optional<Tensor>(tfl.make({1}, {1})),
580       optional<Tensor>(tfl.make({2}, {0, 1}))};
581 
582   optional<Tensor> indices_mixed[] = {
583       optional<Tensor>(tfi.make({1}, {-1})),
584       optional<Tensor>(tfb.make({3}, {true, true, false}))};
585 
586   // clang-format off
587   Tensor values = tf.make(
588     {2, 4},
589     {
590        10.,  20.,  30.,  40.,
591       -10., -20., -30., -40.,
592     }
593   );
594   // clang-format on
595 
596   // clang-format off
597   Tensor expected = tf.make(
598       {2, 3, 4},
599       {
600           // [0, :, :]
601           1.,   2.,   3.,   4., // [0, 0, :]
602           5.,   6.,   7.,   8., // [0, 1, :]
603           9.,  10.,  11.,  12., // [0, 2, :]
604 
605           // [1, :, :]
606           10.,  20.,  30.,  40., // [1, 0, :]
607          -10., -20., -30., -40., // [1, 1, :]
608           -9., -10., -11., -12., // [1, 2, :]
609       });
610   // clang-format on
611 
612   // clang-format off
613   Tensor expected_accum = tf.make(
614       {2, 3, 4},
615       {
616           // [0, :, :]
617           1.,   2.,   3.,   4., // [0, 0, :]
618           5.,   6.,   7.,   8., // [0, 1, :]
619           9.,  10.,  11.,  12., // [0, 2, :]
620 
621           // [1, :, :]
622            9.,  18.,  27.,  36., // [1, 0, :]
623          -15., -26., -37., -48., // [1, 1, :]
624           -9., -10., -11., -12., // [1, 2, :]
625       });
626   // clang-format on
627 
628   run_test_cases(x, indices_long, values, expected, expected_accum);
629   run_test_cases(x, indices_mixed, values, expected, expected_accum);
630 }
631 
TEST_F(OpIndexPutOutTest,IndicesFewerThanInputDimSupportedSameValue)632 TEST_F(OpIndexPutOutTest, IndicesFewerThanInputDimSupportedSameValue) {
633   TensorFactory<ScalarType::Double> tf;
634   TensorFactory<ScalarType::Long> tfl;
635   // clang-format off
636   Tensor x = tf.make(
637       {2, 3, 4},
638       {
639           // [0, :, :]
640           1.,   2.,   3.,   4., // [0, 0, :]
641           5.,   6.,   7.,   8., // [0, 1, :]
642           9.,  10.,  11.,  12., // [0, 2, :]
643 
644           // [1, :, :]
645          -1.,  -2.,  -3.,  -4., // [1, 0, :]
646          -5.,  -6.,  -7.,  -8., // [1, 1, :]
647          -9., -10., -11., -12., // [1, 2, :]
648       });
649   // clang-format on
650 
651   // Try to select the input value at indices
652   // [1, 0, :], [1, 1, :]
653   optional<Tensor> indices[] = {
654       optional<Tensor>(tfl.make({1}, {1})),
655       optional<Tensor>(tfl.make({2}, {0, 1}))};
656 
657   // clang-format off
658   Tensor values = tf.make(
659     {1},
660     {10.}
661   );
662   // clang-format on
663 
664   // clang-format off
665   Tensor expected = tf.make(
666       {2, 3, 4},
667       {
668           // [0, :, :]
669           1.,   2.,   3.,   4., // [0, 0, :]
670           5.,   6.,   7.,   8., // [0, 1, :]
671           9.,  10.,  11.,  12., // [0, 2, :]
672 
673           // [1, :, :]
674          10.,  10.,  10.,  10., // [1, 0, :]
675          10.,  10.,  10.,  10., // [1, 1, :]
676          -9., -10., -11., -12., // [1, 2, :]
677       });
678   // clang-format on
679 
680   // clang-format off
681   Tensor expected_accum = tf.make(
682       {2, 3, 4},
683       {
684           // [0, :, :]
685           1.,   2.,   3.,   4., // [0, 0, :]
686           5.,   6.,   7.,   8., // [0, 1, :]
687           9.,  10.,  11.,  12., // [0, 2, :]
688 
689           // [1, :, :]
690           9.,   8.,   7.,   6., // [1, 0, :]
691           5.,   4.,   3.,   2., // [1, 1, :]
692          -9., -10., -11., -12., // [1, 2, :]
693       });
694   // clang-format on
695 
696   run_test_cases(x, /*indices=*/indices, values, expected, expected_accum);
697 }
698 
699 //
700 // Test that all dtypes are supported
701 //
702 
703 /**
704  * Generic test for integral index lists
705  */
TEST_F(OpIndexPutOutTest,AllDtypesSupportedForInput)706 TEST_F(OpIndexPutOutTest, AllDtypesSupportedForInput) {
707 #define TEST_ENTRY(ctype, dtype) \
708   test_dtype<ScalarType::dtype, ScalarType::Long>();
709 
710   ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
711 
712 #undef TEST_ENTRY
713 }
714 
TEST_F(OpIndexPutOutTest,AllDtypesSupportedForIndicesList)715 TEST_F(OpIndexPutOutTest, AllDtypesSupportedForIndicesList) {
716   test_dtype<ScalarType::Float, ScalarType::Long>();
717   test_dtype<ScalarType::Float, ScalarType::Int>();
718 }
719 
720 //
721 // Death Tests
722 //
723 
TEST_F(OpIndexPutOutTest,IndexOutOfBoundDies)724 TEST_F(OpIndexPutOutTest, IndexOutOfBoundDies) {
725   TensorFactory<ScalarType::Float> tf;
726   TensorFactory<ScalarType::Long> tfl;
727 
728   Tensor x = tf.ones({1, 1, 1});
729   Tensor out = tf.zeros({1, 1, 1});
730   Tensor index = tfl.make({1}, {5});
731 
732   // clang-format off
733   Tensor values = tf.make(
734     {1},
735     {10.}
736   );
737   // clang-format on
738 
739   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
740       context_,
741       op_index_put_out(
742           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
743       "");
744 }
745 
TEST_F(OpIndexPutOutTest,NegativeIndexOutOfBoundDies)746 TEST_F(OpIndexPutOutTest, NegativeIndexOutOfBoundDies) {
747   TensorFactory<ScalarType::Float> tf;
748   TensorFactory<ScalarType::Long> tfl;
749 
750   Tensor x = tf.ones({1, 1, 1});
751   Tensor out = tf.zeros({1, 1, 1});
752   Tensor index = tfl.make({1}, {-5});
753 
754   // clang-format off
755   Tensor values = tf.make(
756     {1},
757     {10.}
758   );
759   // clang-format on
760 
761   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
762       context_,
763       op_index_put_out(
764           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
765       "");
766 }
767 
TEST_F(OpIndexPutOutTest,TooManyBooleanIndexCountDies)768 TEST_F(OpIndexPutOutTest, TooManyBooleanIndexCountDies) {
769   TensorFactory<ScalarType::Float> tf;
770   TensorFactory<ScalarType::Bool> tfb;
771 
772   Tensor x = tf.ones({1, 1, 1});
773   Tensor out = tf.zeros({1, 1, 1});
774   Tensor index = tfb.make({3}, {true, true, false});
775 
776   // clang-format off
777   Tensor values = tf.make(
778     {1},
779     {10.}
780   );
781   // clang-format on
782 
783   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
784       context_,
785       op_index_put_out(
786           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
787       "");
788 }
789 
TEST_F(OpIndexPutOutTest,TooFewBooleanIndexCountDies)790 TEST_F(OpIndexPutOutTest, TooFewBooleanIndexCountDies) {
791   TensorFactory<ScalarType::Float> tf;
792   TensorFactory<ScalarType::Bool> tfb;
793 
794   Tensor x = tf.ones({4});
795   Tensor out = tf.zeros({4});
796   Tensor index = tfb.make({1}, {true});
797 
798   // clang-format off
799   Tensor values = tf.make(
800     {1},
801     {10.}
802   );
803   // clang-format on
804 
805   // ATen kernel will throw exception instead of death
806   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
807       context_,
808       op_index_put_out(
809           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
810       "");
811 }
812 
TEST_F(OpIndexPutOutTest,MismatchedIndexMaskDies)813 TEST_F(OpIndexPutOutTest, MismatchedIndexMaskDies) {
814   TensorFactory<ScalarType::Float> tf;
815   TensorFactory<ScalarType::Bool> tfb;
816 
817   Tensor x = tf.ones({4, 4});
818   Tensor out = tf.zeros({4, 4});
819   Tensor index = tfb.ones({3, 3});
820 
821   // clang-format off
822   Tensor values = tf.make(
823     {1},
824     {10.}
825   );
826   // clang-format on
827 
828   // ATen kernel will throw exception instead of death
829   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
830       context_,
831       op_index_put_out(
832           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
833       "");
834 }
835 
TEST_F(OpIndexPutOutTest,MismatchedOutputDtypesDies)836 TEST_F(OpIndexPutOutTest, MismatchedOutputDtypesDies) {
837   TensorFactory<ScalarType::Float> tf_float;
838   TensorFactory<ScalarType::Double> tf_double;
839   TensorFactory<ScalarType::Long> tf_long;
840 
841   Tensor x = tf_float.zeros({1, 2, 2});
842 
843   // Size is compatible to the output, but a mismatched dtype.
844   Tensor out = tf_double.ones({1, 2, 2});
845   Tensor index = tf_long.make({1}, {0});
846 
847   // clang-format off
848   Tensor values = tf_float.make(
849     {1},
850     {10.}
851   );
852   // clang-format on
853 
854   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
855       context_,
856       op_index_put_out(
857           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
858       "");
859 }
860 
TEST_F(OpIndexPutOutTest,MismatchedValuesDtypesDies)861 TEST_F(OpIndexPutOutTest, MismatchedValuesDtypesDies) {
862   TensorFactory<ScalarType::Float> tf_float;
863   TensorFactory<ScalarType::Double> tf_double;
864   TensorFactory<ScalarType::Long> tf_long;
865 
866   Tensor x = tf_float.zeros({1, 2, 2});
867 
868   // Size is compatible to the output, but a mismatched dtype.
869   Tensor out = tf_float.ones({1, 2, 2});
870   Tensor index = tf_long.make({1}, {0});
871 
872   // clang-format off
873   Tensor values = tf_double.make(
874     {1},
875     {10.}
876   );
877   // clang-format on
878 
879   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
880       context_,
881       op_index_put_out(
882           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
883       "");
884 }
885 
TEST_F(OpIndexPutOutTest,ValuesSizeMismatchDimDies)886 TEST_F(OpIndexPutOutTest, ValuesSizeMismatchDimDies) {
887   TensorFactory<ScalarType::Float> tf;
888   TensorFactory<ScalarType::Long> tfl;
889 
890   Tensor x = tf.zeros({2, 4, 7, 5});
891   Tensor index = tfl.make({1}, {1});
892 
893   Tensor out = tf.ones({2, 4, 7, 5});
894 
895   // clang-format off
896   Tensor values = tf.make(
897     {1, 2},
898     {10., 10.}
899   );
900   // clang-format on
901 
902   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
903       context_,
904       op_index_put_out(
905           x, /*indices=*/{index}, values, /*accumulate=*/false, out),
906       "");
907 }
908 
TEST_F(OpIndexPutOutTest,InvalidIndicesDtypeDies)909 TEST_F(OpIndexPutOutTest, InvalidIndicesDtypeDies) {
910   TensorFactory<ScalarType::Float> tf;
911   TensorFactory<ScalarType::Float> tff;
912 
913   Tensor x = tf.zeros({2, 4, 7, 5});
914   // clang-format off
915   optional<Tensor> indices[] = {
916       optional<Tensor>(tff.make({3}, {1, 1, 1,})),
917       optional<Tensor>(tff.make({2}, {1, 2}))};
918   // clang-format on
919 
920   Tensor out = tf.ones({2, 4, 7, 5});
921 
922   // clang-format off
923   Tensor values = tf.make(
924     {1,},
925     {10}
926   );
927   // clang-format on
928 
929   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
930       context_,
931       op_index_put_out(x, indices, values, /*accumulate=*/false, out),
932       "");
933 }
934 
TEST_F(OpIndexPutOutTest,InvalidIndicesShapesDies)935 TEST_F(OpIndexPutOutTest, InvalidIndicesShapesDies) {
936   TensorFactory<ScalarType::Float> tf;
937   TensorFactory<ScalarType::Long> tfl;
938 
939   Tensor x = tf.zeros({2, 4, 7, 5});
940   // clang-format off
941   optional<Tensor> indices[] = {
942       optional<Tensor>(tfl.make({3}, {1, 1, 1,})),
943       optional<Tensor>(tfl.make({2}, {1, 2}))};
944 
945   Tensor out = tf.ones({2, 4, 7, 5});
946   // clang-format on
947 
948   // clang-format off
949   Tensor values = tf.make(
950     {1, 2},
951     {10., 10.}
952   );
953   // clang-format on
954 
955   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
956       context_,
957       op_index_put_out(x, indices, values, /*accumulate=*/false, out),
958       "");
959 }
960 
TEST_F(OpIndexPutOutTest,NonLinearIndices)961 TEST_F(OpIndexPutOutTest, NonLinearIndices) {
962   TensorFactory<ScalarType::Float> tf;
963   TensorFactory<ScalarType::Long> tfl;
964 
965   Tensor x = tf.zeros({4, 4});
966   // clang-format off
967   optional<Tensor> indices[] = {
968       optional<Tensor>(tfl.make({2, 2}, {1, 1, 1, 1,})),
969       optional<Tensor>(tfl.make({1, 2}, {3, 0,}))};
970 
971   Tensor out = tf.ones({4, 4});
972   // clang-format on
973 
974   // clang-format off
975   Tensor values = tf.make(
976     {1},
977     {10.}
978   );
979   // clang-format on
980 
981   Tensor expected =
982       tf.make({4, 4}, {0, 0, 0, 0, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0});
983 
984   Tensor ret = op_index_put_out(x, indices, values, /*accumulate=*/false, out);
985 
986   EXPECT_TENSOR_EQ(ret, out);
987   EXPECT_TENSOR_EQ(ret, expected);
988 }
989 
990 //
991 // Dynamic Shape Tests
992 //
993 
TEST_F(OpIndexPutOutTest,DynamicShapeUpperBoundSameAsExpected)994 TEST_F(OpIndexPutOutTest, DynamicShapeUpperBoundSameAsExpected) {
995   test_dynamic_shape(
996       {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
997 }
998 
TEST_F(OpIndexPutOutTest,DynamicShapeUpperBoundLargerThanExpected)999 TEST_F(OpIndexPutOutTest, DynamicShapeUpperBoundLargerThanExpected) {
1000   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
1001     GTEST_SKIP() << "Dynamic shape not supported";
1002   }
1003   test_dynamic_shape(
1004       {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
1005 }
1006 
TEST_F(OpIndexPutOutTest,DynamicShapeUnbound)1007 TEST_F(OpIndexPutOutTest, DynamicShapeUnbound) {
1008   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
1009     GTEST_SKIP() << "Dynamic shape not supported";
1010   }
1011   test_dynamic_shape(
1012       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
1013 }
1014