1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19 using exec_aten::Scalar;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::TensorFactory;
23
24 class OpEqScalarOutTest : public OperatorTest {
25 protected:
op_eq_scalar_out(const Tensor & self,Scalar & other,Tensor & out)26 Tensor& op_eq_scalar_out(const Tensor& self, Scalar& other, Tensor& out) {
27 return torch::executor::aten::eq_outf(context_, self, other, out);
28 }
29
30 // Common testing for eq operator
31 template <ScalarType DTYPE>
test_eq_scalar_out()32 void test_eq_scalar_out() {
33 TensorFactory<DTYPE> tf;
34 TensorFactory<ScalarType::Bool> tf_out;
35
36 const std::vector<int32_t> sizes = {2, 2};
37 // Destination for the eq
38 Tensor out = tf_out.ones(sizes);
39 Scalar other = 3;
40
41 // Valid input should give the expected output
42 op_eq_scalar_out(tf.make(sizes, /*data=*/{2, 3, 3, 3}), other, out);
43 EXPECT_TENSOR_EQ(
44 out, tf_out.make(sizes, /*data=*/{false, true, true, true}));
45 }
46
47 // Handle all output dtypes.
48 template <ScalarType OUTPUT_DTYPE>
test_eq_all_output_dtypes()49 void test_eq_all_output_dtypes() {
50 TensorFactory<ScalarType::Float> tf_float;
51 TensorFactory<OUTPUT_DTYPE> tf_out;
52
53 const std::vector<int32_t> sizes = {2, 5};
54
55 Tensor in = tf_float.ones(sizes);
56 Tensor out = tf_out.zeros(sizes);
57 Scalar other = 1;
58
59 op_eq_scalar_out(in, other, out);
60 EXPECT_TENSOR_EQ(out, tf_out.ones(sizes));
61 }
62 };
63
TEST_F(OpEqScalarOutTest,AllRealInputBoolOutputSupport)64 TEST_F(OpEqScalarOutTest, AllRealInputBoolOutputSupport) {
65 #define TEST_ENTRY(ctype, dtype) test_eq_scalar_out<ScalarType::dtype>();
66 ET_FORALL_REAL_TYPES(TEST_ENTRY);
67 #undef TEST_ENTRY
68 }
69
TEST_F(OpEqScalarOutTest,BoolInputDtype)70 TEST_F(OpEqScalarOutTest, BoolInputDtype) {
71 TensorFactory<ScalarType::Bool> tf_bool;
72
73 const std::vector<int32_t> sizes = {2, 2};
74 Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
75 Tensor out = tf_bool.zeros(sizes);
76 Scalar other = 1;
77
78 op_eq_scalar_out(a, other, out);
79 EXPECT_TENSOR_EQ(
80 out, tf_bool.make(sizes, /*data=*/{false, true, false, true}));
81 }
82
83 // Mismatched shape tests.
TEST_F(OpEqScalarOutTest,MismatchedShapesDies)84 TEST_F(OpEqScalarOutTest, MismatchedShapesDies) {
85 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
86 GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
87 }
88 TensorFactory<ScalarType::Int> tf_int;
89 TensorFactory<ScalarType::Bool> tf_bool;
90
91 Tensor a = tf_int.ones(/*sizes=*/{4});
92 Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
93 Scalar other = 3;
94
95 ET_EXPECT_KERNEL_FAILURE(context_, op_eq_scalar_out(a, other, out));
96 }
97
TEST_F(OpEqScalarOutTest,AllRealOutputDTypes)98 TEST_F(OpEqScalarOutTest, AllRealOutputDTypes) {
99 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
100 GTEST_SKIP() << "ATen kernel can handle non-bool output dtype";
101 }
102 #define TEST_ENTRY(ctype, dtype) test_eq_all_output_dtypes<ScalarType::dtype>();
103 ET_FORALL_REAL_TYPES(TEST_ENTRY);
104 #undef TEST_ENTRY
105 }
106
107 /* %python
108 import torch
109 torch.manual_seed(0)
110 x = torch.randint(3, (3, 2))
111 res = torch.eq(x, 2)
112 op = "op_eq_scalar_out"
113 opt_setup_params = """
114 Scalar other = 2;
115 """
116 opt_extra_params = "other,"
117 dtype = "ScalarType::Int"
118 out_dtype = "ScalarType::Bool"
119 check = "EXPECT_TENSOR_EQ" */
120
TEST_F(OpEqScalarOutTest,DynamicShapeUpperBoundSameAsExpected)121 TEST_F(OpEqScalarOutTest, DynamicShapeUpperBoundSameAsExpected) {
122 /* %python
123 out_args = "{3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
124 %rewrite(unary_op_out_dtype) */
125
126 TensorFactory<ScalarType::Int> tf;
127 TensorFactory<ScalarType::Bool> tfOut;
128
129 Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
130 Tensor expected =
131 tfOut.make({3, 2}, {true, false, true, false, false, false});
132
133 Scalar other = 2;
134
135 Tensor out =
136 tfOut.zeros({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
137 op_eq_scalar_out(x, other, out);
138 EXPECT_TENSOR_EQ(out, expected);
139 }
140
TEST_F(OpEqScalarOutTest,DynamicShapeUpperBoundLargerThanExpected)141 TEST_F(OpEqScalarOutTest, DynamicShapeUpperBoundLargerThanExpected) {
142 /* %python
143 out_args = "{10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
144 %rewrite(unary_op_out_dtype) */
145
146 TensorFactory<ScalarType::Int> tf;
147 TensorFactory<ScalarType::Bool> tfOut;
148
149 Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
150 Tensor expected =
151 tfOut.make({3, 2}, {true, false, true, false, false, false});
152
153 Scalar other = 2;
154
155 Tensor out = tfOut.zeros(
156 {10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
157 op_eq_scalar_out(x, other, out);
158 EXPECT_TENSOR_EQ(out, expected);
159 }
160
TEST_F(OpEqScalarOutTest,DynamicShapeUnbound)161 TEST_F(OpEqScalarOutTest, DynamicShapeUnbound) {
162 if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
163 GTEST_SKIP() << "Dynamic shape unbound not supported";
164 }
165 /* %python
166 out_args = "{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND"
167 %rewrite(unary_op_out_dtype) */
168
169 TensorFactory<ScalarType::Int> tf;
170 TensorFactory<ScalarType::Bool> tfOut;
171
172 Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
173 Tensor expected =
174 tfOut.make({3, 2}, {true, false, true, false, false, false});
175
176 Scalar other = 2;
177
178 Tensor out = tfOut.zeros(
179 {1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
180 op_eq_scalar_out(x, other, out);
181 EXPECT_TENSOR_EQ(out, expected);
182 }
183