1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <gtest/gtest.h>
16
17 using namespace ::testing;
18 using exec_aten::Scalar;
19 using exec_aten::ScalarType;
20 using exec_aten::Tensor;
21 using executorch::runtime::KernelRuntimeContext;
22 using torch::executor::testing::TensorFactory;
23
24 class OpLtScalarOutTest : public OperatorTest {
25 protected:
op_lt_scalar_out(const Tensor & self,Scalar & other,Tensor & out)26 Tensor& op_lt_scalar_out(const Tensor& self, Scalar& other, Tensor& out) {
27 return torch::executor::aten::lt_outf(context_, self, other, out);
28 }
29
30 template <ScalarType DTYPE_IN, ScalarType DTYPE_OUT>
test_lt_scalar_out()31 void test_lt_scalar_out() {
32 TensorFactory<DTYPE_IN> tf;
33 TensorFactory<DTYPE_OUT> tf_out;
34
35 const std::vector<int32_t> sizes = {2, 2};
36 Tensor out = tf_out.ones(sizes);
37 Scalar other = 2;
38
39 // Valid input should give the expected output
40 op_lt_scalar_out(tf.make(sizes, /*data=*/{3, 1, 2, 4}), other, out);
41 EXPECT_TENSOR_EQ(
42 out, tf_out.make(sizes, /*data=*/{false, true, false, false}));
43 }
44 };
45
46 class OpLtTensorOutTest : public OperatorTest {
47 protected:
48 Tensor&
op_lt_tensor_out(const Tensor & self,const Tensor & other,Tensor & out)49 op_lt_tensor_out(const Tensor& self, const Tensor& other, Tensor& out) {
50 return torch::executor::aten::lt_outf(context_, self, other, out);
51 }
52
53 template <ScalarType DTYPE_IN, ScalarType DTYPE_OUT>
test_dtype()54 void test_dtype() {
55 TensorFactory<DTYPE_IN> tf_input;
56 TensorFactory<DTYPE_OUT> tf_out;
57 Tensor a = tf_input.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4});
58 Tensor b = tf_input.make({2, 2}, {1, 4, 2, 3});
59 Tensor out = tf_out.zeros({2, 2});
60
61 op_lt_tensor_out(a, b, out);
62 EXPECT_TENSOR_EQ(out, tf_out.make({2, 2}, {false, true, false, false}));
63 }
64 };
65
TEST_F(OpLtScalarOutTest,AllRealInputBoolOutputSupport)66 TEST_F(OpLtScalarOutTest, AllRealInputBoolOutputSupport) {
67 #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
68 test_lt_scalar_out<ScalarType::dtype_in, ScalarType::dtype_out>();
69
70 #define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
71 ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
72 test_lt_scalar_out<ScalarType::dtype_in, ScalarType::Bool>();
73
74 ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES)
75
76 #undef TEST_FORALL_OUT_TYPES
77 #undef TEST_ENTRY
78 }
79
TEST_F(OpLtScalarOutTest,BoolInputDtype)80 TEST_F(OpLtScalarOutTest, BoolInputDtype) {
81 TensorFactory<ScalarType::Bool> tf_bool;
82
83 const std::vector<int32_t> sizes = {2, 2};
84 Tensor a = tf_bool.make(sizes, /*data=*/{false, true, false, true});
85 Tensor out = tf_bool.zeros(sizes);
86 Scalar other = 0.5;
87
88 op_lt_scalar_out(a, other, out);
89 EXPECT_TENSOR_EQ(
90 out, tf_bool.make(sizes, /*data=*/{true, false, true, false}));
91 }
92
93 // Mismatched shape tests.
TEST_F(OpLtScalarOutTest,MismatchedInOutShapesDies)94 TEST_F(OpLtScalarOutTest, MismatchedInOutShapesDies) {
95 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
96 GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
97 }
98 TensorFactory<ScalarType::Int> tf_int;
99 TensorFactory<ScalarType::Bool> tf_bool;
100
101 Tensor a = tf_int.ones(/*sizes=*/{4});
102 Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
103 Scalar other = 3;
104
105 ET_EXPECT_KERNEL_FAILURE(context_, op_lt_scalar_out(a, other, out));
106 }
107
TEST_F(OpLtScalarOutTest,DynamicOutShapeTest)108 TEST_F(OpLtScalarOutTest, DynamicOutShapeTest) {
109 TensorFactory<ScalarType::Int> tf;
110
111 const std::vector<int32_t> sizes = {2, 2};
112 const std::vector<int32_t> out_sizes = {4, 1};
113
114 Tensor out =
115 tf.zeros(out_sizes, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
116 Scalar other = 2;
117
118 // Valid input should give the expected output
119 op_lt_scalar_out(tf.make(sizes, /*data=*/{3, 1, 2, 4}), other, out);
120 EXPECT_TENSOR_EQ(out, tf.make(sizes, /*data=*/{false, true, false, false}));
121 }
122
TEST_F(OpLtTensorOutTest,AllDtypesSupported)123 TEST_F(OpLtTensorOutTest, AllDtypesSupported) {
124 #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \
125 test_dtype<ScalarType::dtype_in, ScalarType::dtype_out>();
126
127 #define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \
128 ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \
129 test_dtype<ScalarType::dtype_in, ScalarType::Bool>();
130
131 ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES);
132
133 #undef TEST_FORALL_OUT_TYPES
134 #undef TEST_ENTRY
135 }
136
TEST_F(OpLtTensorOutTest,MismatchedInShapesDies)137 TEST_F(OpLtTensorOutTest, MismatchedInShapesDies) {
138 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
139 GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
140 }
141 TensorFactory<ScalarType::Int> tf_int;
142 TensorFactory<ScalarType::Bool> tf_bool;
143
144 Tensor a = tf_int.ones(/*sizes=*/{4});
145 Tensor b = tf_int.ones(/*sizes=*/{2, 2});
146 Tensor out = tf_bool.ones(/*sizes=*/{4});
147
148 ET_EXPECT_KERNEL_FAILURE(context_, op_lt_tensor_out(a, b, out));
149 }
150
TEST_F(OpLtTensorOutTest,MismatchedInOutShapesDies)151 TEST_F(OpLtTensorOutTest, MismatchedInOutShapesDies) {
152 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
153 GTEST_SKIP() << "ATen kernel can handle mismatched shapes";
154 }
155 TensorFactory<ScalarType::Int> tf_int;
156 TensorFactory<ScalarType::Bool> tf_bool;
157
158 Tensor a = tf_int.ones(/*sizes=*/{4});
159 Tensor b = tf_int.ones(/*sizes=*/{4});
160 Tensor out = tf_bool.ones(/*sizes=*/{2, 2});
161
162 ET_EXPECT_KERNEL_FAILURE(context_, op_lt_tensor_out(a, b, out));
163 }
164
TEST_F(OpLtTensorOutTest,DynamicOutShapeTest)165 TEST_F(OpLtTensorOutTest, DynamicOutShapeTest) {
166 TensorFactory<ScalarType::Int> tf;
167
168 Tensor a = tf.make(/*sizes=*/{2, 2}, /*data=*/{2, 3, 2, 4});
169 Tensor b = tf.make({2, 2}, {1, 4, 2, 3});
170
171 Tensor out =
172 tf.zeros({1, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
173
174 op_lt_tensor_out(a, b, out);
175 EXPECT_TENSOR_EQ(out, tf.make({2, 2}, {false, true, false, false}));
176 }
177