1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16
17 #include <gtest/gtest.h>
18
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24
25 class OpBmmOutTest : public OperatorTest {
26 protected:
op_bmm_out(const Tensor & self,const Tensor & mat2,Tensor & out)27 Tensor& op_bmm_out(const Tensor& self, const Tensor& mat2, Tensor& out) {
28 return torch::executor::aten::bmm_outf(context_, self, mat2, out);
29 }
30
31 template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()32 void test_dtype() {
33 TensorFactory<DTYPE> tf;
34
35 // Gives 4 * 2 * 3 = 24, shape (10, 3, 5)
36 Tensor x = tf.full({10, 3, 4}, 2);
37 Tensor y = tf.full({10, 4, 5}, 3);
38
39 Tensor out = tf.zeros({10, 3, 5});
40 op_bmm_out(x, y, out);
41
42 Tensor expected = tf.full({10, 3, 5}, 24);
43
44 EXPECT_TENSOR_EQ(out, expected);
45 }
46 };
47
TEST_F(OpBmmOutTest,OutputDim)48 TEST_F(OpBmmOutTest, OutputDim) {
49 TensorFactory<ScalarType::Int> tf;
50
51 // Two tensors with compatible dimensions: (10, 3, 4) and (10, 4, 5).
52 Tensor x = tf.ones({10, 3, 4});
53 Tensor y = tf.ones({10, 4, 5});
54
55 // Output shape should be (10, 3, 5)
56 Tensor out = tf.zeros({10, 3, 5});
57
58 Tensor ret = op_bmm_out(x, y, out);
59
60 // Should always return the provided out Tensor.
61 EXPECT_TENSOR_EQ(ret, out);
62
63 // Expected tensor, filled with 4.
64 Tensor expected = tf.full({10, 3, 5}, 4);
65
66 EXPECT_TENSOR_EQ(out, expected);
67 }
68
TEST_F(OpBmmOutTest,OutputDimFloat)69 TEST_F(OpBmmOutTest, OutputDimFloat) {
70 TensorFactory<ScalarType::Float> tf;
71
72 // clang-format off
73 Tensor x = tf.make(
74 {2, 4, 5},
75 {
76 4., 3., 1., 1., 1.,
77 3., 1., 4., 4., 2.,
78 1., 1., 1., 3., 3.,
79 4., 2., 2., 2., 3.,
80
81 1., 3., 1., 4., 4.,
82 1., 1., 2., 4., 3.,
83 4., 3., 4., 1., 2.,
84 1., 4., 4., 4., 4.,
85 });
86 // clang-format on
87
88 // clang-format off
89 Tensor y = tf.make(
90 {2, 5, 3},
91 {
92 4., 4., 4.,
93 2., 3., 1.,
94 1., 4., 4.,
95 3., 1., 2.,
96 1., 4., 3.,
97
98 1., 4., 4.,
99 4., 4., 4.,
100 2., 1., 4.,
101 1., 4., 3.,
102 1., 4., 4.,
103 });
104 // clang-format on
105
106 // Output shape should be (10, 3, 5)
107 Tensor out = tf.zeros({2, 4, 3});
108
109 Tensor ret = op_bmm_out(x, y, out);
110
111 // Should always return the provided out Tensor.
112 EXPECT_TENSOR_EQ(ret, out);
113
114 // clang-format off
115 Tensor expected = tf.make(
116 {2, 4, 3},
117 {
118 27., 34., 28.,
119 32., 43., 43.,
120 19., 26., 24.,
121 31., 44., 39.,
122
123 23., 49., 48.,
124 16., 38., 40.,
125 27., 44., 55.,
126 33., 56., 64.,
127 });
128 // clang-format on
129
130 EXPECT_TENSOR_EQ(out, expected);
131 }
132
133 /// A generic smoke test that works for any dtype that supports ones() and
134 /// zeros().
TEST_F(OpBmmOutTest,AllDtypesSupported)135 TEST_F(OpBmmOutTest, AllDtypesSupported) {
136 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
137 ET_FORALL_REAL_TYPES(TEST_ENTRY);
138 #undef TEST_ENTRY
139 // TODO: Also add tests for half, complex, quantized, and other types. Easiest
140 // way to do that would be to make TensorFactory support zeros() and ones()
141 // for those types.
142 }
143
TEST_F(OpBmmOutTest,EmptyInputWithEmptyOutTensorPasses)144 TEST_F(OpBmmOutTest, EmptyInputWithEmptyOutTensorPasses) {
145 TensorFactory<ScalarType::Int> tf;
146
147 Tensor x = tf.full({2, 2, 2}, 3);
148 Tensor y = tf.make({2, 2, 0}, {});
149
150 // Make an empty out tensor and demonstrate that it's empty.
151 Tensor out = tf.make({2, 2, 0}, {});
152
153 EXPECT_EQ(out.numel(), 0);
154
155 op_bmm_out(x, y, out);
156
157 EXPECT_EQ(out.numel(), 0);
158 }
159
TEST_F(OpBmmOutTest,MismatchedDimensionsDies)160 TEST_F(OpBmmOutTest, MismatchedDimensionsDies) {
161 TensorFactory<ScalarType::Int> tf;
162
163 Tensor x = tf.ones({2, 10, 3});
164
165 // wrong_y has incompatible shape
166 Tensor wrong_y = tf.ones({3, 7, 4});
167 Tensor right_y = tf.ones({2, 3, 4});
168
169 Tensor out = tf.ones({2, 10, 4});
170
171 ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, wrong_y, out));
172
173 EXPECT_TENSOR_EQ(op_bmm_out(x, right_y, out), tf.full({2, 10, 4}, 3));
174 }
175
TEST_F(OpBmmOutTest,MismatchedDimensionSizeDies)176 TEST_F(OpBmmOutTest, MismatchedDimensionSizeDies) {
177 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
178 GTEST_SKIP() << "ATen kernel can handle mismatched dimension size";
179 }
180 TensorFactory<ScalarType::Int> tf;
181
182 Tensor x = tf.ones({2, 10, 3});
183
184 Tensor y = tf.ones({2, 3, 4});
185
186 // wrong_y has incompatible dim
187 Tensor wrong_y = tf.ones({7, 4});
188 Tensor right_y = tf.ones({2, 3, 4});
189
190 // wrong_out has incompatible dim
191 Tensor right_out = tf.ones({2, 10, 4});
192 Tensor wrong_out = tf.ones({7, 5});
193
194 ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, right_y, wrong_out));
195 ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, wrong_y, right_out));
196 }
197
TEST_F(OpBmmOutTest,WrongOutShapeDies)198 TEST_F(OpBmmOutTest, WrongOutShapeDies) {
199 if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
200 GTEST_SKIP() << "ATen kernel can handle wrong out shape";
201 }
202 TensorFactory<ScalarType::Int> tf;
203
204 Tensor x = tf.ones({2, 10, 3});
205
206 Tensor y = tf.ones({2, 3, 4});
207
208 // wrong_out has incompatible shape
209 Tensor right_out = tf.ones({2, 10, 4});
210 Tensor wrong_out = tf.ones({3, 7, 5});
211
212 ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, y, wrong_out));
213
214 EXPECT_TENSOR_EQ(op_bmm_out(x, y, right_out), tf.full({2, 10, 4}, 3));
215 }
216
TEST_F(OpBmmOutTest,DynamicShapeUpperBoundSameAsExpected)217 TEST_F(OpBmmOutTest, DynamicShapeUpperBoundSameAsExpected) {
218 TensorFactory<ScalarType::Float> tf;
219
220 auto x = tf.make(
221 {3, 3, 6},
222 {0.7231091856956482, 0.7423362731933594, 0.5262957811355591,
223 0.24365824460983276, 0.584592342376709, 0.033152639865875244,
224 0.13871687650680542, 0.242235004901886, 0.815468966960907,
225 0.793160617351532, 0.2782524824142456, 0.48195880651474,
226 0.8197803497314453, 0.9970665574073792, 0.6984410881996155,
227 0.5675464272499084, 0.8352431654930115, 0.2055988311767578,
228 0.593172013759613, 0.11234724521636963, 0.1534569263458252,
229 0.24170821905136108, 0.7262365221977234, 0.7010802030563354,
230 0.2038237452507019, 0.6510535478591919, 0.7744860053062439,
231 0.4368913173675537, 0.5190907716751099, 0.6158523559570312,
232 0.8101882934570312, 0.9800970554351807, 0.1146882176399231,
233 0.3167651295661926, 0.6965049505233765, 0.9142746925354004,
234 0.9351036548614502, 0.9411783814430237, 0.5995072722434998,
235 0.06520867347717285, 0.5459962487220764, 0.18719732761383057,
236 0.03402292728424072, 0.944246232509613, 0.8801798820495605,
237 0.0012360215187072754, 0.5935860276222229, 0.4157699942588806,
238 0.41771942377090454, 0.2711215615272522, 0.6922780871391296,
239 0.2038482427597046, 0.6832956671714783, 0.75285404920578});
240 auto y = tf.make(
241 {3, 6, 2},
242 {0.8579357862472534, 0.6869555711746216, 0.0051323771476745605,
243 0.17565155029296875, 0.7496575117111206, 0.6046506762504578,
244 0.1099579930305481, 0.21209025382995605, 0.9703746438026428,
245 0.8369089365005493, 0.28198742866516113, 0.3741576075553894,
246 0.023700952529907227, 0.49101293087005615, 0.12347054481506348,
247 0.11432164907455444, 0.4724501967430115, 0.5750725269317627,
248 0.2952348589897156, 0.7966887950897217, 0.19573044776916504,
249 0.9536850452423096, 0.8426499366760254, 0.07835853099822998,
250 0.3755578398704529, 0.5225613117218018, 0.572950541973114,
251 0.6185871362686157, 0.6962141394615173, 0.5299500823020935,
252 0.25603562593460083, 0.7365944981575012, 0.020375549793243408,
253 0.20364665985107422, 0.3748350739479065, 0.2564433217048645});
254 Tensor expected_result = tf.make(
255 {3, 3, 2},
256 {1.6221470832824707,
257 1.498693823814392,
258 1.224705696105957,
259 1.2123372554779053,
260 2.1629090309143066,
261 2.05692195892334,
262 0.9047035574913025,
263 1.3324503898620605,
264 1.2006582021713257,
265 1.5112680196762085,
266 1.1946606636047363,
267 1.5640640258789062,
268 1.405808448791504,
269 1.5957869291305542,
270 1.3348338603973389,
271 1.2967426776885986,
272 1.1425018310546875,
273 1.2352378368377686});
274
275 Tensor out =
276 tf.zeros({3, 3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
277 Tensor ret = op_bmm_out(x, y, out);
278 EXPECT_TENSOR_CLOSE(out, expected_result);
279 }
280
TEST_F(OpBmmOutTest,DynamicShapeUpperBoundLargerThanExpected)281 TEST_F(OpBmmOutTest, DynamicShapeUpperBoundLargerThanExpected) {
282 TensorFactory<ScalarType::Float> tf;
283
284 auto x = tf.make(
285 {3, 3, 6},
286 {0.7231091856956482, 0.7423362731933594, 0.5262957811355591,
287 0.24365824460983276, 0.584592342376709, 0.033152639865875244,
288 0.13871687650680542, 0.242235004901886, 0.815468966960907,
289 0.793160617351532, 0.2782524824142456, 0.48195880651474,
290 0.8197803497314453, 0.9970665574073792, 0.6984410881996155,
291 0.5675464272499084, 0.8352431654930115, 0.2055988311767578,
292 0.593172013759613, 0.11234724521636963, 0.1534569263458252,
293 0.24170821905136108, 0.7262365221977234, 0.7010802030563354,
294 0.2038237452507019, 0.6510535478591919, 0.7744860053062439,
295 0.4368913173675537, 0.5190907716751099, 0.6158523559570312,
296 0.8101882934570312, 0.9800970554351807, 0.1146882176399231,
297 0.3167651295661926, 0.6965049505233765, 0.9142746925354004,
298 0.9351036548614502, 0.9411783814430237, 0.5995072722434998,
299 0.06520867347717285, 0.5459962487220764, 0.18719732761383057,
300 0.03402292728424072, 0.944246232509613, 0.8801798820495605,
301 0.0012360215187072754, 0.5935860276222229, 0.4157699942588806,
302 0.41771942377090454, 0.2711215615272522, 0.6922780871391296,
303 0.2038482427597046, 0.6832956671714783, 0.75285404920578});
304 auto y = tf.make(
305 {3, 6, 2},
306 {0.8579357862472534, 0.6869555711746216, 0.0051323771476745605,
307 0.17565155029296875, 0.7496575117111206, 0.6046506762504578,
308 0.1099579930305481, 0.21209025382995605, 0.9703746438026428,
309 0.8369089365005493, 0.28198742866516113, 0.3741576075553894,
310 0.023700952529907227, 0.49101293087005615, 0.12347054481506348,
311 0.11432164907455444, 0.4724501967430115, 0.5750725269317627,
312 0.2952348589897156, 0.7966887950897217, 0.19573044776916504,
313 0.9536850452423096, 0.8426499366760254, 0.07835853099822998,
314 0.3755578398704529, 0.5225613117218018, 0.572950541973114,
315 0.6185871362686157, 0.6962141394615173, 0.5299500823020935,
316 0.25603562593460083, 0.7365944981575012, 0.020375549793243408,
317 0.20364665985107422, 0.3748350739479065, 0.2564433217048645});
318 Tensor expected_result = tf.make(
319 {3, 3, 2},
320 {1.6221470832824707,
321 1.498693823814392,
322 1.224705696105957,
323 1.2123372554779053,
324 2.1629090309143066,
325 2.05692195892334,
326 0.9047035574913025,
327 1.3324503898620605,
328 1.2006582021713257,
329 1.5112680196762085,
330 1.1946606636047363,
331 1.5640640258789062,
332 1.405808448791504,
333 1.5957869291305542,
334 1.3348338603973389,
335 1.2967426776885986,
336 1.1425018310546875,
337 1.2352378368377686});
338
339 Tensor out =
340 tf.zeros({6, 6, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
341 Tensor ret = op_bmm_out(x, y, out);
342 EXPECT_TENSOR_CLOSE(out, expected_result);
343 }
344
TEST_F(OpBmmOutTest,DynamicShapeUnbound)345 TEST_F(OpBmmOutTest, DynamicShapeUnbound) {
346 GTEST_SKIP() << "Dynamic shape unbound not supported";
347 TensorFactory<ScalarType::Float> tf;
348
349 auto x = tf.make(
350 {3, 3, 6},
351 {0.7231091856956482, 0.7423362731933594, 0.5262957811355591,
352 0.24365824460983276, 0.584592342376709, 0.033152639865875244,
353 0.13871687650680542, 0.242235004901886, 0.815468966960907,
354 0.793160617351532, 0.2782524824142456, 0.48195880651474,
355 0.8197803497314453, 0.9970665574073792, 0.6984410881996155,
356 0.5675464272499084, 0.8352431654930115, 0.2055988311767578,
357 0.593172013759613, 0.11234724521636963, 0.1534569263458252,
358 0.24170821905136108, 0.7262365221977234, 0.7010802030563354,
359 0.2038237452507019, 0.6510535478591919, 0.7744860053062439,
360 0.4368913173675537, 0.5190907716751099, 0.6158523559570312,
361 0.8101882934570312, 0.9800970554351807, 0.1146882176399231,
362 0.3167651295661926, 0.6965049505233765, 0.9142746925354004,
363 0.9351036548614502, 0.9411783814430237, 0.5995072722434998,
364 0.06520867347717285, 0.5459962487220764, 0.18719732761383057,
365 0.03402292728424072, 0.944246232509613, 0.8801798820495605,
366 0.0012360215187072754, 0.5935860276222229, 0.4157699942588806,
367 0.41771942377090454, 0.2711215615272522, 0.6922780871391296,
368 0.2038482427597046, 0.6832956671714783, 0.75285404920578});
369 auto y = tf.make(
370 {3, 6, 2},
371 {0.8579357862472534, 0.6869555711746216, 0.0051323771476745605,
372 0.17565155029296875, 0.7496575117111206, 0.6046506762504578,
373 0.1099579930305481, 0.21209025382995605, 0.9703746438026428,
374 0.8369089365005493, 0.28198742866516113, 0.3741576075553894,
375 0.023700952529907227, 0.49101293087005615, 0.12347054481506348,
376 0.11432164907455444, 0.4724501967430115, 0.5750725269317627,
377 0.2952348589897156, 0.7966887950897217, 0.19573044776916504,
378 0.9536850452423096, 0.8426499366760254, 0.07835853099822998,
379 0.3755578398704529, 0.5225613117218018, 0.572950541973114,
380 0.6185871362686157, 0.6962141394615173, 0.5299500823020935,
381 0.25603562593460083, 0.7365944981575012, 0.020375549793243408,
382 0.20364665985107422, 0.3748350739479065, 0.2564433217048645});
383 Tensor expected_result = tf.make(
384 {3, 3, 2},
385 {1.6221470832824707,
386 1.498693823814392,
387 1.224705696105957,
388 1.2123372554779053,
389 2.1629090309143066,
390 2.05692195892334,
391 0.9047035574913025,
392 1.3324503898620605,
393 1.2006582021713257,
394 1.5112680196762085,
395 1.1946606636047363,
396 1.5640640258789062,
397 1.405808448791504,
398 1.5957869291305542,
399 1.3348338603973389,
400 1.2967426776885986,
401 1.1425018310546875,
402 1.2352378368377686});
403
404 Tensor out = tf.zeros(
405 {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
406 Tensor ret = op_bmm_out(x, y, out);
407 EXPECT_TENSOR_CLOSE(out, expected_result);
408 }
409