xref: /aosp_15_r20/external/executorch/kernels/test/op_mm_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 #include <executorch/runtime/platform/runtime.h>
17 
18 #include <gtest/gtest.h>
19 #include <limits>
20 
21 using namespace ::testing;
22 using exec_aten::ArrayRef;
23 using exec_aten::Scalar;
24 using exec_aten::ScalarType;
25 using exec_aten::Tensor;
26 using torch::executor::testing::TensorFactory;
27 
28 class OpMmOutTest : public OperatorTest {
29  protected:
op_mm_out(const Tensor & self,const Tensor & mat2,Tensor & out)30   Tensor& op_mm_out(const Tensor& self, const Tensor& mat2, Tensor& out) {
31     return torch::executor::aten::mm_outf(context_, self, mat2, out);
32   }
33 
34   template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()35   void test_dtype() {
36     TensorFactory<DTYPE> tf;
37 
38     if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
39       if (DTYPE == ScalarType::Half) {
40         GTEST_SKIP()
41             << "skip Half because torch::executor::aten::mm_out does not support Half";
42         return;
43       }
44     }
45 
46     // matmul gives 4 * 2 * 3 = 24
47     Tensor x = tf.full({3, 4}, 2);
48     Tensor y = tf.full({4, 5}, 3);
49 
50     // Output shape should be (3, 5)
51     Tensor out = tf.zeros({3, 5});
52 
53     op_mm_out(x, y, out);
54 
55     Tensor expected = tf.full({3, 5}, 24);
56 
57     EXPECT_TENSOR_EQ(out, expected);
58   }
59 };
60 
TEST_F(OpMmOutTest,OutputDim)61 TEST_F(OpMmOutTest, OutputDim) {
62   TensorFactory<ScalarType::Int> tf;
63 
64   // 3 tensors with compatible dimensions: (3, 5), (3, 4) and (4, 5).
65   Tensor x = tf.ones({3, 4});
66   Tensor y = tf.ones({4, 5});
67   Tensor out = tf.zeros({3, 5});
68 
69   Tensor ret = op_mm_out(x, y, out);
70 
71   // Should always return the provided out Tensor.
72   EXPECT_TENSOR_EQ(ret, out);
73 
74   // Expected tensor, filled with 4.
75   Tensor expected = tf.full({3, 5}, 4);
76 
77   EXPECT_TENSOR_EQ(out, expected);
78 }
79 
80 /// A generic smoke test that works for any dtype that supports ones() and
81 /// zeros().
TEST_F(OpMmOutTest,AllDtypesSupported)82 TEST_F(OpMmOutTest, AllDtypesSupported) {
83 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
84   ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
85 #undef TEST_ENTRY
86   // TODO: Also add tests for half, complex, quantized, and other types. Easiest
87   // way to do that would be to make TensorFactory support zeros() and ones()
88   // for those types.
89 }
90 
TEST_F(OpMmOutTest,EmptyInputWithEmptyOutTensorPasses)91 TEST_F(OpMmOutTest, EmptyInputWithEmptyOutTensorPasses) {
92   TensorFactory<ScalarType::Float> tf;
93 
94   // Empty input matrices
95   Tensor x = tf.make({0, 3}, {});
96   Tensor y = tf.make({3, 0}, {});
97 
98   // Output matrix is also empty
99   Tensor out = tf.make({0, 0}, {});
100 
101   Tensor expected = tf.make({0, 0}, {});
102 
103   EXPECT_TENSOR_EQ(op_mm_out(x, y, out), expected);
104 }
105 
TEST_F(OpMmOutTest,InfinityTensorPasses)106 TEST_F(OpMmOutTest, InfinityTensorPasses) {
107   TensorFactory<ScalarType::Float> tff;
108 
109   Tensor x = tff.full({3, 4}, std::numeric_limits<float>::infinity());
110   Tensor y = tff.full({4, 5}, 3);
111 
112   // Output shape should be (3, 5)
113   Tensor out = tff.zeros({3, 5});
114 
115   Tensor expected = tff.full({3, 5}, std::numeric_limits<float>::infinity());
116 
117   EXPECT_TENSOR_EQ(op_mm_out(x, y, out), expected);
118 }
119 
TEST_F(OpMmOutTest,MismatchedDimensionsDies)120 TEST_F(OpMmOutTest, MismatchedDimensionsDies) {
121   TensorFactory<ScalarType::Int> tf;
122 
123   Tensor x = tf.full({2, 2}, 3);
124 
125   Tensor wrong_y = tf.full({3, 1}, 1);
126   Tensor right_y = tf.full({2, 2}, 1);
127 
128   // Make an empty out tensor and demonstrate that it's empty.
129   Tensor out = tf.full({2, 2}, 0);
130 
131   Tensor expected = tf.full({2, 2}, 6);
132   ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, wrong_y, out));
133 
134   EXPECT_TENSOR_EQ(op_mm_out(x, right_y, out), expected);
135 }
136 
TEST_F(OpMmOutTest,MismatchedDimensionSizeDies)137 TEST_F(OpMmOutTest, MismatchedDimensionSizeDies) {
138   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
139     GTEST_SKIP() << "ATen kernel can handle mismatched dimension size";
140   }
141   TensorFactory<ScalarType::Int> tf;
142   Tensor x = tf.full({2, 2}, 3);
143 
144   // wrong_y has incompatible dim
145   Tensor wrong_y = tf.full({2, 2, 2}, 1);
146   Tensor right_y = tf.full({2, 2}, 1);
147 
148   // wrong_out has incompatible dim
149   Tensor right_out = tf.ones({2, 2});
150   Tensor wrong_out = tf.ones({2, 2, 3});
151 
152   ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, right_y, wrong_out));
153   ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, wrong_y, right_out));
154 }
155 
TEST_F(OpMmOutTest,WrongOutShapeDies)156 TEST_F(OpMmOutTest, WrongOutShapeDies) {
157   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
158     GTEST_SKIP() << "ATen kernel can handle wrong out shape";
159   }
160   TensorFactory<ScalarType::Int> tf;
161   Tensor x = tf.ones({10, 3});
162 
163   Tensor y = tf.ones({3, 4});
164 
165   // wrong_out has incompatible shape
166   Tensor right_out = tf.ones({10, 4});
167   Tensor wrong_out = tf.ones({7, 5});
168 
169   ET_EXPECT_KERNEL_FAILURE(context_, op_mm_out(x, y, wrong_out));
170 
171   EXPECT_TENSOR_EQ(op_mm_out(x, y, right_out), tf.full({10, 4}, 3));
172 }
173 
TEST_F(OpMmOutTest,DynamicShapeUpperBoundSameAsExpected)174 TEST_F(OpMmOutTest, DynamicShapeUpperBoundSameAsExpected) {
175   TensorFactory<ScalarType::Float> tf;
176 
177   Tensor x = tf.make(
178       {3, 2},
179       {0.17412060499191284,
180        0.34793388843536377,
181        0.8187907934188843,
182        0.9979893565177917,
183        0.7049332857131958,
184        0.4255824089050293});
185   Tensor y = tf.make(
186       {2, 4},
187       {0.8071839213371277,
188        0.13667285442352295,
189        0.9002121090888977,
190        0.9070476293563843,
191        0.31638312339782715,
192        0.3691965937614441,
193        0.09420186281204224,
194        0.9310881495475769});
195   Tensor expected_result = tf.make(
196       {3, 4},
197       {0.2506277561187744,
198        0.15225356817245483,
199        0.18952149152755737,
200        0.48189279437065125,
201        0.976661741733551,
202        0.480360746383667,
203        0.8310978412628174,
204        1.6718982458114624,
205        0.703657865524292,
206        0.2534688115119934,
207        0.6746801733970642,
208        1.0356627702713013});
209 
210   Tensor out =
211       tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
212   Tensor ret = op_mm_out(x, y, out);
213   EXPECT_TENSOR_CLOSE(out, expected_result);
214 }
215 
TEST_F(OpMmOutTest,DynamicShapeUpperBoundLargerThanExpected)216 TEST_F(OpMmOutTest, DynamicShapeUpperBoundLargerThanExpected) {
217   TensorFactory<ScalarType::Float> tf;
218 
219   Tensor x = tf.make(
220       {3, 2},
221       {0.17412060499191284,
222        0.34793388843536377,
223        0.8187907934188843,
224        0.9979893565177917,
225        0.7049332857131958,
226        0.4255824089050293});
227   Tensor y = tf.make(
228       {2, 4},
229       {0.8071839213371277,
230        0.13667285442352295,
231        0.9002121090888977,
232        0.9070476293563843,
233        0.31638312339782715,
234        0.3691965937614441,
235        0.09420186281204224,
236        0.9310881495475769});
237   Tensor expected_result = tf.make(
238       {3, 4},
239       {0.2506277561187744,
240        0.15225356817245483,
241        0.18952149152755737,
242        0.48189279437065125,
243        0.976661741733551,
244        0.480360746383667,
245        0.8310978412628174,
246        1.6718982458114624,
247        0.703657865524292,
248        0.2534688115119934,
249        0.6746801733970642,
250        1.0356627702713013});
251 
252   Tensor out =
253       tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
254   Tensor ret = op_mm_out(x, y, out);
255   EXPECT_TENSOR_CLOSE(out, expected_result);
256 }
257 
TEST_F(OpMmOutTest,DynamicShapeUnbound)258 TEST_F(OpMmOutTest, DynamicShapeUnbound) {
259   GTEST_SKIP() << "Dynamic shape not supported";
260   TensorFactory<ScalarType::Float> tf;
261 
262   Tensor x = tf.make(
263       {3, 2},
264       {0.17412060499191284,
265        0.34793388843536377,
266        0.8187907934188843,
267        0.9979893565177917,
268        0.7049332857131958,
269        0.4255824089050293});
270   Tensor y = tf.make(
271       {2, 4},
272       {0.8071839213371277,
273        0.13667285442352295,
274        0.9002121090888977,
275        0.9070476293563843,
276        0.31638312339782715,
277        0.3691965937614441,
278        0.09420186281204224,
279        0.9310881495475769});
280   Tensor expected_result = tf.make(
281       {3, 4},
282       {0.2506277561187744,
283        0.15225356817245483,
284        0.18952149152755737,
285        0.48189279437065125,
286        0.976661741733551,
287        0.480360746383667,
288        0.8310978412628174,
289        1.6718982458114624,
290        0.703657865524292,
291        0.2534688115119934,
292        0.6746801733970642,
293        1.0356627702713013});
294 
295   Tensor out =
296       tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
297   Tensor ret = op_mm_out(x, y, out);
298   EXPECT_TENSOR_CLOSE(out, expected_result);
299 }
300