1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::testing::SupportedFeatures;
22 using torch::executor::testing::TensorFactory;
23
24 class OpMaskedSelectOutTest : public OperatorTest {
25 protected:
26 Tensor&
op_masked_select_out(const Tensor & in,const Tensor & mask,Tensor & out)27 op_masked_select_out(const Tensor& in, const Tensor& mask, Tensor& out) {
28 return torch::executor::aten::masked_select_outf(context_, in, mask, out);
29 }
30 };
31
TEST_F(OpMaskedSelectOutTest,SmokeTest)32 TEST_F(OpMaskedSelectOutTest, SmokeTest) {
33 TensorFactory<ScalarType::Int> tf;
34 TensorFactory<ScalarType::Bool> tfBool;
35
36 Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
37 Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
38 Tensor out = tf.zeros({3});
39
40 op_masked_select_out(in, mask, out);
41 EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 4, 6}));
42 }
43
TEST_F(OpMaskedSelectOutTest,BroadcastInput)44 TEST_F(OpMaskedSelectOutTest, BroadcastInput) {
45 TensorFactory<ScalarType::Int> tf;
46 TensorFactory<ScalarType::Bool> tfBool;
47
48 Tensor in = tf.make({3}, {1, 2, 3});
49 Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
50 Tensor out = tf.zeros({3});
51
52 op_masked_select_out(in, mask, out);
53 EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 1, 3}));
54 }
55
TEST_F(OpMaskedSelectOutTest,BroadcastMask)56 TEST_F(OpMaskedSelectOutTest, BroadcastMask) {
57 TensorFactory<ScalarType::Int> tf;
58 TensorFactory<ScalarType::Bool> tfBool;
59
60 Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
61 Tensor mask = tfBool.make({3}, {false, true, false});
62
63 Tensor out = tf.zeros({2});
64
65 op_masked_select_out(in, mask, out);
66 EXPECT_TENSOR_EQ(out, tf.make({2}, {2, 5}));
67 }
68
TEST_F(OpMaskedSelectOutTest,BroadcastInputAndMask)69 TEST_F(OpMaskedSelectOutTest, BroadcastInputAndMask) {
70 TensorFactory<ScalarType::Int> tf;
71 TensorFactory<ScalarType::Bool> tfBool;
72
73 Tensor in = tf.ones({2, 3, 4, 1});
74 Tensor mask = tfBool.ones({2, 1, 1, 5});
75 Tensor out = tf.zeros({120});
76
77 op_masked_select_out(in, mask, out);
78 EXPECT_TENSOR_EQ(out, tf.ones({120}));
79 }
80
TEST_F(OpMaskedSelectOutTest,EmptyInput)81 TEST_F(OpMaskedSelectOutTest, EmptyInput) {
82 TensorFactory<ScalarType::Int> tf;
83 TensorFactory<ScalarType::Bool> tfBool;
84
85 Tensor in = tf.make({2, 0}, {});
86 Tensor mask = tfBool.make({2, 1}, {true, true});
87 Tensor out = tf.zeros({0});
88
89 op_masked_select_out(in, mask, out);
90 EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
91 }
92
TEST_F(OpMaskedSelectOutTest,EmptyMask)93 TEST_F(OpMaskedSelectOutTest, EmptyMask) {
94 TensorFactory<ScalarType::Int> tf;
95 TensorFactory<ScalarType::Bool> tfBool;
96
97 Tensor in = tf.make({2, 1}, {100, 200});
98 Tensor mask = tfBool.make({2, 0}, {});
99 Tensor out = tf.zeros({0});
100
101 op_masked_select_out(in, mask, out);
102 EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
103 }
104
TEST_F(OpMaskedSelectOutTest,EmptyInputAndMask)105 TEST_F(OpMaskedSelectOutTest, EmptyInputAndMask) {
106 TensorFactory<ScalarType::Int> tf;
107 TensorFactory<ScalarType::Bool> tfBool;
108
109 Tensor in = tf.make({2, 0}, {});
110 Tensor mask = tfBool.make({0}, {});
111 Tensor out = tf.zeros({0});
112
113 op_masked_select_out(in, mask, out);
114 EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
115 }
116