xref: /aosp_15_r20/external/executorch/kernels/test/op_lt_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <gtest/gtest.h>
16 
17 using namespace ::testing;
18 using exec_aten::Scalar;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using executorch::runtime::KernelRuntimeContext;
22 using torch::executor::testing::TensorFactory;
23 
24 class OpLtScalarOutTest : public OperatorTest {
25  protected:
op_lt_scalar_out(const Tensor & self,Scalar & other,Tensor & out)26   Tensor& op_lt_scalar_out(const Tensor& self, Scalar& other, Tensor& out) {
27     return torch::executor::aten::lt_outf(context_, self, other, out);
28   }
29 
30   template <ScalarType DTYPE_IN, ScalarType DTYPE_OUT>
test_lt_scalar_out()31   void test_lt_scalar_out() {
32     TensorFactory<DTYPE_IN> tf;
33     TensorFactory<DTYPE_OUT> tf_out;
34 
35     const std::vector<int32_t> sizes = {2, 2};
36     Tensor out = tf_out.ones(sizes);
37     Scalar other = 2;
38 
39     // Valid input should give the expected output
40     op_lt_scalar_out(tf.make(sizes, /*data=*/{3, 1, 2, 4}), other, out);
41     EXPECT_TENSOR_EQ(
42         out, tf_out.make(sizes, /*data=*/{false, true, false, false}));
43   }
44 };
45 
46 class OpLtTensorOutTest : public OperatorTest {
47  protected:
48   Tensor&
op_lt_tensor_out(const Tensor & self,const Tensor & other,Tensor & out)49   op_lt_tensor_out(const Tensor& self, const Tensor& other, Tensor& out) {
50     return torch::executor::aten::lt_outf(context_, self, other, out);
51   }
52 
53   template <ScalarType DTYPE_IN, ScalarType DTYPE_OUT>
test_dtype()54   void test_dtype() {
55     TensorFactory<DTYPE_IN> tf_input;
56     TensorFactory<DTYPE_OUT> tf_out;
57     Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4});
58     Tensor b = tf_input.make({2, 2}, {1, 4, 2, 3});
59     Tensor out = tf_out.zeros({2, 2});
60 
61     op_lt_tensor_out(a, b, out);
62     EXPECT_TENSOR_EQ(out, tf_out.make({2, 2}, {false, true, false, false}));
63   }
64 };
65 
TEST_F(OpLtScalarOutTest,AllRealInputBoolOutputSupport)66 TEST_F(OpLtScalarOutTest, AllRealInputBoolOutputSupport) {
67 #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
68   test_lt_scalar_out<ScalarType::dtype_in, ScalarType::dtype_out>();
69 
70 #define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in)            \
71   ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
72   test_lt_scalar_out<ScalarType::dtype_in, ScalarType::Bool>();
73 
74   ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES)
75 
76 #undef TEST_FORALL_OUT_TYPES
77 #undef TEST_ENTRY
78 }
79 
TEST_F(OpLtScalarOutTest,BoolInputDtype)80 TEST_F(OpLtScalarOutTest, BoolInputDtype) {
81   TensorFactory<ScalarType::Bool> tf_bool;
82 
83   const std::vector<int32_t> sizes = {2, 2};
84   Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
85   Tensor out = tf_bool.zeros(sizes);
86   Scalar other = 0.5;
87 
88   op_lt_scalar_out(a, other, out);
89   EXPECT_TENSOR_EQ(
90       out, tf_bool.make(sizes, /*data=*/{true, false, true, false}));
91 }
92 
93 // Mismatched shape tests.
TEST_F(OpLtScalarOutTest,MismatchedInOutShapesDies)94 TEST_F(OpLtScalarOutTest, MismatchedInOutShapesDies) {
95   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
96     GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
97   }
98   TensorFactory<ScalarType::Int> tf_int;
99   TensorFactory<ScalarType::Bool> tf_bool;
100 
101   Tensor a = tf_int.ones(/*sizes=*/{4});
102   Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
103   Scalar other = 3;
104 
105   ET_EXPECT_KERNEL_FAILURE(context_, op_lt_scalar_out(a, other, out));
106 }
107 
TEST_F(OpLtScalarOutTest,DynamicOutShapeTest)108 TEST_F(OpLtScalarOutTest, DynamicOutShapeTest) {
109   TensorFactory<ScalarType::Int> tf;
110 
111   const std::vector<int32_t> sizes = {2, 2};
112   const std::vector<int32_t> out_sizes = {4, 1};
113 
114   Tensor out =
115       tf.zeros(out_sizes, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
116   Scalar other = 2;
117 
118   // Valid input should give the expected output
119   op_lt_scalar_out(tf.make(sizes, /*data=*/{3, 1, 2, 4}), other, out);
120   EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{false, true, false, false}));
121 }
122 
TEST_F(OpLtTensorOutTest,AllDtypesSupported)123 TEST_F(OpLtTensorOutTest, AllDtypesSupported) {
124 #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
125   test_dtype<ScalarType::dtype_in, ScalarType::dtype_out>();
126 
127 #define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in)            \
128   ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
129   test_dtype<ScalarType::dtype_in, ScalarType::Bool>();
130 
131   ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES);
132 
133 #undef TEST_FORALL_OUT_TYPES
134 #undef TEST_ENTRY
135 }
136 
TEST_F(OpLtTensorOutTest,MismatchedInShapesDies)137 TEST_F(OpLtTensorOutTest, MismatchedInShapesDies) {
138   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
139     GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
140   }
141   TensorFactory<ScalarType::Int> tf_int;
142   TensorFactory<ScalarType::Bool> tf_bool;
143 
144   Tensor a = tf_int.ones(/*sizes=*/{4});
145   Tensor b = tf_int.ones(/*sizes=*/{2, 2});
146   Tensor out = tf_bool.ones(/*sizes=*/{4});
147 
148   ET_EXPECT_KERNEL_FAILURE(context_, op_lt_tensor_out(a, b, out));
149 }
150 
TEST_F(OpLtTensorOutTest,MismatchedInOutShapesDies)151 TEST_F(OpLtTensorOutTest, MismatchedInOutShapesDies) {
152   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
153     GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
154   }
155   TensorFactory<ScalarType::Int> tf_int;
156   TensorFactory<ScalarType::Bool> tf_bool;
157 
158   Tensor a = tf_int.ones(/*sizes=*/{4});
159   Tensor b = tf_int.ones(/*sizes=*/{4});
160   Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
161 
162   ET_EXPECT_KERNEL_FAILURE(context_, op_lt_tensor_out(a, b, out));
163 }
164 
TEST_F(OpLtTensorOutTest,DynamicOutShapeTest)165 TEST_F(OpLtTensorOutTest, DynamicOutShapeTest) {
166   TensorFactory<ScalarType::Int> tf;
167 
168   Tensor a = tf.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4});
169   Tensor b = tf.make({2, 2}, {1, 4, 2, 3});
170 
171   Tensor out =
172       tf.zeros({1, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
173 
174   op_lt_tensor_out(a, b, out);
175   EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {false, true, false, false}));
176 }
177