1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16
17 #include <gtest/gtest.h>
18 #include <sys/types.h>
19
20 using namespace ::testing;
21 using exec_aten::ArrayRef;
22 using exec_aten::optional;
23 using exec_aten::ScalarType;
24 using exec_aten::Tensor;
25 using torch::executor::testing::TensorFactory;
26
27 using OptTensorArrayRef = ArrayRef<optional<Tensor>>;
28
29 class OpIndexTensorOutTest : public OperatorTest {
30 protected:
op_index_tensor_out(const Tensor & input,OptTensorArrayRef indices,Tensor & out)31 Tensor& op_index_tensor_out(
32 const Tensor& input,
33 OptTensorArrayRef indices,
34 Tensor& out) {
35 #ifdef USE_ATEN_LIB
36 c10::List<std::optional<at::Tensor>> indices_list(indices);
37 return torch::executor::aten::index_outf(
38 context_, input, indices_list, out);
39 #else
40 return torch::executor::aten::index_outf(context_, input, indices, out);
41 #endif
42 }
43
44 template <
45 exec_aten::ScalarType INPUT_DTYPE,
46 exec_aten::ScalarType INDEX_DTYPE,
47 exec_aten::ScalarType OUTPUT_DTYPE>
test_dtype()48 void test_dtype() {
49 TensorFactory<INPUT_DTYPE> tf;
50 TensorFactory<INDEX_DTYPE> tfl;
51 TensorFactory<OUTPUT_DTYPE> tfo;
52 TensorFactory<ScalarType::Bool> tfb;
53
54 // clang-format off
55 Tensor x = tf.make(
56 {3, 2, 4},
57 {
58 // all ones below are from x,
59 // and all zeros are from y.
60 // [0, :, :]
61 1, 1, 1, 1, // [0, 0, :]
62 0, 0, 0, 0, // [0, 1, :]
63
64 // [1, :, :]
65 1, 1, 1, 1, // [1, 0, :]
66 0, 0, 0, 0, // [1, 1, :]
67
68 // [2, :, :]
69 1, 1, 1, 1, // [2, 0, :]
70 0, 0, 0, 0, // [2, 1, :]
71 });
72 // clang-format on
73
74 // indices [0, 1, 2], [1, 0, 3], expressed two different ways
75 optional<Tensor> indices[] = {
76 optional<Tensor>(tfl.make({2}, {0, 1})),
77 optional<Tensor>(tfl.make({2}, {1, 0})),
78 optional<Tensor>(tfl.make({2}, {2, 3}))};
79
80 optional<Tensor> indices_mixed[] = {
81 optional<Tensor>(tfl.make({2}, {0, 1})),
82 optional<Tensor>(tfb.make({2}, {false, true})),
83 optional<Tensor>(tfl.make({2}, {2, 3}))};
84
85 std::vector<int32_t> out_size{2};
86
87 Tensor out_0 = tfo.zeros(out_size);
88 Tensor ret_0 = op_index_tensor_out(x, /*indices=*/indices, out_0);
89
90 EXPECT_TENSOR_EQ(ret_0, out_0);
91 EXPECT_TENSOR_EQ(ret_0, tfo.make(out_size, {0, 1}));
92
93 // Repeat the test with alternative indices representation
94
95 Tensor out_0_with_mixed = tfo.zeros(out_size);
96 Tensor ret_0_with_mixed =
97 op_index_tensor_out(x, /*indices=*/indices, out_0_with_mixed);
98
99 EXPECT_TENSOR_EQ(ret_0_with_mixed, out_0_with_mixed);
100 EXPECT_TENSOR_EQ(ret_0_with_mixed, tfo.make(out_size, {0, 1}));
101 }
102
103 /**
104 * Generic test for integral index lists
105 */
test_dtype_enumerate_in_types()106 void test_dtype_enumerate_in_types() {
107 #define TEST_ENTRY(ctype, dtype) \
108 test_dtype<ScalarType::dtype, ScalarType::Long, ScalarType::dtype>();
109
110 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
111
112 #undef TEST_ENTRY
113 }
114
115 // Run the test by selecting elements in input
run_test_cases(const Tensor & x,OptTensorArrayRef indices,const Tensor & expected)116 void run_test_cases(
117 const Tensor& x,
118 OptTensorArrayRef indices,
119 const Tensor& expected) {
120 // Generated out tensor sharing same size and dtype with expected tensor
121 TensorFactory<ScalarType::Double> tf;
122
123 const std::vector<int32_t> out_size(
124 expected.sizes().begin(), expected.sizes().end());
125 Tensor out = tf.ones(out_size);
126
127 Tensor ret = op_index_tensor_out(x, indices, out);
128 EXPECT_TENSOR_EQ(out, ret);
129 EXPECT_TENSOR_EQ(ret, expected);
130 }
131 };
132
133 //
134 // Correctness Tests
135 //
136
TEST_F(OpIndexTensorOutTest,IndexMask)137 TEST_F(OpIndexTensorOutTest, IndexMask) {
138 TensorFactory<ScalarType::Double> tf;
139 TensorFactory<ScalarType::Bool> tfb;
140 // clang-format off
141 Tensor x = tf.make(
142 {2, 3, 4},
143 {
144 // [0, :, :]
145 1., 2., 3., 4., // [0, 0, :]
146 5., 6., 7., 8., // [0, 1, :]
147 9., 10., 11., 12., // [0, 2, :]
148
149 // [1, :, :]
150 -1., -2., -3., -4., // [1, 0, :]
151 -5., -6., -7., -8., // [1, 1, :]
152 -9., -10., -11., -12., // [1, 2, :]
153 });
154 // clang-format on
155
156 // clang-format off
157 Tensor indices = tfb.make(
158 {2, 3, 4},
159 {
160 // [0, :, :]
161 true, false, false, false, // [0, 0, :]
162 false, false, true, false, // [0, 1, :]
163 false, false, false, false, // [0, 2, :]
164
165 // [1, :, :]
166 false, true, false, false, // [1, 0, :]
167 false, false, false, false, // [1, 1, :]
168 false, false, true, false, // [1, 2, :]
169 });
170 // clang-format on
171
172 // clang-format off
173 Tensor expected = tf.make(
174 {4},
175 {1., 7., -2., -11.}
176 );
177 // clang-format on
178
179 run_test_cases(x, {indices}, expected);
180 }
181
TEST_F(OpIndexTensorOutTest,SelectFrontDimAllIndexes)182 TEST_F(OpIndexTensorOutTest, SelectFrontDimAllIndexes) {
183 TensorFactory<ScalarType::Double> tf;
184 TensorFactory<ScalarType::Int> tfi;
185 TensorFactory<ScalarType::Long> tfl;
186 TensorFactory<ScalarType::Bool> tfb;
187 // clang-format off
188 Tensor x = tf.make(
189 {2, 3, 4},
190 {
191 // [0, :, :]
192 1., 2., 3., 4., // [0, 0, :]
193 5., 6., 7., 8., // [0, 1, :]
194 9., 10., 11., 12., // [0, 2, :]
195
196 // [1, :, :]
197 -1., -2., -3., -4., // [1, 0, :]
198 -5., -6., -7., -8., // [1, 1, :]
199 -9., -10., -11., -12., // [1, 2, :]
200 });
201 // clang-format on
202
203 // Try to select the input value at indices
204 // [1, 0, 1], [1, 0, 2]. This is expressed in various ways to test different
205 // indexing expressions.
206 optional<Tensor> indices[] = {
207 optional<Tensor>(tfl.make({1}, {1})),
208 optional<Tensor>(tfl.make({1}, {0})),
209 optional<Tensor>(tfl.make({2}, {1, 2}))};
210
211 optional<Tensor> indices_int[] = {
212 optional<Tensor>(tfi.make({1}, {1})),
213 optional<Tensor>(tfi.make({1}, {0})),
214 optional<Tensor>(tfi.make({2}, {1, 2}))};
215
216 optional<Tensor> indices_negative[] = {
217 optional<Tensor>(tfl.make({1}, {-1})),
218 optional<Tensor>(tfl.make({1}, {0})),
219 optional<Tensor>(tfl.make({2}, {-3, -2}))};
220
221 optional<Tensor> indices_bool[] = {
222 optional<Tensor>(tfb.make({2}, {false, true})),
223 optional<Tensor>(tfb.make({3}, {true, false, false})),
224 optional<Tensor>(tfl.make({2}, {-3, -2}))};
225
226 optional<Tensor> indices_mixed[] = {
227 optional<Tensor>(tfb.make({2}, {false, true})),
228 optional<Tensor>(tfl.make({1}, {0})),
229 optional<Tensor>(tfl.make({2}, {-3, -2}))};
230
231 std::vector<int32_t> out_size{2};
232
233 // clang-format off
234 Tensor expected = tf.make(
235 out_size,
236 {-2., -3.,}
237 );
238 // clang-format on
239
240 run_test_cases(x, /*indices=*/indices, expected);
241 run_test_cases(x, /*indices=*/indices_int, expected);
242 run_test_cases(x, /*indices=*/indices_negative, expected);
243 run_test_cases(x, /*indices=*/indices_bool, expected);
244 run_test_cases(x, /*indices=*/indices_mixed, expected);
245 }
246
TEST_F(OpIndexTensorOutTest,SelectTwoValuesAtSameIndex)247 TEST_F(OpIndexTensorOutTest, SelectTwoValuesAtSameIndex) {
248 TensorFactory<ScalarType::Double> tf;
249 TensorFactory<ScalarType::Long> tfl;
250 // clang-format off
251 Tensor x = tf.make(
252 {2, 3, 4},
253 {
254 // [0, :, :]
255 1., 2., 3., 4., // [0, 0, :]
256 5., 6., 7., 8., // [0, 1, :]
257 9., 10., 11., 12., // [0, 2, :]
258
259 // [1, :, :]
260 -1., -2., -3., -4., // [1, 0, :]
261 -5., -6., -7., -8., // [1, 1, :]
262 -9., -10., -11., -12., // [1, 2, :]
263 });
264 // clang-format on
265
266 // Try to select the value at the same index
267 optional<Tensor> indices[] = {
268 optional<Tensor>(tfl.make({1, 2}, {0, 0})),
269 optional<Tensor>(tfl.make({1, 2}, {1, 1})),
270 optional<Tensor>(tfl.make({1, 2}, {2, 2}))};
271
272 std::vector<int32_t> out_size{1, 2}; // In ATen the size is (1, 2)
273
274 // clang-format off
275 Tensor expected = tf.make(
276 out_size,
277 {7., 7.,}
278 );
279 // clang-format on
280
281 run_test_cases(x, /*indices=*/indices, expected);
282 }
283
TEST_F(OpIndexTensorOutTest,IndicesFewerThanInputDimSupported)284 TEST_F(OpIndexTensorOutTest, IndicesFewerThanInputDimSupported) {
285 TensorFactory<ScalarType::Double> tf;
286 TensorFactory<ScalarType::Int> tfi;
287 TensorFactory<ScalarType::Long> tfl;
288 TensorFactory<ScalarType::Bool> tfb;
289 // clang-format off
290 Tensor x = tf.make(
291 {2, 3, 4},
292 {
293 // [0, :, :]
294 1., 2., 3., 4., // [0, 0, :]
295 5., 6., 7., 8., // [0, 1, :]
296 9., 10., 11., 12., // [0, 2, :]
297
298 // [1, :, :]
299 -1., -2., -3., -4., // [1, 0, :]
300 -5., -6., -7., -8., // [1, 1, :]
301 -9., -10., -11., -12., // [1, 2, :]
302 });
303 // clang-format on
304
305 // Try to select the input value at indices
306 // [1, 0, :], [1, 1, :]. This is expressed in various ways to test different
307 // indexing expressions.
308
309 optional<Tensor> indices[] = {
310 optional<Tensor>(tfl.make({1}, {1})),
311 optional<Tensor>(tfl.make({2}, {0, 1}))};
312
313 optional<Tensor> indices_mixed[] = {
314 optional<Tensor>(tfi.make({1}, {-1})),
315 optional<Tensor>(tfb.make({3}, {true, true, false}))};
316
317 std::vector<int32_t> out_size{2, 4};
318
319 // clang-format off
320 Tensor expected = tf.make(
321 out_size,
322 {
323 -1., -2., -3., -4.,
324 -5., -6., -7., -8.,
325 }
326 );
327 // clang-format on
328
329 run_test_cases(x, /*indices=*/indices, expected);
330 run_test_cases(x, /*indices=*/indices_mixed, expected);
331 }
332
TEST_F(OpIndexTensorOutTest,IndicesWithNullTensorsSupported)333 TEST_F(OpIndexTensorOutTest, IndicesWithNullTensorsSupported) {
334 TensorFactory<ScalarType::Double> tf;
335 TensorFactory<ScalarType::Long> tfl;
336 // clang-format off
337 Tensor x = tf.make(
338 {2, 3, 4},
339 {
340 // [0, :, :]
341 1., 2., 3., 4., // [0, 0, :]
342 5., 6., 7., 8., // [0, 1, :]
343 9., 10., 11., 12., // [0, 2, :]
344
345 // [1, :, :]
346 -1., -2., -3., -4., // [1, 0, :]
347 -5., -6., -7., -8., // [1, 1, :]
348 -9., -10., -11., -12., // [1, 2, :]
349 });
350 // clang-format on
351
352 optional<Tensor> indices0[] = {
353 optional<Tensor>(),
354 optional<Tensor>(tfl.make({1}, {1})),
355 optional<Tensor>(tfl.make({2}, {0, 1}))};
356
357 // clang-format off
358 Tensor expected0 = tf.make(
359 {2, 2},
360 {
361 5., 6.,
362 -5., -6.,
363 }
364 );
365 // clang-format on
366
367 run_test_cases(x, /*indices=*/indices0, expected0);
368
369 optional<Tensor> indices1[] = {
370 optional<Tensor>(tfl.make({1}, {1})),
371 optional<Tensor>(),
372 optional<Tensor>(tfl.make({2}, {0, 1}))};
373
374 // clang-format off
375 Tensor expected1 = tf.make(
376 {2, 3},
377 {
378 -1., -5., -9.,
379 -2., -6., -10.,
380 }
381 );
382 // clang-format on
383
384 run_test_cases(x, /*indices=*/indices1, expected1);
385
386 optional<Tensor> indices2[] = {
387 optional<Tensor>(tfl.make({1}, {1})),
388 optional<Tensor>(tfl.make({2}, {0, 1})),
389 optional<Tensor>()};
390
391 // clang-format off
392 Tensor expected2 = tf.make(
393 {2, 4},
394 {
395 -1., -2., -3., -4.,
396 -5., -6., -7., -8.,
397 }
398 );
399 // clang-format on
400
401 run_test_cases(x, /*indices=*/indices2, expected2);
402 }
403
TEST_F(OpIndexTensorOutTest,IndicesWithOnlyNullTensorsSupported)404 TEST_F(OpIndexTensorOutTest, IndicesWithOnlyNullTensorsSupported) {
405 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
406 GTEST_SKIP() << "ATen kernel test fails";
407 }
408 TensorFactory<ScalarType::Double> tf;
409
410 Tensor x = tf.make({2, 3}, {1., 2., 3., 4., 5., 6.});
411 optional<Tensor> indices0[] = {optional<Tensor>()};
412 run_test_cases(x, indices0, x);
413
414 optional<Tensor> indices1[] = {optional<Tensor>(), optional<Tensor>()};
415 run_test_cases(x, indices1, x);
416
417 optional<Tensor> indices2[] = {
418 optional<Tensor>(), optional<Tensor>(), optional<Tensor>()};
419 Tensor out = tf.ones({2, 3});
420 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
421 context_, op_index_tensor_out(x, indices2, out), "");
422 }
423
TEST_F(OpIndexTensorOutTest,EmptyIndicesSupported)424 TEST_F(OpIndexTensorOutTest, EmptyIndicesSupported) {
425 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
426 GTEST_SKIP() << "ATen kernel test fails";
427 }
428 TensorFactory<ScalarType::Float> tf;
429
430 // Using empty tensors as input.
431 Tensor x = tf.make({2}, {1., 2.});
432
433 Tensor out = tf.zeros({2});
434
435 op_index_tensor_out(x, /*indices=*/{}, out);
436 EXPECT_TENSOR_EQ(out, x);
437 // Success if it doesn't assert on the weird-shaped empty input and the
438 // ret is still a empty array
439 }
440
441 //
442 // Test that all dtypes are supported
443 //
444
TEST_F(OpIndexTensorOutTest,AllDtypesSupportedForInput)445 TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForInput) {
446 test_dtype_enumerate_in_types();
447 }
448
TEST_F(OpIndexTensorOutTest,AllDtypesSupportedForIndex)449 TEST_F(OpIndexTensorOutTest, AllDtypesSupportedForIndex) {
450 test_dtype<ScalarType::Double, ScalarType::Long, ScalarType::Double>();
451 test_dtype<ScalarType::Double, ScalarType::Int, ScalarType::Double>();
452 }
453
454 //
455 // Death Tests
456 //
457
TEST_F(OpIndexTensorOutTest,IndexOutOfBoundDies)458 TEST_F(OpIndexTensorOutTest, IndexOutOfBoundDies) {
459 TensorFactory<ScalarType::Int> tf;
460 TensorFactory<ScalarType::Long> tfl;
461
462 Tensor x = tf.ones({1, 1, 1});
463 Tensor out = tf.zeros({1, 1, 1});
464 Tensor index = tfl.make({1}, {5});
465
466 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
467 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
468 }
469
TEST_F(OpIndexTensorOutTest,NegativeIndexOutOfBoundDies)470 TEST_F(OpIndexTensorOutTest, NegativeIndexOutOfBoundDies) {
471 TensorFactory<ScalarType::Int> tf;
472 TensorFactory<ScalarType::Long> tfl;
473
474 Tensor x = tf.ones({1, 1, 1});
475 Tensor out = tf.zeros({1, 1, 1});
476 Tensor index = tfl.make({1}, {-5});
477
478 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
479 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
480 }
481
TEST_F(OpIndexTensorOutTest,TooManyBooleanIndexCountDies)482 TEST_F(OpIndexTensorOutTest, TooManyBooleanIndexCountDies) {
483 TensorFactory<ScalarType::Float> tf;
484 TensorFactory<ScalarType::Bool> tfb;
485
486 Tensor x = tf.ones({1, 1, 1});
487 Tensor out = tf.zeros({1, 1, 1});
488 Tensor index = tfb.make({3}, {true, false, false});
489
490 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
491 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
492 }
493
TEST_F(OpIndexTensorOutTest,TooFewBooleanIndexCountDies)494 TEST_F(OpIndexTensorOutTest, TooFewBooleanIndexCountDies) {
495 TensorFactory<ScalarType::Float> tf;
496 TensorFactory<ScalarType::Bool> tfb;
497
498 Tensor x = tf.ones({4});
499 Tensor out = tf.zeros({1});
500 Tensor index = tfb.make({1}, {true});
501
502 // ATen kernel will throw exception instead of death
503 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
504 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
505 }
506
TEST_F(OpIndexTensorOutTest,MismatchedIndexMaskDies)507 TEST_F(OpIndexTensorOutTest, MismatchedIndexMaskDies) {
508 TensorFactory<ScalarType::Float> tf;
509 TensorFactory<ScalarType::Bool> tfb;
510
511 Tensor x = tf.ones({4, 4});
512 Tensor out = tf.zeros({9});
513 Tensor index = tfb.ones({3, 3});
514
515 // ATen kernel will throw exception instead of death
516 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
517 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
518 }
519
TEST_F(OpIndexTensorOutTest,MismatchedOutputDimDies)520 TEST_F(OpIndexTensorOutTest, MismatchedOutputDimDies) {
521 TensorFactory<ScalarType::Int> tf;
522 TensorFactory<ScalarType::Long> tfl;
523
524 Tensor x = tf.zeros({2, 4, 7, 5});
525 Tensor index = tfl.make({1}, {3});
526
527 // Should be {1, 4, 7, 5}
528 Tensor out = tf.zeros({2, 4});
529
530 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
531 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
532 }
533
TEST_F(OpIndexTensorOutTest,InvalidIndicesDtypeDies)534 TEST_F(OpIndexTensorOutTest, InvalidIndicesDtypeDies) {
535 TensorFactory<ScalarType::Int> tf;
536 TensorFactory<ScalarType::Float> tff;
537
538 Tensor x = tf.zeros({2, 4, 7, 5});
539 Tensor index = tff.make({1}, {3});
540
541 Tensor out = tf.zeros({1, 4, 7, 5});
542
543 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
544 context_, op_index_tensor_out(x, /*indices=*/{index}, out), "");
545 }
546
TEST_F(OpIndexTensorOutTest,InvalidIndicesShapesDies)547 TEST_F(OpIndexTensorOutTest, InvalidIndicesShapesDies) {
548 TensorFactory<ScalarType::Float> tf;
549 TensorFactory<ScalarType::Long> tfl;
550
551 Tensor x = tf.zeros({2, 4, 7, 5});
552 // clang-format off
553 optional<Tensor> indices[] = {
554 optional<Tensor>(tfl.make({3}, {1, 1, 1,})),
555 optional<Tensor>(tfl.make({2}, {1, 2}))};
556
557 Tensor out = tf.ones({3, 7, 5});
558 // clang-format on
559
560 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
561 context_, op_index_tensor_out(x, indices, out), "");
562 }
563
TEST_F(OpIndexTensorOutTest,InvalidIndicesShapeDies2)564 TEST_F(OpIndexTensorOutTest, InvalidIndicesShapeDies2) {
565 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
566 GTEST_SKIP() << "";
567 }
568 TensorFactory<ScalarType::Float> tf;
569 TensorFactory<ScalarType::Long> tfl;
570
571 Tensor x = tf.zeros({4, 4});
572 // clang-format off
573 optional<Tensor> indices[] = {
574 optional<Tensor>(tfl.make({2, 2}, {1, 1, 1, 1,})),
575 optional<Tensor>(tfl.make({1, 2}, {3, 0,}))};
576
577 Tensor out = tf.ones({4});
578 // clang-format on
579
580 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
581 context_, op_index_tensor_out(x, indices, out), "");
582 }
583
584 //
585 // Dynamic Shape Tests
586 //
587
588 // Test whether resize works when out is having larger size
TEST_F(OpIndexTensorOutTest,UpperBoundOutTensor)589 TEST_F(OpIndexTensorOutTest, UpperBoundOutTensor) {
590 TensorFactory<ScalarType::Double> tf;
591 TensorFactory<ScalarType::Long> tfl;
592 // clang-format off
593 Tensor x = tf.make(
594 {2, 3, 4},
595 {
596 // [0, :, :]
597 1., 2., 3., 4., // [0, 0, :]
598 5., 6., 7., 8., // [0, 1, :]
599 9., 10., 11., 12., // [0, 2, :]
600
601 // [1, :, :]
602 -1., -2., -3., -4., // [1, 0, :]
603 -5., -6., -7., -8., // [1, 1, :]
604 -9., -10., -11., -12., // [1, 2, :]
605 });
606 // clang-format on
607
608 // Try to select the tensor from the input
609 // indices [0, 2, 2], [1, 1, 2]
610 optional<Tensor> indices[] = {
611 optional<Tensor>(tfl.make({1, 2}, {0, 1})),
612 optional<Tensor>(tfl.make({1, 2}, {2, 1})),
613 optional<Tensor>(tfl.make({1, 2}, {2, 2}))};
614
615 Tensor out =
616 tf.zeros({5, 5}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
617 // clang-format off
618 Tensor expected = tf.make(
619 {1, 2},
620 {
621 11., -7.
622 }
623 );
624 // clang-format on
625
626 Tensor ret = op_index_tensor_out(x, indices, out);
627 EXPECT_TENSOR_EQ(out, ret);
628 EXPECT_TENSOR_EQ(ret, expected);
629 }
630