xref: /aosp_15_r20/external/executorch/kernels/test/op_masked_select_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using torch::executor::testing::SupportedFeatures;
22 using torch::executor::testing::TensorFactory;
23 
24 class OpMaskedSelectOutTest : public OperatorTest {
25  protected:
26   Tensor&
op_masked_select_out(const Tensor & in,const Tensor & mask,Tensor & out)27   op_masked_select_out(const Tensor& in, const Tensor& mask, Tensor& out) {
28     return torch::executor::aten::masked_select_outf(context_, in, mask, out);
29   }
30 };
31 
TEST_F(OpMaskedSelectOutTest,SmokeTest)32 TEST_F(OpMaskedSelectOutTest, SmokeTest) {
33   TensorFactory<ScalarType::Int> tf;
34   TensorFactory<ScalarType::Bool> tfBool;
35 
36   Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
37   Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
38   Tensor out = tf.zeros({3});
39 
40   op_masked_select_out(in, mask, out);
41   EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 4, 6}));
42 }
43 
TEST_F(OpMaskedSelectOutTest,BroadcastInput)44 TEST_F(OpMaskedSelectOutTest, BroadcastInput) {
45   TensorFactory<ScalarType::Int> tf;
46   TensorFactory<ScalarType::Bool> tfBool;
47 
48   Tensor in = tf.make({3}, {1, 2, 3});
49   Tensor mask = tfBool.make({2, 3}, {true, false, false, true, false, true});
50   Tensor out = tf.zeros({3});
51 
52   op_masked_select_out(in, mask, out);
53   EXPECT_TENSOR_EQ(out, tf.make({3}, {1, 1, 3}));
54 }
55 
TEST_F(OpMaskedSelectOutTest,BroadcastMask)56 TEST_F(OpMaskedSelectOutTest, BroadcastMask) {
57   TensorFactory<ScalarType::Int> tf;
58   TensorFactory<ScalarType::Bool> tfBool;
59 
60   Tensor in = tf.make({2, 3}, {1, 2, 3, 4, 5, 6});
61   Tensor mask = tfBool.make({3}, {false, true, false});
62 
63   Tensor out = tf.zeros({2});
64 
65   op_masked_select_out(in, mask, out);
66   EXPECT_TENSOR_EQ(out, tf.make({2}, {2, 5}));
67 }
68 
TEST_F(OpMaskedSelectOutTest,BroadcastInputAndMask)69 TEST_F(OpMaskedSelectOutTest, BroadcastInputAndMask) {
70   TensorFactory<ScalarType::Int> tf;
71   TensorFactory<ScalarType::Bool> tfBool;
72 
73   Tensor in = tf.ones({2, 3, 4, 1});
74   Tensor mask = tfBool.ones({2, 1, 1, 5});
75   Tensor out = tf.zeros({120});
76 
77   op_masked_select_out(in, mask, out);
78   EXPECT_TENSOR_EQ(out, tf.ones({120}));
79 }
80 
TEST_F(OpMaskedSelectOutTest,EmptyInput)81 TEST_F(OpMaskedSelectOutTest, EmptyInput) {
82   TensorFactory<ScalarType::Int> tf;
83   TensorFactory<ScalarType::Bool> tfBool;
84 
85   Tensor in = tf.make({2, 0}, {});
86   Tensor mask = tfBool.make({2, 1}, {true, true});
87   Tensor out = tf.zeros({0});
88 
89   op_masked_select_out(in, mask, out);
90   EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
91 }
92 
TEST_F(OpMaskedSelectOutTest,EmptyMask)93 TEST_F(OpMaskedSelectOutTest, EmptyMask) {
94   TensorFactory<ScalarType::Int> tf;
95   TensorFactory<ScalarType::Bool> tfBool;
96 
97   Tensor in = tf.make({2, 1}, {100, 200});
98   Tensor mask = tfBool.make({2, 0}, {});
99   Tensor out = tf.zeros({0});
100 
101   op_masked_select_out(in, mask, out);
102   EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
103 }
104 
TEST_F(OpMaskedSelectOutTest,EmptyInputAndMask)105 TEST_F(OpMaskedSelectOutTest, EmptyInputAndMask) {
106   TensorFactory<ScalarType::Int> tf;
107   TensorFactory<ScalarType::Bool> tfBool;
108 
109   Tensor in = tf.make({2, 0}, {});
110   Tensor mask = tfBool.make({0}, {});
111   Tensor out = tf.zeros({0});
112 
113   op_masked_select_out(in, mask, out);
114   EXPECT_TENSOR_EQ(out, tf.make({0}, {}));
115 }
116