xref: /aosp_15_r20/external/executorch/kernels/test/op_index_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 
17 #include <gtest/gtest.h>
18 #include <sys/types.h>
19 
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::optional;
23 using exec_aten::ScalarType;
24 using exec_aten::Tensor;
25 using torch::executor::testing::TensorFactory;
26 
27 using OptTensorArrayRef = ArrayRef<optional<Tensor>>;
28 
29 class OpIndexTensorOutTest : public OperatorTest {
30  protected:
op_index_tensor_out(const Tensor & input,OptTensorArrayRef indices,Tensor & out)31   Tensor& op_index_tensor_out(
32       const Tensor& input,
33       OptTensorArrayRef indices,
34       Tensor& out) {
35 #ifdef USE_ATEN_LIB
36     c10::List<std::optional<at::Tensor>> indices_list(indices);
37     return torch::executor::aten::index_outf(
38         context_, input, indices_list, out);
39 #else
40     return torch::executor::aten::index_outf(context_, input, indices, out);
41 #endif
42   }
43 
44   template <
45       exec_aten::ScalarType INPUT_DTYPE,
46       exec_aten::ScalarType INDEX_DTYPE,
47       exec_aten::ScalarType OUTPUT_DTYPE>
test_dtype()48   void test_dtype() {
49     TensorFactory<INPUT_DTYPE> tf;
50     TensorFactory<INDEX_DTYPE> tfl;
51     TensorFactory<OUTPUT_DTYPE> tfo;
52     TensorFactory<ScalarType::Bool> tfb;
53 
54     // clang-format off
55     Tensor x = tf.make(
56         {3, 2, 4},
57         {
58           // all ones below are from x,
59           // and all zeros are from y.
60           // [0, :, :]
61           1, 1, 1, 1, // [0, 0, :]
62           0, 0, 0, 0, // [0, 1, :]
63 
64           // [1, :, :]
65           1, 1, 1, 1, // [1, 0, :]
66           0, 0, 0, 0, // [1, 1, :]
67 
68           // [2, :, :]
69           1, 1, 1, 1, // [2, 0, :]
70           0, 0, 0, 0, // [2, 1, :]
71         });
72     // clang-format on
73 
74     // indices [0, 1, 2], [1, 0, 3], expressed two different ways
75     optional<Tensor> indices[] = {
76         optional<Tensor>(tfl.make({2}, {0, 1})),
77         optional<Tensor>(tfl.make({2}, {1, 0})),
78         optional<Tensor>(tfl.make({2}, {2, 3}))};
79 
80     optional<Tensor> indices_mixed[] = {
81         optional<Tensor>(tfl.make({2}, {0, 1})),
82         optional<Tensor>(tfb.make({2}, {false, true})),
83         optional<Tensor>(tfl.make({2}, {2, 3}))};
84 
85     std::vector<int32_t> out_size{2};
86 
87     Tensor out_0 = tfo.zeros(out_size);
88     Tensor ret_0 = op_index_tensor_out(x, /*indices=*/indices, out_0);
89 
90     EXPECT_TENSOR_EQ(ret_0, out_0);
91     EXPECT_TENSOR_EQ(ret_0, tfo.make(out_size, {0, 1}));
92 
93     // Repeat the test with alternative indices representation
94 
95     Tensor out_0_with_mixed = tfo.zeros(out_size);
96     Tensor ret_0_with_mixed =
97         op_index_tensor_out(x, /*indices=*/indices, out_0_with_mixed);
98 
99     EXPECT_TENSOR_EQ(ret_0_with_mixed, out_0_with_mixed);
100     EXPECT_TENSOR_EQ(ret_0_with_mixed, tfo.make(out_size, {0, 1}));
101   }
102 
103   /**
104    * Generic test for integral index lists
105    */
test_dtype_enumerate_in_types()106   void test_dtype_enumerate_in_types() {
107 #define TEST_ENTRY(ctype, dtype) \
108   test_dtype<ScalarType::dtype, ScalarType::Long, ScalarType::dtype>();
109 
110     ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
111 
112 #undef TEST_ENTRY
113   }
114 
115   // Run the test by selecting elements in input
run_test_cases(const Tensor & x,OptTensorArrayRef indices,const Tensor & expected)116   void run_test_cases(
117       const Tensor& x,
118       OptTensorArrayRef indices,
119       const Tensor& expected) {
120     // Generated out tensor sharing same size and dtype with expected tensor
121     TensorFactory<ScalarType::Double> tf;
122 
123     const std::vector<int32_t> out_size(
124         expected.sizes().begin(), expected.sizes().end());
125     Tensor out = tf.ones(out_size);
126 
127     Tensor ret = op_index_tensor_out(x, indices, out);
128     EXPECT_TENSOR_EQ(out, ret);
129     EXPECT_TENSOR_EQ(ret, expected);
130   }
131 };
132 
133 //
134 // Correctness Tests
135 //
136 
TEST_F(OpIndexTensorOutTest,IndexMask)137 TEST_F(OpIndexTensorOutTest, IndexMask) {
138   TensorFactory<ScalarType::Double> tf;
139   TensorFactory<ScalarType::Bool> tfb;
140   // clang-format off
141   Tensor x = tf.make(
142       {2, 3, 4},
143       {
144           // [0, :, :]
145           1.,   2.,   3.,   4., // [0, 0, :]
146           5.,   6.,   7.,   8., // [0, 1, :]
147           9.,  10.,  11.,  12., // [0, 2, :]
148 
149           // [1, :, :]
150          -1.,  -2.,  -3.,  -4., // [1, 0, :]
151          -5.,  -6.,  -7.,  -8., // [1, 1, :]
152          -9., -10., -11., -12., // [1, 2, :]
153       });
154   // clang-format on
155 
156   // clang-format off
157   Tensor indices = tfb.make(
158       {2, 3, 4},
159       {
160          // [0, :, :]
161           true, false, false, false, // [0, 0, :]
162          false, false,  true, false, // [0, 1, :]
163          false, false, false, false, // [0, 2, :]
164 
165          // [1, :, :]
166          false,  true, false, false, // [1, 0, :]
167          false, false, false, false, // [1, 1, :]
168          false, false,  true, false, // [1, 2, :]
169       });
170   // clang-format on
171 
172   // clang-format off
173   Tensor expected = tf.make(
174     {4},
175     {1., 7., -2., -11.}
176   );
177   // clang-format on
178 
179   run_test_cases(x, {indices}, expected);
180 }
181 
TEST_F(OpIndexTensorOutTest,SelectFrontDimAllIndexes)182 TEST_F(OpIndexTensorOutTest, SelectFrontDimAllIndexes) {
183   TensorFactory<ScalarType::Double> tf;
184   TensorFactory<ScalarType::Int> tfi;
185   TensorFactory<ScalarType::Long> tfl;
186   TensorFactory<ScalarType::Bool> tfb;
187   // clang-format off
188   Tensor x = tf.make(
189       {2, 3, 4},
190       {
191           // [0, :, :]
192           1.,   2.,   3.,   4., // [0, 0, :]
193           5.,   6.,   7.,   8., // [0, 1, :]
194           9.,  10.,  11.,  12., // [0, 2, :]
195 
196           // [1, :, :]
197          -1.,  -2.,  -3.,  -4., // [1, 0, :]
198          -5.,  -6.,  -7.,  -8., // [1, 1, :]
199          -9., -10., -11., -12., // [1, 2, :]
200       });
201   // clang-format on
202 
203   // Try to select the input value at indices
204   // [1, 0, 1], [1, 0, 2]. This is expressed in various ways to test different
205   // indexing expressions.
206   optional<Tensor> indices[] = {
207       optional<Tensor>(tfl.make({1}, {1})),
208       optional<Tensor>(tfl.make({1}, {0})),
209       optional<Tensor>(tfl.make({2}, {1, 2}))};
210 
211   optional<Tensor> indices_int[] = {
212       optional<Tensor>(tfi.make({1}, {1})),
213       optional<Tensor>(tfi.make({1}, {0})),
214       optional<Tensor>(tfi.make({2}, {1, 2}))};
215 
216   optional<Tensor> indices_negative[] = {
217       optional<Tensor>(tfl.make({1}, {-1})),
218       optional<Tensor>(tfl.make({1}, {0})),
219       optional<Tensor>(tfl.make({2}, {-3, -2}))};
220 
221   optional<Tensor> indices_bool[] = {
222       optional<Tensor>(tfb.make({2}, {false, true})),
223       optional<Tensor>(tfb.make({3}, {true, false, false})),
224       optional<Tensor>(tfl.make({2}, {-3, -2}))};
225 
226   optional<Tensor> indices_mixed[] = {
227       optional<Tensor>(tfb.make({2}, {false, true})),
228       optional<Tensor>(tfl.make({1}, {0})),
229       optional<Tensor>(tfl.make({2}, {-3, -2}))};
230 
231   std::vector<int32_t> out_size{2};
232 
233   // clang-format off
234   Tensor expected = tf.make(
235     out_size,
236     {-2., -3.,}
237   );
238   // clang-format on
239 
240   run_test_cases(x, /*indices=*/indices, expected);
241   run_test_cases(x, /*indices=*/indices_int, expected);
242   run_test_cases(x, /*indices=*/indices_negative, expected);
243   run_test_cases(x, /*indices=*/indices_bool, expected);
244   run_test_cases(x, /*indices=*/indices_mixed, expected);
245 }
246 
TEST_F(OpIndexTensorOutTest,SelectTwoValuesAtSameIndex)247 TEST_F(OpIndexTensorOutTest, SelectTwoValuesAtSameIndex) {
248   TensorFactory<ScalarType::Double> tf;
249   TensorFactory<ScalarType::Long> tfl;
250   // clang-format off
251   Tensor x = tf.make(
252       {2, 3, 4},
253       {
254           // [0, :, :]
255           1.,   2.,   3.,   4., // [0, 0, :]
256           5.,   6.,   7.,   8., // [0, 1, :]
257           9.,  10.,  11.,  12., // [0, 2, :]
258 
259           // [1, :, :]
260          -1.,  -2.,  -3.,  -4., // [1, 0, :]
261          -5.,  -6.,  -7.,  -8., // [1, 1, :]
262          -9., -10., -11., -12., // [1, 2, :]
263       });
264   // clang-format on
265 
266   // Try to select the value at the same index
267   optional<Tensor> indices[] = {
268       optional<Tensor>(tfl.make({1, 2}, {0, 0})),
269       optional<Tensor>(tfl.make({1, 2}, {1, 1})),
270       optional<Tensor>(tfl.make({1, 2}, {2, 2}))};
271 
272   std::vector<int32_t> out_size{1, 2}; // In ATen the size is (1, 2)
273 
274   // clang-format off
275   Tensor expected = tf.make(
276     out_size,
277     {7., 7.,}
278   );
279   // clang-format on
280 
281   run_test_cases(x, /*indices=*/indices, expected);
282 }
283 
TEST_F(OpIndexTensorOutTest,IndicesFewerThanInputDimSupported)284 TEST_F(OpIndexTensorOutTest, IndicesFewerThanInputDimSupported) {
285   TensorFactory<ScalarType::Double> tf;
286   TensorFactory<ScalarType::Int> tfi;
287   TensorFactory<ScalarType::Long> tfl;
288   TensorFactory<ScalarType::Bool> tfb;
289   // clang-format off
290   Tensor x = tf.make(
291       {2, 3, 4},
292       {
293           // [0, :, :]
294           1.,   2.,   3.,   4., // [0, 0, :]
295           5.,   6.,   7.,   8., // [0, 1, :]
296           9.,  10.,  11.,  12., // [0, 2, :]
297 
298           // [1, :, :]
299          -1.,  -2.,  -3.,  -4., // [1, 0, :]
300          -5.,  -6.,  -7.,  -8., // [1, 1, :]
301          -9., -10., -11., -12., // [1, 2, :]
302       });
303   // clang-format on
304 
305   // Try to select the input value at indices
306   // [1, 0, :], [1, 1, :]. This is expressed in various ways to test different
307   // indexing expressions.
308 
309   optional<Tensor> indices[] = {
310       optional<Tensor>(tfl.make({1}, {1})),
311       optional<Tensor>(tfl.make({2}, {0, 1}))};
312 
313   optional<Tensor> indices_mixed[] = {
314       optional<Tensor>(tfi.make({1}, {-1})),
315       optional<Tensor>(tfb.make({3}, {true, true, false}))};
316 
317   std::vector<int32_t> out_size{2, 4};
318 
319   // clang-format off
320   Tensor expected = tf.make(
321     out_size,
322     {
323       -1.,  -2.,  -3.,  -4.,
324       -5.,  -6.,  -7.,  -8.,
325     }
326   );
327   // clang-format on
328 
329   run_test_cases(x, /*indices=*/indices, expected);
330   run_test_cases(x, /*indices=*/indices_mixed, expected);
331 }
332 
TEST_F(OpIndexTensorOutTest,IndicesWithNullTensorsSupported)333 TEST_F(OpIndexTensorOutTest, IndicesWithNullTensorsSupported) {
334   TensorFactory<ScalarType::Double> tf;
335   TensorFactory<ScalarType::Long> tfl;
336   // clang-format off
337   Tensor x = tf.make(
338       {2, 3, 4},
339       {
340           // [0, :, :]
341           1.,   2.,   3.,   4., // [0, 0, :]
342           5.,   6.,   7.,   8., // [0, 1, :]
343           9.,  10.,  11.,  12., // [0, 2, :]
344 
345           // [1, :, :]
346          -1.,  -2.,  -3.,  -4., // [1, 0, :]
347          -5.,  -6.,  -7.,  -8., // [1, 1, :]
348          -9., -10., -11., -12., // [1, 2, :]
349       });
350   // clang-format on
351 
352   optional<Tensor> indices0[] = {
353       optional<Tensor>(),
354       optional<Tensor>(tfl.make({1}, {1})),
355       optional<Tensor>(tfl.make({2}, {0, 1}))};
356 
357   // clang-format off
358   Tensor expected0 = tf.make(
359     {2, 2},
360     {
361        5.,   6.,
362       -5.,  -6.,
363     }
364   );
365   // clang-format on
366 
367   run_test_cases(x, /*indices=*/indices0, expected0);
368 
369   optional<Tensor> indices1[] = {
370       optional<Tensor>(tfl.make({1}, {1})),
371       optional<Tensor>(),
372       optional<Tensor>(tfl.make({2}, {0, 1}))};
373 
374   // clang-format off
375   Tensor expected1 = tf.make(
376     {2, 3},
377     {
378       -1.,  -5.,  -9.,
379       -2.,  -6., -10.,
380     }
381   );
382   // clang-format on
383 
384   run_test_cases(x, /*indices=*/indices1, expected1);
385 
386   optional<Tensor> indices2[] = {
387       optional<Tensor>(tfl.make({1}, {1})),
388       optional<Tensor>(tfl.make({2}, {0, 1})),
389       optional<Tensor>()};
390 
391   // clang-format off
392   Tensor expected2 = tf.make(
393     {2, 4},
394     {
395       -1.,  -2.,  -3.,  -4.,
396       -5.,  -6.,  -7.,  -8.,
397     }
398   );
399   // clang-format on
400 
401   run_test_cases(x, /*indices=*/indices2, expected2);
402 }
403 
TEST_F(OpIndexTensorOutTest,IndicesWithOnlyNullTensorsSupported)404 TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) {
405   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
406     GTEST_SKIP() << "ATen kernel test fails";
407   }
408   TensorFactory<ScalarType::Double> tf;
409 
410   Tensor x = tf.make({2, 3}, {1., 2., 3., 4., 5., 6.});
411   optional<Tensor> indices0[] = {optional<Tensor>()};
412   run_test_cases(x, indices0, x);
413 
414   optional<Tensor> indices1[] = {optional<Tensor>(), optional<Tensor>()};
415   run_test_cases(x, indices1, x);
416 
417   optional<Tensor> indices2[] = {
418       optional<Tensor>(), optional<Tensor>(), optional<Tensor>()};
419   Tensor out = tf.ones({2, 3});
420   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
421       context_, op_index_tensor_out(x, indices2, out), "");
422 }
423 
TEST_F(OpIndexTensorOutTest,EmptyIndicesSupported)424 TEST_F(OpIndexTensorOutTest, EmptyIndicesSupported) {
425   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
426     GTEST_SKIP() << "ATen kernel test fails";
427   }
428   TensorFactory<ScalarType::Float> tf;
429 
430   // Using empty tensors as input.
431   Tensor x = tf.make({2}, {1., 2.});
432 
433   Tensor out = tf.zeros({2});
434 
435   op_index_tensor_out(x, /*indices=*/{}, out);
436   EXPECT_TENSOR_EQ(out, x);
437   // Success if it doesn't assert on the weird-shaped empty input and the
438   // ret is still a empty array
439 }
440 
441 //
442 // Test that all dtypes are supported
443 //
444 
TEST_F(OpIndexTensorOutTest,AllDtypesSupportedForInput)445 TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForInput) {
446   test_dtype_enumerate_in_types();
447 }
448 
TEST_F(OpIndexTensorOutTest,AllDtypesSupportedForIndex)449 TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForIndex) {
450   test_dtype<ScalarType::Double, ScalarType::Long, ScalarType::Double>();
451   test_dtype<ScalarType::Double, ScalarType::Int, ScalarType::Double>();
452 }
453 
454 //
455 // Death Tests
456 //
457 
TEST_F(OpIndexTensorOutTest,IndexOutOfBoundDies)458 TEST_F(OpIndexTensorOutTest, IndexOutOfBoundDies) {
459   TensorFactory<ScalarType::Int> tf;
460   TensorFactory<ScalarType::Long> tfl;
461 
462   Tensor x = tf.ones({1, 1, 1});
463   Tensor out = tf.zeros({1, 1, 1});
464   Tensor index = tfl.make({1}, {5});
465 
466   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
467       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
468 }
469 
TEST_F(OpIndexTensorOutTest,NegativeIndexOutOfBoundDies)470 TEST_F(OpIndexTensorOutTest, NegativeIndexOutOfBoundDies) {
471   TensorFactory<ScalarType::Int> tf;
472   TensorFactory<ScalarType::Long> tfl;
473 
474   Tensor x = tf.ones({1, 1, 1});
475   Tensor out = tf.zeros({1, 1, 1});
476   Tensor index = tfl.make({1}, {-5});
477 
478   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
479       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
480 }
481 
TEST_F(OpIndexTensorOutTest,TooManyBooleanIndexCountDies)482 TEST_F(OpIndexTensorOutTest, TooManyBooleanIndexCountDies) {
483   TensorFactory<ScalarType::Float> tf;
484   TensorFactory<ScalarType::Bool> tfb;
485 
486   Tensor x = tf.ones({1, 1, 1});
487   Tensor out = tf.zeros({1, 1, 1});
488   Tensor index = tfb.make({3}, {true, false, false});
489 
490   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
491       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
492 }
493 
TEST_F(OpIndexTensorOutTest,TooFewBooleanIndexCountDies)494 TEST_F(OpIndexTensorOutTest, TooFewBooleanIndexCountDies) {
495   TensorFactory<ScalarType::Float> tf;
496   TensorFactory<ScalarType::Bool> tfb;
497 
498   Tensor x = tf.ones({4});
499   Tensor out = tf.zeros({1});
500   Tensor index = tfb.make({1}, {true});
501 
502   // ATen kernel will throw exception instead of death
503   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
504       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
505 }
506 
TEST_F(OpIndexTensorOutTest,MismatchedIndexMaskDies)507 TEST_F(OpIndexTensorOutTest, MismatchedIndexMaskDies) {
508   TensorFactory<ScalarType::Float> tf;
509   TensorFactory<ScalarType::Bool> tfb;
510 
511   Tensor x = tf.ones({4, 4});
512   Tensor out = tf.zeros({9});
513   Tensor index = tfb.ones({3, 3});
514 
515   // ATen kernel will throw exception instead of death
516   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
517       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
518 }
519 
TEST_F(OpIndexTensorOutTest,MismatchedOutputDimDies)520 TEST_F(OpIndexTensorOutTest, MismatchedOutputDimDies) {
521   TensorFactory<ScalarType::Int> tf;
522   TensorFactory<ScalarType::Long> tfl;
523 
524   Tensor x = tf.zeros({2, 4, 7, 5});
525   Tensor index = tfl.make({1}, {3});
526 
527   // Should be {1, 4, 7, 5}
528   Tensor out = tf.zeros({2, 4});
529 
530   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
531       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
532 }
533 
TEST_F(OpIndexTensorOutTest,InvalidIndicesDtypeDies)534 TEST_F(OpIndexTensorOutTest, InvalidIndicesDtypeDies) {
535   TensorFactory<ScalarType::Int> tf;
536   TensorFactory<ScalarType::Float> tff;
537 
538   Tensor x = tf.zeros({2, 4, 7, 5});
539   Tensor index = tff.make({1}, {3});
540 
541   Tensor out = tf.zeros({1, 4, 7, 5});
542 
543   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
544       context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
545 }
546 
TEST_F(OpIndexTensorOutTest,InvalidIndicesShapesDies)547 TEST_F(OpIndexTensorOutTest, InvalidIndicesShapesDies) {
548   TensorFactory<ScalarType::Float> tf;
549   TensorFactory<ScalarType::Long> tfl;
550 
551   Tensor x = tf.zeros({2, 4, 7, 5});
552   // clang-format off
553   optional<Tensor> indices[] = {
554       optional<Tensor>(tfl.make({3}, {1, 1, 1,})),
555       optional<Tensor>(tfl.make({2}, {1, 2}))};
556 
557   Tensor out = tf.ones({3, 7, 5});
558   // clang-format on
559 
560   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
561       context_, op_index_tensor_out(x, indices, out), "");
562 }
563 
TEST_F(OpIndexTensorOutTest,InvalidIndicesShapeDies2)564 TEST_F(OpIndexTensorOutTest, InvalidIndicesShapeDies2) {
565   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
566     GTEST_SKIP() << "";
567   }
568   TensorFactory<ScalarType::Float> tf;
569   TensorFactory<ScalarType::Long> tfl;
570 
571   Tensor x = tf.zeros({4, 4});
572   // clang-format off
573   optional<Tensor> indices[] = {
574       optional<Tensor>(tfl.make({2, 2}, {1, 1, 1, 1,})),
575       optional<Tensor>(tfl.make({1, 2}, {3, 0,}))};
576 
577   Tensor out = tf.ones({4});
578   // clang-format on
579 
580   ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
581       context_, op_index_tensor_out(x, indices, out), "");
582 }
583 
584 //
585 // Dynamic Shape Tests
586 //
587 
588 // Test whether resize works when out is having larger size
TEST_F(OpIndexTensorOutTest,UpperBoundOutTensor)589 TEST_F(OpIndexTensorOutTest, UpperBoundOutTensor) {
590   TensorFactory<ScalarType::Double> tf;
591   TensorFactory<ScalarType::Long> tfl;
592   // clang-format off
593   Tensor x = tf.make(
594       {2, 3, 4},
595       {
596           // [0, :, :]
597           1.,   2.,   3.,   4., // [0, 0, :]
598           5.,   6.,   7.,   8., // [0, 1, :]
599           9.,  10.,  11.,  12., // [0, 2, :]
600 
601           // [1, :, :]
602          -1.,  -2.,  -3.,  -4., // [1, 0, :]
603          -5.,  -6.,  -7.,  -8., // [1, 1, :]
604          -9., -10., -11., -12., // [1, 2, :]
605       });
606   // clang-format on
607 
608   // Try to select the tensor from the input
609   // indices [0, 2, 2], [1, 1, 2]
610   optional<Tensor> indices[] = {
611       optional<Tensor>(tfl.make({1, 2}, {0, 1})),
612       optional<Tensor>(tfl.make({1, 2}, {2, 1})),
613       optional<Tensor>(tfl.make({1, 2}, {2, 2}))};
614 
615   Tensor out =
616       tf.zeros({5, 5}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
617   // clang-format off
618   Tensor expected = tf.make(
619     {1, 2},
620     {
621           11.,  -7.
622     }
623   );
624   // clang-format on
625 
626   Tensor ret = op_index_tensor_out(x, indices, out);
627   EXPECT_TENSOR_EQ(out, ret);
628   EXPECT_TENSOR_EQ(ret, expected);
629 }
630