1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 #include <executorch/runtime/platform/runtime.h>
17
18 #include <gtest/gtest.h>
19 #include <limits>
20
21 using namespace ::testing;
22 using exec_aten::ArrayRef;
23 using exec_aten::Scalar;
24 using exec_aten::ScalarType;
25 using exec_aten::Tensor;
26 using torch::executor::testing::TensorFactory;
27
28 class OpMmOutTest : public OperatorTest {
29 protected:
op_mm_out(const Tensor & self,const Tensor & mat2,Tensor & out)30 Tensor& op_mm_out(const Tensor& self, const Tensor& mat2, Tensor& out) {
31 return torch::executor::aten::mm_outf(context_, self, mat2, out);
32 }
33
34 template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()35 void test_dtype() {
36 TensorFactory<DTYPE> tf;
37
38 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
39 if (DTYPE == ScalarType::Half) {
40 GTEST_SKIP()
41 << "skip Half because torch::executor::aten::mm_out does not support Half";
42 return;
43 }
44 }
45
46 // matmul gives 4 * 2 * 3 = 24
47 Tensor x = tf.full({3, 4}, 2);
48 Tensor y = tf.full({4, 5}, 3);
49
50 // Output shape should be (3, 5)
51 Tensor out = tf.zeros({3, 5});
52
53 op_mm_out(x, y, out);
54
55 Tensor expected = tf.full({3, 5}, 24);
56
57 EXPECT_TENSOR_EQ(out, expected);
58 }
59 };
60
TEST_F(OpMmOutTest,OutputDim)61 TEST_F(OpMmOutTest, OutputDim) {
62 TensorFactory<ScalarType::Int> tf;
63
64 // 3 tensors with compatible dimensions: (3, 5), (3, 4) and (4, 5).
65 Tensor x = tf.ones({3, 4});
66 Tensor y = tf.ones({4, 5});
67 Tensor out = tf.zeros({3, 5});
68
69 Tensor ret = op_mm_out(x, y, out);
70
71 // Should always return the provided out Tensor.
72 EXPECT_TENSOR_EQ(ret, out);
73
74 // Expected tensor, filled with 4.
75 Tensor expected = tf.full({3, 5}, 4);
76
77 EXPECT_TENSOR_EQ(out, expected);
78 }
79
80 /// A generic smoke test that works for any dtype that supports ones() and
81 /// zeros().
TEST_F(OpMmOutTest,AllDtypesSupported)82 TEST_F(OpMmOutTest, AllDtypesSupported) {
83 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
84 ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
85 #undef TEST_ENTRY
86 // TODO: Also add tests for half, complex, quantized, and other types. Easiest
87 // way to do that would be to make TensorFactory support zeros() and ones()
88 // for those types.
89 }
90
TEST_F(OpMmOutTest,EmptyInputWithEmptyOutTensorPasses)91 TEST_F(OpMmOutTest, EmptyInputWithEmptyOutTensorPasses) {
92 TensorFactory<ScalarType::Float> tf;
93
94 // Empty input matrices
95 Tensor x = tf.make({0, 3}, {});
96 Tensor y = tf.make({3, 0}, {});
97
98 // Output matrix is also empty
99 Tensor out = tf.make({0, 0}, {});
100
101 Tensor expected = tf.make({0, 0}, {});
102
103 EXPECT_TENSOR_EQ(op_mm_out(x, y, out), expected);
104 }
105
TEST_F(OpMmOutTest,InfinityTensorPasses)106 TEST_F(OpMmOutTest, InfinityTensorPasses) {
107 TensorFactory<ScalarType::Float> tff;
108
109 Tensor x = tff.full({3, 4}, std::numeric_limits<float>::infinity());
110 Tensor y = tff.full({4, 5}, 3);
111
112 // Output shape should be (3, 5)
113 Tensor out = tff.zeros({3, 5});
114
115 Tensor expected = tff.full({3, 5}, std::numeric_limits<float>::infinity());
116
117 EXPECT_TENSOR_EQ(op_mm_out(x, y, out), expected);
118 }
119
TEST_F(OpMmOutTest,MismatchedDimensionsDies)120 TEST_F(OpMmOutTest, MismatchedDimensionsDies) {
121 TensorFactory<ScalarType::Int> tf;
122
123 Tensor x = tf.full({2, 2}, 3);
124
125 Tensor wrong_y = tf.full({3, 1}, 1);
126 Tensor right_y = tf.full({2, 2}, 1);
127
128 // Make an empty out tensor and demonstrate that it's empty.
129 Tensor out = tf.full({2, 2}, 0);
130
131 Tensor expected = tf.full({2, 2}, 6);
132 ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, wrong_y, out));
133
134 EXPECT_TENSOR_EQ(op_mm_out(x, right_y, out), expected);
135 }
136
TEST_F(OpMmOutTest,MismatchedDimensionSizeDies)137 TEST_F(OpMmOutTest, MismatchedDimensionSizeDies) {
138 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
139 GTEST_SKIP() << "ATen kernel can handle mismatched dimension size";
140 }
141 TensorFactory<ScalarType::Int> tf;
142 Tensor x = tf.full({2, 2}, 3);
143
144 // wrong_y has incompatible dim
145 Tensor wrong_y = tf.full({2, 2, 2}, 1);
146 Tensor right_y = tf.full({2, 2}, 1);
147
148 // wrong_out has incompatible dim
149 Tensor right_out = tf.ones({2, 2});
150 Tensor wrong_out = tf.ones({2, 2, 3});
151
152 ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, right_y, wrong_out));
153 ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, wrong_y, right_out));
154 }
155
TEST_F(OpMmOutTest,WrongOutShapeDies)156 TEST_F(OpMmOutTest, WrongOutShapeDies) {
157 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
158 GTEST_SKIP() << "ATen kernel can handle wrong out shape";
159 }
160 TensorFactory<ScalarType::Int> tf;
161 Tensor x = tf.ones({10, 3});
162
163 Tensor y = tf.ones({3, 4});
164
165 // wrong_out has incompatible shape
166 Tensor right_out = tf.ones({10, 4});
167 Tensor wrong_out = tf.ones({7, 5});
168
169 ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, y, wrong_out));
170
171 EXPECT_TENSOR_EQ(op_mm_out(x, y, right_out), tf.full({10, 4}, 3));
172 }
173
TEST_F(OpMmOutTest,DynamicShapeUpperBoundSameAsExpected)174 TEST_F(OpMmOutTest, DynamicShapeUpperBoundSameAsExpected) {
175 TensorFactory<ScalarType::Float> tf;
176
177 Tensor x = tf.make(
178 {3, 2},
179 {0.17412060499191284,
180 0.34793388843536377,
181 0.8187907934188843,
182 0.9979893565177917,
183 0.7049332857131958,
184 0.4255824089050293});
185 Tensor y = tf.make(
186 {2, 4},
187 {0.8071839213371277,
188 0.13667285442352295,
189 0.9002121090888977,
190 0.9070476293563843,
191 0.31638312339782715,
192 0.3691965937614441,
193 0.09420186281204224,
194 0.9310881495475769});
195 Tensor expected_result = tf.make(
196 {3, 4},
197 {0.2506277561187744,
198 0.15225356817245483,
199 0.18952149152755737,
200 0.48189279437065125,
201 0.976661741733551,
202 0.480360746383667,
203 0.8310978412628174,
204 1.6718982458114624,
205 0.703657865524292,
206 0.2534688115119934,
207 0.6746801733970642,
208 1.0356627702713013});
209
210 Tensor out =
211 tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
212 Tensor ret = op_mm_out(x, y, out);
213 EXPECT_TENSOR_CLOSE(out, expected_result);
214 }
215
TEST_F(OpMmOutTest,DynamicShapeUpperBoundLargerThanExpected)216 TEST_F(OpMmOutTest, DynamicShapeUpperBoundLargerThanExpected) {
217 TensorFactory<ScalarType::Float> tf;
218
219 Tensor x = tf.make(
220 {3, 2},
221 {0.17412060499191284,
222 0.34793388843536377,
223 0.8187907934188843,
224 0.9979893565177917,
225 0.7049332857131958,
226 0.4255824089050293});
227 Tensor y = tf.make(
228 {2, 4},
229 {0.8071839213371277,
230 0.13667285442352295,
231 0.9002121090888977,
232 0.9070476293563843,
233 0.31638312339782715,
234 0.3691965937614441,
235 0.09420186281204224,
236 0.9310881495475769});
237 Tensor expected_result = tf.make(
238 {3, 4},
239 {0.2506277561187744,
240 0.15225356817245483,
241 0.18952149152755737,
242 0.48189279437065125,
243 0.976661741733551,
244 0.480360746383667,
245 0.8310978412628174,
246 1.6718982458114624,
247 0.703657865524292,
248 0.2534688115119934,
249 0.6746801733970642,
250 1.0356627702713013});
251
252 Tensor out =
253 tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
254 Tensor ret = op_mm_out(x, y, out);
255 EXPECT_TENSOR_CLOSE(out, expected_result);
256 }
257
TEST_F(OpMmOutTest,DynamicShapeUnbound)258 TEST_F(OpMmOutTest, DynamicShapeUnbound) {
259 GTEST_SKIP() << "Dynamic shape not supported";
260 TensorFactory<ScalarType::Float> tf;
261
262 Tensor x = tf.make(
263 {3, 2},
264 {0.17412060499191284,
265 0.34793388843536377,
266 0.8187907934188843,
267 0.9979893565177917,
268 0.7049332857131958,
269 0.4255824089050293});
270 Tensor y = tf.make(
271 {2, 4},
272 {0.8071839213371277,
273 0.13667285442352295,
274 0.9002121090888977,
275 0.9070476293563843,
276 0.31638312339782715,
277 0.3691965937614441,
278 0.09420186281204224,
279 0.9310881495475769});
280 Tensor expected_result = tf.make(
281 {3, 4},
282 {0.2506277561187744,
283 0.15225356817245483,
284 0.18952149152755737,
285 0.48189279437065125,
286 0.976661741733551,
287 0.480360746383667,
288 0.8310978412628174,
289 1.6718982458114624,
290 0.703657865524292,
291 0.2534688115119934,
292 0.6746801733970642,
293 1.0356627702713013});
294
295 Tensor out =
296 tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
297 Tensor ret = op_mm_out(x, y, out);
298 EXPECT_TENSOR_CLOSE(out, expected_result);
299 }
300