xref: /aosp_15_r20/external/executorch/kernels/test/op_eq_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::Scalar;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::TensorFactory;
23 
24 class OpEqScalarOutTest : public OperatorTest {
25  protected:
op_eq_scalar_out(const Tensor & self,Scalar & other,Tensor & out)26   Tensor& op_eq_scalar_out(const Tensor& self, Scalar& other, Tensor& out) {
27     return torch::executor::aten::eq_outf(context_, self, other, out);
28   }
29 
30   // Common testing for eq operator
31   template <ScalarType DTYPE>
test_eq_scalar_out()32   void test_eq_scalar_out() {
33     TensorFactory<DTYPE> tf;
34     TensorFactory<ScalarType::Bool> tf_out;
35 
36     const std::vector<int32_t> sizes = {2, 2};
37     // Destination for the eq
38     Tensor out = tf_out.ones(sizes);
39     Scalar other = 3;
40 
41     // Valid input should give the expected output
42     op_eq_scalar_out(tf.make(sizes, /*data=*/{2, 3, 3, 3}), other, out);
43     EXPECT_TENSOR_EQ(
44         out, tf_out.make(sizes, /*data=*/{false, true, true, true}));
45   }
46 
47   // Handle all output dtypes.
48   template <ScalarType OUTPUT_DTYPE>
test_eq_all_output_dtypes()49   void test_eq_all_output_dtypes() {
50     TensorFactory<ScalarType::Float> tf_float;
51     TensorFactory<OUTPUT_DTYPE> tf_out;
52 
53     const std::vector<int32_t> sizes = {2, 5};
54 
55     Tensor in = tf_float.ones(sizes);
56     Tensor out = tf_out.zeros(sizes);
57     Scalar other = 1;
58 
59     op_eq_scalar_out(in, other, out);
60     EXPECT_TENSOR_EQ(out, tf_out.ones(sizes));
61   }
62 };
63 
TEST_F(OpEqScalarOutTest,AllRealInputBoolOutputSupport)64 TEST_F(OpEqScalarOutTest, AllRealInputBoolOutputSupport) {
65 #define TEST_ENTRY(ctype, dtype) test_eq_scalar_out<ScalarType::dtype>();
66   ET_FORALL_REAL_TYPES(TEST_ENTRY);
67 #undef TEST_ENTRY
68 }
69 
TEST_F(OpEqScalarOutTest,BoolInputDtype)70 TEST_F(OpEqScalarOutTest, BoolInputDtype) {
71   TensorFactory<ScalarType::Bool> tf_bool;
72 
73   const std::vector<int32_t> sizes = {2, 2};
74   Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
75   Tensor out = tf_bool.zeros(sizes);
76   Scalar other = 1;
77 
78   op_eq_scalar_out(a, other, out);
79   EXPECT_TENSOR_EQ(
80       out, tf_bool.make(sizes, /*data=*/{false, true, false, true}));
81 }
82 
83 // Mismatched shape tests.
TEST_F(OpEqScalarOutTest,MismatchedShapesDies)84 TEST_F(OpEqScalarOutTest, MismatchedShapesDies) {
85   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
86     GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
87   }
88   TensorFactory<ScalarType::Int> tf_int;
89   TensorFactory<ScalarType::Bool> tf_bool;
90 
91   Tensor a = tf_int.ones(/*sizes=*/{4});
92   Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
93   Scalar other = 3;
94 
95   ET_EXPECT_KERNEL_FAILURE(context_, op_eq_scalar_out(a, other, out));
96 }
97 
TEST_F(OpEqScalarOutTest,AllRealOutputDTypes)98 TEST_F(OpEqScalarOutTest, AllRealOutputDTypes) {
99   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
100     GTEST_SKIP() << "ATen kernel can handle non-bool output dtype";
101   }
102 #define TEST_ENTRY(ctype, dtype) test_eq_all_output_dtypes<ScalarType::dtype>();
103   ET_FORALL_REAL_TYPES(TEST_ENTRY);
104 #undef TEST_ENTRY
105 }
106 
107 /* %python
108 import torch
109 torch.manual_seed(0)
110 x = torch.randint(3, (3, 2))
111 res = torch.eq(x, 2)
112 op = "op_eq_scalar_out"
113 opt_setup_params = """
114   Scalar other = 2;
115 """
116 opt_extra_params = "other,"
117 dtype = "ScalarType::Int"
118 out_dtype = "ScalarType::Bool"
119 check = "EXPECT_TENSOR_EQ" */
120 
TEST_F(OpEqScalarOutTest,DynamicShapeUpperBoundSameAsExpected)121 TEST_F(OpEqScalarOutTest, DynamicShapeUpperBoundSameAsExpected) {
122   /* %python
123   out_args = "{3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
124   %rewrite(unary_op_out_dtype) */
125 
126   TensorFactory<ScalarType::Int> tf;
127   TensorFactory<ScalarType::Bool> tfOut;
128 
129   Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
130   Tensor expected =
131       tfOut.make({3, 2}, {true, false, true, false, false, false});
132 
133   Scalar other = 2;
134 
135   Tensor out =
136       tfOut.zeros({3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
137   op_eq_scalar_out(x, other, out);
138   EXPECT_TENSOR_EQ(out, expected);
139 }
140 
TEST_F(OpEqScalarOutTest,DynamicShapeUpperBoundLargerThanExpected)141 TEST_F(OpEqScalarOutTest, DynamicShapeUpperBoundLargerThanExpected) {
142   /* %python
143   out_args = "{10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND"
144   %rewrite(unary_op_out_dtype) */
145 
146   TensorFactory<ScalarType::Int> tf;
147   TensorFactory<ScalarType::Bool> tfOut;
148 
149   Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
150   Tensor expected =
151       tfOut.make({3, 2}, {true, false, true, false, false, false});
152 
153   Scalar other = 2;
154 
155   Tensor out = tfOut.zeros(
156       {10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
157   op_eq_scalar_out(x, other, out);
158   EXPECT_TENSOR_EQ(out, expected);
159 }
160 
TEST_F(OpEqScalarOutTest,DynamicShapeUnbound)161 TEST_F(OpEqScalarOutTest, DynamicShapeUnbound) {
162   if (!torch::executor::testing::SupportedFeatures::get()->output_resize) {
163     GTEST_SKIP() << "Dynamic shape unbound not supported";
164   }
165   /* %python
166   out_args = "{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND"
167   %rewrite(unary_op_out_dtype) */
168 
169   TensorFactory<ScalarType::Int> tf;
170   TensorFactory<ScalarType::Bool> tfOut;
171 
172   Tensor x = tf.make({3, 2}, {2, 0, 2, 0, 1, 0});
173   Tensor expected =
174       tfOut.make({3, 2}, {true, false, true, false, false, false});
175 
176   Scalar other = 2;
177 
178   Tensor out = tfOut.zeros(
179       {1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
180   op_eq_scalar_out(x, other, out);
181   EXPECT_TENSOR_EQ(out, expected);
182 }
183