1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17 #include <sys/types.h>
18
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::optional;
22 using exec_aten::ScalarType;
23 using exec_aten::Tensor;
24 using torch::executor::testing::TensorFactory;
25
26 using OptTensorArrayRef = ArrayRef<optional<Tensor>>;
27
28 class OpIndexPutOutTest : public OperatorTest {
29 protected:
op_index_put_out(const Tensor & input,OptTensorArrayRef indices,const Tensor & values,const bool accumulate,Tensor & out)30 Tensor& op_index_put_out(
31 const Tensor& input,
32 OptTensorArrayRef indices,
33 const Tensor& values,
34 const bool accumulate,
35 Tensor& out) {
36 #ifdef USE_ATEN_LIB
37 c10::List<std::optional<at::Tensor>> indices_list(indices);
38 return torch::executor::aten::index_put_outf(
39 context_, input, indices_list, values, accumulate, out);
40 #else
41 return torch::executor::aten::index_put_outf(
42 context_, input, indices, values, accumulate, out);
43 #endif
44 }
45
46 template <
47 exec_aten::ScalarType INPUT_DTYPE,
48 exec_aten::ScalarType INDICES_DTYPE>
test_dtype()49 void test_dtype() {
50 TensorFactory<INPUT_DTYPE> tf;
51 TensorFactory<INDICES_DTYPE> tfl;
52 TensorFactory<ScalarType::Bool> tfb;
53
54 // clang-format off
55 Tensor x = tf.make(
56 {3, 2, 4},
57 {
58 // [0, :, :]
59 1, 1, 1, 1, // [0, 0, :]
60 0, 0, 0, 0, // [0, 1, :]
61
62 // [1, :, :]
63 1, 1, 1, 1, // [1, 0, :]
64 0, 0, 0, 0, // [1, 1, :]
65
66 // [2, :, :]
67 1, 1, 1, 1, // [2, 0, :]
68 0, 0, 0, 0, // [2, 1, :]
69 });
70 // clang-format on
71
72 // First, index_put to make everything equal to 1
73
74 // indices [0, 1, :], [1, 1, :], [2, 1, :]
75 optional<Tensor> indices[] = {
76 optional<Tensor>(tfl.make({1, 3}, {0, 1, 2})),
77 optional<Tensor>(tfl.make({1, 3}, {1, 1, 1})),
78 };
79 // bool representation of the same index list
80 optional<Tensor> indices_bool[] = {
81 optional<Tensor>(tfb.make({3}, {true, true, true})),
82 optional<Tensor>(tfb.make({2}, {false, true})),
83 };
84
85 Tensor values = tf.ones({3, 4});
86
87 std::vector<int32_t> out_size{3, 2, 4};
88
89 Tensor out = tf.zeros(out_size);
90 Tensor ret =
91 op_index_put_out(x, indices, values, /*accumulate=*/false, out);
92
93 EXPECT_TENSOR_EQ(ret, out);
94 EXPECT_TENSOR_EQ(ret, tf.ones(out_size));
95
96 // Repeat the test with bool indices
97 Tensor out_with_bool = tf.zeros(out_size);
98 Tensor ret_with_bool = op_index_put_out(
99 x, indices_bool, values, /*accumulate=*/false, out_with_bool);
100
101 EXPECT_TENSOR_EQ(ret_with_bool, out_with_bool);
102 EXPECT_TENSOR_EQ(ret_with_bool, tf.ones(out_size));
103
104 // Then, index_put to make everything equal to 0
105
106 // indices [0, 1, :], [1, 0, :], [2, 0, :]
107 optional<Tensor> indices_alt[] = {
108 optional<Tensor>(tfl.make({1, 3}, {0, 1, 2})),
109 optional<Tensor>(tfl.make({1, 3}, {0, 0, 0})),
110 };
111 // bool representation of the same index list
112 optional<Tensor> indices_alt_bool[] = {
113 optional<Tensor>(tfb.make({3}, {true, true, true})),
114 optional<Tensor>(tfb.make({2}, {true, false})),
115 };
116
117 Tensor values_alt = tf.zeros({3, 4});
118
119 Tensor out_alt = tf.ones(out_size);
120 Tensor ret_alt = op_index_put_out(
121 x, indices_alt, values_alt, /*accumulate=*/false, out_alt);
122
123 EXPECT_TENSOR_EQ(ret_alt, out_alt);
124 EXPECT_TENSOR_EQ(ret_alt, tf.zeros(out_size));
125
126 // Repeat the test with bool indices
127 Tensor out_alt_with_bool = tf.ones(out_size);
128 Tensor ret_alt_with_bool = op_index_put_out(
129 x,
130 indices_alt_bool,
131 values_alt,
132 /*accumulate=*/false,
133 out_alt_with_bool);
134
135 EXPECT_TENSOR_EQ(ret_alt_with_bool, out_alt_with_bool);
136 EXPECT_TENSOR_EQ(ret_alt_with_bool, tf.zeros(out_size));
137 }
138
139 /* %python
140 import torch
141 torch.manual_seed(0)
142 input = torch.rand(2, 3, 4)
143 indices = [torch.tensor([1]), torch.tensor([0]), torch.tensor([1, 2])]
144 values = torch.rand(2)
145 accumulate = False
146 expected = input.index_put(indices, values, accumulate=accumulate)
147
148 index_put_template = f"""
149 {declare_tensor_factory("ScalarType::Float", "tf")}
150 {declare_tensor_factory("ScalarType::Long", "tf_indices")}
151
152 {declare_tensor_make_t("input", "tf")}
153 {declare_optional_tensor_list_make_t("indices", "tf_indices")}
154 {declare_tensor_make_t("values", "tf")}
155 {declare_tensor_make_t("expected", "tf")}
156 {declare_tensor_zeros("out_shape, dynamism", "tf", "out")}
157
158 op_index_put_out(input, indices, values, $accumulate$, out);
159 EXPECT_TENSOR_EQ(out, expected);"""
160 */
test_dynamic_shape(const std::vector<int32_t> & out_shape,enum torch::executor::TensorShapeDynamism dynamism)161 void test_dynamic_shape(
162 const std::vector<int32_t>& out_shape,
163 enum torch::executor::TensorShapeDynamism dynamism) {
164 /* %python
165 %rewrite(index_put_template) */
166
167 TensorFactory<ScalarType::Float> tf;
168 TensorFactory<ScalarType::Long> tf_indices;
169
170 Tensor input = tf.make(
171 {2, 3, 4},
172 {0.49625658988952637, 0.7682217955589294, 0.08847743272781372,
173 0.13203048706054688, 0.30742281675338745, 0.6340786814689636,
174 0.4900934100151062, 0.8964447379112244, 0.455627977848053,
175 0.6323062777519226, 0.3488934636116028, 0.40171730518341064,
176 0.022325754165649414, 0.16885894536972046, 0.2938884496688843,
177 0.518521785736084, 0.6976675987243652, 0.800011396408081,
178 0.16102945804595947, 0.28226858377456665, 0.6816085577011108,
179 0.9151939749717712, 0.39709991216659546, 0.8741558790206909});
180 optional<Tensor> indices[] = {
181 optional<Tensor>(tf_indices.make({1}, {1})),
182 optional<Tensor>(tf_indices.make({1}, {0})),
183 optional<Tensor>(tf_indices.make({2}, {1, 2}))};
184 Tensor values = tf.make({2}, {0.41940832138061523, 0.5529070496559143});
185 Tensor expected = tf.make(
186 {2, 3, 4},
187 {0.49625658988952637, 0.7682217955589294, 0.08847743272781372,
188 0.13203048706054688, 0.30742281675338745, 0.6340786814689636,
189 0.4900934100151062, 0.8964447379112244, 0.455627977848053,
190 0.6323062777519226, 0.3488934636116028, 0.40171730518341064,
191 0.022325754165649414, 0.41940832138061523, 0.5529070496559143,
192 0.518521785736084, 0.6976675987243652, 0.800011396408081,
193 0.16102945804595947, 0.28226858377456665, 0.6816085577011108,
194 0.9151939749717712, 0.39709991216659546, 0.8741558790206909});
195 Tensor out = tf.zeros(out_shape, dynamism);
196
197 op_index_put_out(input, indices, values, false, out);
198 EXPECT_TENSOR_EQ(out, expected);
199 }
200
201 // Run the test by putting values into the selected elements
run_test_cases(const Tensor & x,OptTensorArrayRef indices,const Tensor & values,const Tensor & expected,const Tensor & expected_accum)202 void run_test_cases(
203 const Tensor& x,
204 OptTensorArrayRef indices,
205 const Tensor& values,
206 const Tensor& expected,
207 const Tensor& expected_accum) {
208 // Generated out tensor sharing same size and dtype with expected tensor
209 TensorFactory<ScalarType::Double> tf;
210
211 const std::vector<int32_t> out_size(
212 expected.sizes().begin(), expected.sizes().end());
213 Tensor out = tf.ones(out_size);
214
215 Tensor ret =
216 op_index_put_out(x, indices, values, /*accumulate=*/false, out);
217 EXPECT_TENSOR_EQ(out, ret);
218 EXPECT_TENSOR_EQ(ret, expected);
219
220 Tensor out_accum = tf.ones(out_size);
221 Tensor ret_accum =
222 op_index_put_out(x, indices, values, /*accumulate=*/true, out_accum);
223 EXPECT_TENSOR_EQ(out_accum, ret_accum);
224 EXPECT_TENSOR_EQ(ret_accum, expected_accum);
225 }
226 };
227
228 //
229 // Correctness Tests
230 //
231
TEST_F(OpIndexPutOutTest,IndexPutMask)232 TEST_F(OpIndexPutOutTest, IndexPutMask) {
233 TensorFactory<ScalarType::Double> tf;
234 TensorFactory<ScalarType::Bool> tfb;
235 // clang-format off
236 Tensor x = tf.make(
237 {2, 3, 4},
238 {
239 // [0, :, :]
240 1., 2., 3., 4., // [0, 0, :]
241 5., 6., 7., 8., // [0, 1, :]
242 9., 10., 11., 12., // [0, 2, :]
243
244 // [1, :, :]
245 -1., -2., -3., -4., // [1, 0, :]
246 -5., -6., -7., -8., // [1, 1, :]
247 -9., -10., -11., -12., // [1, 2, :]
248 });
249 // clang-format on
250
251 // clang-format off
252 Tensor indices = tfb.make(
253 {2, 3, 4},
254 {
255 // [0, :, :]
256 true, false, false, false, // [0, 0, :]
257 false, false, true, false, // [0, 1, :]
258 false, false, false, false, // [0, 2, :]
259
260 // [1, :, :]
261 false, true, false, false, // [1, 0, :]
262 false, false, false, false, // [1, 1, :]
263 false, false, true, false, // [1, 2, :]
264 });
265 // clang-format on
266
267 // clang-format off
268 Tensor values = tf.make(
269 {4},
270 {10., 20., 30., 40.}
271 );
272 // clang-format on
273
274 // clang-format off
275 Tensor expected = tf.make(
276 {2, 3, 4},
277 {
278 // [0, :, :]
279 10., 2., 3., 4., // [0, 0, :]
280 5., 6., 20., 8., // [0, 1, :]
281 9., 10., 11., 12., // [0, 2, :]
282
283 // [1, :, :]
284 -1., 30., -3., -4., // [1, 0, :]
285 -5., -6., -7., -8., // [1, 1, :]
286 -9., -10., 40., -12., // [1, 2, :]
287 });
288 // clang-format on
289
290 // clang-format off
291 Tensor expected_accum = tf.make(
292 {2, 3, 4},
293 {
294 // [0, :, :]
295 11., 2., 3., 4., // [0, 0, :]
296 5., 6., 27., 8., // [0, 1, :]
297 9., 10., 11., 12., // [0, 2, :]
298
299 // [1, :, :]
300 -1., 28., -3., -4., // [1, 0, :]
301 -5., -6., -7., -8., // [1, 1, :]
302 -9., -10., 29., -12., // [1, 2, :]
303 });
304 // clang-format on
305
306 run_test_cases(x, {indices}, values, expected, expected_accum);
307 }
308
TEST_F(OpIndexPutOutTest,IndexPutMaskBroadcast)309 TEST_F(OpIndexPutOutTest, IndexPutMaskBroadcast) {
310 TensorFactory<ScalarType::Double> tf;
311 TensorFactory<ScalarType::Bool> tfb;
312 // clang-format off
313 Tensor x = tf.make(
314 {2, 3, 4},
315 {
316 // [0, :, :]
317 1., 2., 3., 4., // [0, 0, :]
318 5., 6., 7., 8., // [0, 1, :]
319 9., 10., 11., 12., // [0, 2, :]
320
321 // [1, :, :]
322 -1., -2., -3., -4., // [1, 0, :]
323 -5., -6., -7., -8., // [1, 1, :]
324 -9., -10., -11., -12., // [1, 2, :]
325 });
326 // clang-format on
327
328 // Try to select the input value at indices
329 // [1, 0, 1], [1, 0, 2]. This is expressed in various ways to test different
330 // indexing expressions.
331
332 // clang-format off
333 Tensor indices = tfb.make(
334 {2, 3, 4},
335 {
336 // [0, :, :]
337 true, false, false, false, // [0, 0, :]
338 false, false, true, false, // [0, 1, :]
339 false, false, false, false, // [0, 2, :]
340
341 // [1, :, :]
342 false, true, false, false, // [1, 0, :]
343 false, false, false, false, // [1, 1, :]
344 false, false, true, false, // [1, 2, :]
345 });
346 // clang-format on
347
348 // clang-format off
349 Tensor values = tf.make(
350 {1},
351 {10.}
352 );
353 // clang-format on
354
355 // clang-format off
356 Tensor expected = tf.make(
357 {2, 3, 4},
358 {
359 // [0, :, :]
360 10., 2., 3., 4., // [0, 0, :]
361 5., 6., 10., 8., // [0, 1, :]
362 9., 10., 11., 12., // [0, 2, :]
363
364 // [1, :, :]
365 -1., 10., -3., -4., // [1, 0, :]
366 -5., -6., -7., -8., // [1, 1, :]
367 -9., -10., 10., -12., // [1, 2, :]
368 });
369 // clang-format on
370
371 // clang-format off
372 Tensor expected_accum = tf.make(
373 {2, 3, 4},
374 {
375 // [0, :, :]
376 11., 2., 3., 4., // [0, 0, :]
377 5., 6., 17., 8., // [0, 1, :]
378 9., 10., 11., 12., // [0, 2, :]
379
380 // [1, :, :]
381 -1., 8., -3., -4., // [1, 0, :]
382 -5., -6., -7., -8., // [1, 1, :]
383 -9., -10., -1., -12., // [1, 2, :]
384 });
385 // clang-format on
386
387 run_test_cases(x, {indices}, values, expected, expected_accum);
388 }
389
TEST_F(OpIndexPutOutTest,PutFrontDimAllIndexes)390 TEST_F(OpIndexPutOutTest, PutFrontDimAllIndexes) {
391 TensorFactory<ScalarType::Double> tf;
392 TensorFactory<ScalarType::Int> tfi;
393 TensorFactory<ScalarType::Long> tfl;
394 TensorFactory<ScalarType::Bool> tfb;
395 // clang-format off
396 Tensor x = tf.make(
397 {2, 3, 4},
398 {
399 // [0, :, :]
400 1., 2., 3., 4., // [0, 0, :]
401 5., 6., 7., 8., // [0, 1, :]
402 9., 10., 11., 12., // [0, 2, :]
403
404 // [1, :, :]
405 -1., -2., -3., -4., // [1, 0, :]
406 -5., -6., -7., -8., // [1, 1, :]
407 -9., -10., -11., -12., // [1, 2, :]
408 });
409 // clang-format on
410
411 // Try to select the input value at indices
412 // [1, 0, 1], [1, 0, 2]. This is expressed in various ways to test different
413 // indexing expressions.
414
415 optional<Tensor> indices_long[] = {
416 optional<Tensor>(tfl.make({1}, {1})),
417 optional<Tensor>(tfl.make({1}, {0})),
418 optional<Tensor>(tfl.make({2}, {1, 2}))};
419
420 optional<Tensor> indices_int[] = {
421 optional<Tensor>(tfi.make({1}, {1})),
422 optional<Tensor>(tfi.make({1}, {0})),
423 optional<Tensor>(tfi.make({2}, {1, 2}))};
424
425 optional<Tensor> indices_negative[] = {
426 optional<Tensor>(tfl.make({1}, {-1})),
427 optional<Tensor>(tfl.make({1}, {0})),
428 optional<Tensor>(tfl.make({2}, {-3, -2}))};
429
430 optional<Tensor> indices_bool[] = {
431 optional<Tensor>(tfb.make({2}, {false, true})),
432 optional<Tensor>(tfb.make({3}, {true, false, false})),
433 optional<Tensor>(tfl.make({2}, {-3, -2}))};
434
435 optional<Tensor> indices_mixed[] = {
436 optional<Tensor>(tfb.make({2}, {false, true})),
437 optional<Tensor>(tfl.make({1}, {0})),
438 optional<Tensor>(tfl.make({2}, {-3, -2}))};
439
440 // clang-format off
441 Tensor values = tf.make(
442 {2},
443 {10., 20.}
444 );
445 // clang-format on
446
447 // clang-format off
448 Tensor expected = tf.make(
449 {2, 3, 4},
450 {
451 // [0, :, :]
452 1., 2., 3., 4., // [0, 0, :]
453 5., 6., 7., 8., // [0, 1, :]
454 9., 10., 11., 12., // [0, 2, :]
455
456 // [1, :, :]
457 -1., 10., 20., -4., // [1, 0, :]
458 -5., -6., -7., -8., // [1, 1, :]
459 -9., -10., -11., -12., // [1, 2, :]
460 });
461 // clang-format on
462
463 // clang-format off
464 Tensor expected_accum = tf.make(
465 {2, 3, 4},
466 {
467 // [0, :, :]
468 1., 2., 3., 4., // [0, 0, :]
469 5., 6., 7., 8., // [0, 1, :]
470 9., 10., 11., 12., // [0, 2, :]
471
472 // [1, :, :]
473 -1., 8., 17., -4., // [1, 0, :]
474 -5., -6., -7., -8., // [1, 1, :]
475 -9., -10., -11., -12., // [1, 2, :]
476 });
477 // clang-format on
478
479 run_test_cases(x, indices_long, values, expected, expected_accum);
480 run_test_cases(x, indices_int, values, expected, expected_accum);
481 run_test_cases(x, indices_negative, values, expected, expected_accum);
482 run_test_cases(x, indices_bool, values, expected, expected_accum);
483 run_test_cases(x, indices_mixed, values, expected, expected_accum);
484 }
485
TEST_F(OpIndexPutOutTest,PutTwoValuesAtSameIndex)486 TEST_F(OpIndexPutOutTest, PutTwoValuesAtSameIndex) {
487 TensorFactory<ScalarType::Double> tf;
488 TensorFactory<ScalarType::Long> tfl;
489 // clang-format off
490 Tensor x = tf.make(
491 {2, 3, 4},
492 {
493 // [0, :, :]
494 1., 2., 3., 4., // [0, 0, :]
495 5., 6., 7., 8., // [0, 1, :]
496 9., 10., 11., 12., // [0, 2, :]
497
498 // [1, :, :]
499 -1., -2., -3., -4., // [1, 0, :]
500 -5., -6., -7., -8., // [1, 1, :]
501 -9., -10., -11., -12., // [1, 2, :]
502 });
503 // clang-format on
504
505 // Try to select the value at the same index
506 optional<Tensor> indices[] = {
507 optional<Tensor>(tfl.make({1, 2}, {0, 0})),
508 optional<Tensor>(tfl.make({1, 2}, {1, 1})),
509 optional<Tensor>(tfl.make({1, 2}, {2, 2}))};
510
511 // clang-format off
512 Tensor values = tf.make(
513 {1},
514 {10.,}
515 );
516 // clang-format on
517
518 // clang-format off
519 Tensor expected = tf.make(
520 {2, 3, 4},
521 {
522 // [0, :, :]
523 1., 2., 3., 4., // [0, 0, :]
524 5., 6., 10., 8., // [0, 1, :]
525 9., 10., 11., 12., // [0, 2, :]
526
527 // [1, :, :]
528 -1., -2., -3., -4., // [1, 0, :]
529 -5., -6., -7., -8., // [1, 1, :]
530 -9., -10., -11., -12., // [1, 2, :]
531 });
532 // clang-format on
533
534 // clang-format off
535 Tensor expected_accum = tf.make(
536 {2, 3, 4},
537 {
538 // [0, :, :]
539 1., 2., 3., 4., // [0, 0, :]
540 5., 6., 27., 8., // [0, 1, :]
541 9., 10., 11., 12., // [0, 2, :]
542
543 // [1, :, :]
544 -1., -2., -3., -4., // [1, 0, :]
545 -5., -6., -7., -8., // [1, 1, :]
546 -9., -10., -11., -12., // [1, 2, :]
547 });
548 // clang-format on
549
550 run_test_cases(x, /*indices=*/indices, values, expected, expected_accum);
551 }
552
TEST_F(OpIndexPutOutTest,IndicesFewerThanInputDimSupported)553 TEST_F(OpIndexPutOutTest, IndicesFewerThanInputDimSupported) {
554 TensorFactory<ScalarType::Double> tf;
555 TensorFactory<ScalarType::Int> tfi;
556 TensorFactory<ScalarType::Long> tfl;
557 TensorFactory<ScalarType::Bool> tfb;
558 // clang-format off
559 Tensor x = tf.make(
560 {2, 3, 4},
561 {
562 // [0, :, :]
563 1., 2., 3., 4., // [0, 0, :]
564 5., 6., 7., 8., // [0, 1, :]
565 9., 10., 11., 12., // [0, 2, :]
566
567 // [1, :, :]
568 -1., -2., -3., -4., // [1, 0, :]
569 -5., -6., -7., -8., // [1, 1, :]
570 -9., -10., -11., -12., // [1, 2, :]
571 });
572 // clang-format on
573
574 // Try to select the input value at indices
575 // [1, 0, :], [1, 1, :]. This is expressed in various ways to test different
576 // indexing expressions.
577
578 optional<Tensor> indices_long[] = {
579 optional<Tensor>(tfl.make({1}, {1})),
580 optional<Tensor>(tfl.make({2}, {0, 1}))};
581
582 optional<Tensor> indices_mixed[] = {
583 optional<Tensor>(tfi.make({1}, {-1})),
584 optional<Tensor>(tfb.make({3}, {true, true, false}))};
585
586 // clang-format off
587 Tensor values = tf.make(
588 {2, 4},
589 {
590 10., 20., 30., 40.,
591 -10., -20., -30., -40.,
592 }
593 );
594 // clang-format on
595
596 // clang-format off
597 Tensor expected = tf.make(
598 {2, 3, 4},
599 {
600 // [0, :, :]
601 1., 2., 3., 4., // [0, 0, :]
602 5., 6., 7., 8., // [0, 1, :]
603 9., 10., 11., 12., // [0, 2, :]
604
605 // [1, :, :]
606 10., 20., 30., 40., // [1, 0, :]
607 -10., -20., -30., -40., // [1, 1, :]
608 -9., -10., -11., -12., // [1, 2, :]
609 });
610 // clang-format on
611
612 // clang-format off
613 Tensor expected_accum = tf.make(
614 {2, 3, 4},
615 {
616 // [0, :, :]
617 1., 2., 3., 4., // [0, 0, :]
618 5., 6., 7., 8., // [0, 1, :]
619 9., 10., 11., 12., // [0, 2, :]
620
621 // [1, :, :]
622 9., 18., 27., 36., // [1, 0, :]
623 -15., -26., -37., -48., // [1, 1, :]
624 -9., -10., -11., -12., // [1, 2, :]
625 });
626 // clang-format on
627
628 run_test_cases(x, indices_long, values, expected, expected_accum);
629 run_test_cases(x, indices_mixed, values, expected, expected_accum);
630 }
631
TEST_F(OpIndexPutOutTest,IndicesFewerThanInputDimSupportedSameValue)632 TEST_F(OpIndexPutOutTest, IndicesFewerThanInputDimSupportedSameValue) {
633 TensorFactory<ScalarType::Double> tf;
634 TensorFactory<ScalarType::Long> tfl;
635 // clang-format off
636 Tensor x = tf.make(
637 {2, 3, 4},
638 {
639 // [0, :, :]
640 1., 2., 3., 4., // [0, 0, :]
641 5., 6., 7., 8., // [0, 1, :]
642 9., 10., 11., 12., // [0, 2, :]
643
644 // [1, :, :]
645 -1., -2., -3., -4., // [1, 0, :]
646 -5., -6., -7., -8., // [1, 1, :]
647 -9., -10., -11., -12., // [1, 2, :]
648 });
649 // clang-format on
650
651 // Try to select the input value at indices
652 // [1, 0, :], [1, 1, :]
653 optional<Tensor> indices[] = {
654 optional<Tensor>(tfl.make({1}, {1})),
655 optional<Tensor>(tfl.make({2}, {0, 1}))};
656
657 // clang-format off
658 Tensor values = tf.make(
659 {1},
660 {10.}
661 );
662 // clang-format on
663
664 // clang-format off
665 Tensor expected = tf.make(
666 {2, 3, 4},
667 {
668 // [0, :, :]
669 1., 2., 3., 4., // [0, 0, :]
670 5., 6., 7., 8., // [0, 1, :]
671 9., 10., 11., 12., // [0, 2, :]
672
673 // [1, :, :]
674 10., 10., 10., 10., // [1, 0, :]
675 10., 10., 10., 10., // [1, 1, :]
676 -9., -10., -11., -12., // [1, 2, :]
677 });
678 // clang-format on
679
680 // clang-format off
681 Tensor expected_accum = tf.make(
682 {2, 3, 4},
683 {
684 // [0, :, :]
685 1., 2., 3., 4., // [0, 0, :]
686 5., 6., 7., 8., // [0, 1, :]
687 9., 10., 11., 12., // [0, 2, :]
688
689 // [1, :, :]
690 9., 8., 7., 6., // [1, 0, :]
691 5., 4., 3., 2., // [1, 1, :]
692 -9., -10., -11., -12., // [1, 2, :]
693 });
694 // clang-format on
695
696 run_test_cases(x, /*indices=*/indices, values, expected, expected_accum);
697 }
698
699 //
700 // Test that all dtypes are supported
701 //
702
703 /**
704 * Generic test for integral index lists
705 */
TEST_F(OpIndexPutOutTest,AllDtypesSupportedForInput)706 TEST_F(OpIndexPutOutTest, AllDtypesSupportedForInput) {
707 #define TEST_ENTRY(ctype, dtype) \
708 test_dtype<ScalarType::dtype, ScalarType::Long>();
709
710 ET_FORALL_REALHBBF16_TYPES(TEST_ENTRY);
711
712 #undef TEST_ENTRY
713 }
714
TEST_F(OpIndexPutOutTest,AllDtypesSupportedForIndicesList)715 TEST_F(OpIndexPutOutTest, AllDtypesSupportedForIndicesList) {
716 test_dtype<ScalarType::Float, ScalarType::Long>();
717 test_dtype<ScalarType::Float, ScalarType::Int>();
718 }
719
720 //
721 // Death Tests
722 //
723
TEST_F(OpIndexPutOutTest,IndexOutOfBoundDies)724 TEST_F(OpIndexPutOutTest, IndexOutOfBoundDies) {
725 TensorFactory<ScalarType::Float> tf;
726 TensorFactory<ScalarType::Long> tfl;
727
728 Tensor x = tf.ones({1, 1, 1});
729 Tensor out = tf.zeros({1, 1, 1});
730 Tensor index = tfl.make({1}, {5});
731
732 // clang-format off
733 Tensor values = tf.make(
734 {1},
735 {10.}
736 );
737 // clang-format on
738
739 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
740 context_,
741 op_index_put_out(
742 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
743 "");
744 }
745
TEST_F(OpIndexPutOutTest,NegativeIndexOutOfBoundDies)746 TEST_F(OpIndexPutOutTest, NegativeIndexOutOfBoundDies) {
747 TensorFactory<ScalarType::Float> tf;
748 TensorFactory<ScalarType::Long> tfl;
749
750 Tensor x = tf.ones({1, 1, 1});
751 Tensor out = tf.zeros({1, 1, 1});
752 Tensor index = tfl.make({1}, {-5});
753
754 // clang-format off
755 Tensor values = tf.make(
756 {1},
757 {10.}
758 );
759 // clang-format on
760
761 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
762 context_,
763 op_index_put_out(
764 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
765 "");
766 }
767
TEST_F(OpIndexPutOutTest,TooManyBooleanIndexCountDies)768 TEST_F(OpIndexPutOutTest, TooManyBooleanIndexCountDies) {
769 TensorFactory<ScalarType::Float> tf;
770 TensorFactory<ScalarType::Bool> tfb;
771
772 Tensor x = tf.ones({1, 1, 1});
773 Tensor out = tf.zeros({1, 1, 1});
774 Tensor index = tfb.make({3}, {true, true, false});
775
776 // clang-format off
777 Tensor values = tf.make(
778 {1},
779 {10.}
780 );
781 // clang-format on
782
783 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
784 context_,
785 op_index_put_out(
786 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
787 "");
788 }
789
TEST_F(OpIndexPutOutTest,TooFewBooleanIndexCountDies)790 TEST_F(OpIndexPutOutTest, TooFewBooleanIndexCountDies) {
791 TensorFactory<ScalarType::Float> tf;
792 TensorFactory<ScalarType::Bool> tfb;
793
794 Tensor x = tf.ones({4});
795 Tensor out = tf.zeros({4});
796 Tensor index = tfb.make({1}, {true});
797
798 // clang-format off
799 Tensor values = tf.make(
800 {1},
801 {10.}
802 );
803 // clang-format on
804
805 // ATen kernel will throw exception instead of death
806 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
807 context_,
808 op_index_put_out(
809 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
810 "");
811 }
812
TEST_F(OpIndexPutOutTest,MismatchedIndexMaskDies)813 TEST_F(OpIndexPutOutTest, MismatchedIndexMaskDies) {
814 TensorFactory<ScalarType::Float> tf;
815 TensorFactory<ScalarType::Bool> tfb;
816
817 Tensor x = tf.ones({4, 4});
818 Tensor out = tf.zeros({4, 4});
819 Tensor index = tfb.ones({3, 3});
820
821 // clang-format off
822 Tensor values = tf.make(
823 {1},
824 {10.}
825 );
826 // clang-format on
827
828 // ATen kernel will throw exception instead of death
829 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
830 context_,
831 op_index_put_out(
832 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
833 "");
834 }
835
TEST_F(OpIndexPutOutTest,MismatchedOutputDtypesDies)836 TEST_F(OpIndexPutOutTest, MismatchedOutputDtypesDies) {
837 TensorFactory<ScalarType::Float> tf_float;
838 TensorFactory<ScalarType::Double> tf_double;
839 TensorFactory<ScalarType::Long> tf_long;
840
841 Tensor x = tf_float.zeros({1, 2, 2});
842
843 // Size is compatible to the output, but a mismatched dtype.
844 Tensor out = tf_double.ones({1, 2, 2});
845 Tensor index = tf_long.make({1}, {0});
846
847 // clang-format off
848 Tensor values = tf_float.make(
849 {1},
850 {10.}
851 );
852 // clang-format on
853
854 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
855 context_,
856 op_index_put_out(
857 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
858 "");
859 }
860
TEST_F(OpIndexPutOutTest,MismatchedValuesDtypesDies)861 TEST_F(OpIndexPutOutTest, MismatchedValuesDtypesDies) {
862 TensorFactory<ScalarType::Float> tf_float;
863 TensorFactory<ScalarType::Double> tf_double;
864 TensorFactory<ScalarType::Long> tf_long;
865
866 Tensor x = tf_float.zeros({1, 2, 2});
867
868 // Size is compatible to the output, but a mismatched dtype.
869 Tensor out = tf_float.ones({1, 2, 2});
870 Tensor index = tf_long.make({1}, {0});
871
872 // clang-format off
873 Tensor values = tf_double.make(
874 {1},
875 {10.}
876 );
877 // clang-format on
878
879 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
880 context_,
881 op_index_put_out(
882 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
883 "");
884 }
885
TEST_F(OpIndexPutOutTest,ValuesSizeMismatchDimDies)886 TEST_F(OpIndexPutOutTest, ValuesSizeMismatchDimDies) {
887 TensorFactory<ScalarType::Float> tf;
888 TensorFactory<ScalarType::Long> tfl;
889
890 Tensor x = tf.zeros({2, 4, 7, 5});
891 Tensor index = tfl.make({1}, {1});
892
893 Tensor out = tf.ones({2, 4, 7, 5});
894
895 // clang-format off
896 Tensor values = tf.make(
897 {1, 2},
898 {10., 10.}
899 );
900 // clang-format on
901
902 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
903 context_,
904 op_index_put_out(
905 x, /*indices=*/{index}, values, /*accumulate=*/false, out),
906 "");
907 }
908
TEST_F(OpIndexPutOutTest,InvalidIndicesDtypeDies)909 TEST_F(OpIndexPutOutTest, InvalidIndicesDtypeDies) {
910 TensorFactory<ScalarType::Float> tf;
911 TensorFactory<ScalarType::Float> tff;
912
913 Tensor x = tf.zeros({2, 4, 7, 5});
914 // clang-format off
915 optional<Tensor> indices[] = {
916 optional<Tensor>(tff.make({3}, {1, 1, 1,})),
917 optional<Tensor>(tff.make({2}, {1, 2}))};
918 // clang-format on
919
920 Tensor out = tf.ones({2, 4, 7, 5});
921
922 // clang-format off
923 Tensor values = tf.make(
924 {1,},
925 {10}
926 );
927 // clang-format on
928
929 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
930 context_,
931 op_index_put_out(x, indices, values, /*accumulate=*/false, out),
932 "");
933 }
934
TEST_F(OpIndexPutOutTest,InvalidIndicesShapesDies)935 TEST_F(OpIndexPutOutTest, InvalidIndicesShapesDies) {
936 TensorFactory<ScalarType::Float> tf;
937 TensorFactory<ScalarType::Long> tfl;
938
939 Tensor x = tf.zeros({2, 4, 7, 5});
940 // clang-format off
941 optional<Tensor> indices[] = {
942 optional<Tensor>(tfl.make({3}, {1, 1, 1,})),
943 optional<Tensor>(tfl.make({2}, {1, 2}))};
944
945 Tensor out = tf.ones({2, 4, 7, 5});
946 // clang-format on
947
948 // clang-format off
949 Tensor values = tf.make(
950 {1, 2},
951 {10., 10.}
952 );
953 // clang-format on
954
955 ET_EXPECT_KERNEL_FAILURE_WITH_MSG(
956 context_,
957 op_index_put_out(x, indices, values, /*accumulate=*/false, out),
958 "");
959 }
960
TEST_F(OpIndexPutOutTest,NonLinearIndices)961 TEST_F(OpIndexPutOutTest, NonLinearIndices) {
962 TensorFactory<ScalarType::Float> tf;
963 TensorFactory<ScalarType::Long> tfl;
964
965 Tensor x = tf.zeros({4, 4});
966 // clang-format off
967 optional<Tensor> indices[] = {
968 optional<Tensor>(tfl.make({2, 2}, {1, 1, 1, 1,})),
969 optional<Tensor>(tfl.make({1, 2}, {3, 0,}))};
970
971 Tensor out = tf.ones({4, 4});
972 // clang-format on
973
974 // clang-format off
975 Tensor values = tf.make(
976 {1},
977 {10.}
978 );
979 // clang-format on
980
981 Tensor expected =
982 tf.make({4, 4}, {0, 0, 0, 0, 10, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0});
983
984 Tensor ret = op_index_put_out(x, indices, values, /*accumulate=*/false, out);
985
986 EXPECT_TENSOR_EQ(ret, out);
987 EXPECT_TENSOR_EQ(ret, expected);
988 }
989
990 //
991 // Dynamic Shape Tests
992 //
993
TEST_F(OpIndexPutOutTest,DynamicShapeUpperBoundSameAsExpected)994 TEST_F(OpIndexPutOutTest, DynamicShapeUpperBoundSameAsExpected) {
995 test_dynamic_shape(
996 {2, 3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
997 }
998
TEST_F(OpIndexPutOutTest,DynamicShapeUpperBoundLargerThanExpected)999 TEST_F(OpIndexPutOutTest, DynamicShapeUpperBoundLargerThanExpected) {
1000 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
1001 GTEST_SKIP() << "Dynamic shape not supported";
1002 }
1003 test_dynamic_shape(
1004 {10, 10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
1005 }
1006
TEST_F(OpIndexPutOutTest,DynamicShapeUnbound)1007 TEST_F(OpIndexPutOutTest, DynamicShapeUnbound) {
1008 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
1009 GTEST_SKIP() << "Dynamic shape not supported";
1010 }
1011 test_dynamic_shape(
1012 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
1013 }
1014