xref: /aosp_15_r20/external/executorch/kernels/test/op_bmm_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
16 
17 #include <gtest/gtest.h>
18 
19 using namespace ::testing;
20 using exec_aten::ArrayRef;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 using torch::executor::testing::TensorFactory;
24 
25 class OpBmmOutTest : public OperatorTest {
26  protected:
op_bmm_out(const Tensor & self,const Tensor & mat2,Tensor & out)27   Tensor& op_bmm_out(const Tensor& self, const Tensor& mat2, Tensor& out) {
28     return torch::executor::aten::bmm_outf(context_, self, mat2, out);
29   }
30 
31   template <class CTYPE, exec_aten::ScalarType DTYPE>
test_dtype()32   void test_dtype() {
33     TensorFactory<DTYPE> tf;
34 
35     // Gives 4 * 2 * 3 = 24, shape (10, 3, 5)
36     Tensor x = tf.full({10, 3, 4}, 2);
37     Tensor y = tf.full({10, 4, 5}, 3);
38 
39     Tensor out = tf.zeros({10, 3, 5});
40     op_bmm_out(x, y, out);
41 
42     Tensor expected = tf.full({10, 3, 5}, 24);
43 
44     EXPECT_TENSOR_EQ(out, expected);
45   }
46 };
47 
TEST_F(OpBmmOutTest,OutputDim)48 TEST_F(OpBmmOutTest, OutputDim) {
49   TensorFactory<ScalarType::Int> tf;
50 
51   // Two tensors with compatible dimensions: (10, 3, 4) and (10, 4, 5).
52   Tensor x = tf.ones({10, 3, 4});
53   Tensor y = tf.ones({10, 4, 5});
54 
55   // Output shape should be (10, 3, 5)
56   Tensor out = tf.zeros({10, 3, 5});
57 
58   Tensor ret = op_bmm_out(x, y, out);
59 
60   // Should always return the provided out Tensor.
61   EXPECT_TENSOR_EQ(ret, out);
62 
63   // Expected tensor, filled with 4.
64   Tensor expected = tf.full({10, 3, 5}, 4);
65 
66   EXPECT_TENSOR_EQ(out, expected);
67 }
68 
TEST_F(OpBmmOutTest,OutputDimFloat)69 TEST_F(OpBmmOutTest, OutputDimFloat) {
70   TensorFactory<ScalarType::Float> tf;
71 
72   // clang-format off
73   Tensor x = tf.make(
74       {2, 4, 5},
75       {
76         4., 3., 1., 1., 1.,
77         3., 1., 4., 4., 2.,
78         1., 1., 1., 3., 3.,
79         4., 2., 2., 2., 3.,
80 
81         1., 3., 1., 4., 4.,
82         1., 1., 2., 4., 3.,
83         4., 3., 4., 1., 2.,
84         1., 4., 4., 4., 4.,
85       });
86   // clang-format on
87 
88   // clang-format off
89   Tensor y = tf.make(
90       {2, 5, 3},
91       {
92         4., 4., 4.,
93         2., 3., 1.,
94         1., 4., 4.,
95         3., 1., 2.,
96         1., 4., 3.,
97 
98         1., 4., 4.,
99         4., 4., 4.,
100         2., 1., 4.,
101         1., 4., 3.,
102         1., 4., 4.,
103       });
104   // clang-format on
105 
106   // Output shape should be (10, 3, 5)
107   Tensor out = tf.zeros({2, 4, 3});
108 
109   Tensor ret = op_bmm_out(x, y, out);
110 
111   // Should always return the provided out Tensor.
112   EXPECT_TENSOR_EQ(ret, out);
113 
114   // clang-format off
115   Tensor expected = tf.make(
116       {2, 4, 3},
117       {
118         27., 34., 28.,
119         32., 43., 43.,
120         19., 26., 24.,
121         31., 44., 39.,
122 
123         23., 49., 48.,
124         16., 38., 40.,
125         27., 44., 55.,
126         33., 56., 64.,
127       });
128   // clang-format on
129 
130   EXPECT_TENSOR_EQ(out, expected);
131 }
132 
133 /// A generic smoke test that works for any dtype that supports ones() and
134 /// zeros().
TEST_F(OpBmmOutTest,AllDtypesSupported)135 TEST_F(OpBmmOutTest, AllDtypesSupported) {
136 #define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
137   ET_FORALL_REAL_TYPES(TEST_ENTRY);
138 #undef TEST_ENTRY
139   // TODO: Also add tests for half, complex, quantized, and other types. Easiest
140   // way to do that would be to make TensorFactory support zeros() and ones()
141   // for those types.
142 }
143 
TEST_F(OpBmmOutTest,EmptyInputWithEmptyOutTensorPasses)144 TEST_F(OpBmmOutTest, EmptyInputWithEmptyOutTensorPasses) {
145   TensorFactory<ScalarType::Int> tf;
146 
147   Tensor x = tf.full({2, 2, 2}, 3);
148   Tensor y = tf.make({2, 2, 0}, {});
149 
150   // Make an empty out tensor and demonstrate that it's empty.
151   Tensor out = tf.make({2, 2, 0}, {});
152 
153   EXPECT_EQ(out.numel(), 0);
154 
155   op_bmm_out(x, y, out);
156 
157   EXPECT_EQ(out.numel(), 0);
158 }
159 
TEST_F(OpBmmOutTest,MismatchedDimensionsDies)160 TEST_F(OpBmmOutTest, MismatchedDimensionsDies) {
161   TensorFactory<ScalarType::Int> tf;
162 
163   Tensor x = tf.ones({2, 10, 3});
164 
165   // wrong_y has incompatible shape
166   Tensor wrong_y = tf.ones({3, 7, 4});
167   Tensor right_y = tf.ones({2, 3, 4});
168 
169   Tensor out = tf.ones({2, 10, 4});
170 
171   ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, wrong_y, out));
172 
173   EXPECT_TENSOR_EQ(op_bmm_out(x, right_y, out), tf.full({2, 10, 4}, 3));
174 }
175 
TEST_F(OpBmmOutTest,MismatchedDimensionSizeDies)176 TEST_F(OpBmmOutTest, MismatchedDimensionSizeDies) {
177   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
178     GTEST_SKIP() << "ATen kernel can handle mismatched dimension size";
179   }
180   TensorFactory<ScalarType::Int> tf;
181 
182   Tensor x = tf.ones({2, 10, 3});
183 
184   Tensor y = tf.ones({2, 3, 4});
185 
186   // wrong_y has incompatible dim
187   Tensor wrong_y = tf.ones({7, 4});
188   Tensor right_y = tf.ones({2, 3, 4});
189 
190   // wrong_out has incompatible dim
191   Tensor right_out = tf.ones({2, 10, 4});
192   Tensor wrong_out = tf.ones({7, 5});
193 
194   ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, right_y, wrong_out));
195   ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, wrong_y, right_out));
196 }
197 
TEST_F(OpBmmOutTest,WrongOutShapeDies)198 TEST_F(OpBmmOutTest, WrongOutShapeDies) {
199   if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
200     GTEST_SKIP() << "ATen kernel can handle wrong out shape";
201   }
202   TensorFactory<ScalarType::Int> tf;
203 
204   Tensor x = tf.ones({2, 10, 3});
205 
206   Tensor y = tf.ones({2, 3, 4});
207 
208   // wrong_out has incompatible shape
209   Tensor right_out = tf.ones({2, 10, 4});
210   Tensor wrong_out = tf.ones({3, 7, 5});
211 
212   ET_EXPECT_KERNEL_FAILURE(context_, op_bmm_out(x, y, wrong_out));
213 
214   EXPECT_TENSOR_EQ(op_bmm_out(x, y, right_out), tf.full({2, 10, 4}, 3));
215 }
216 
TEST_F(OpBmmOutTest,DynamicShapeUpperBoundSameAsExpected)217 TEST_F(OpBmmOutTest, DynamicShapeUpperBoundSameAsExpected) {
218   TensorFactory<ScalarType::Float> tf;
219 
220   auto x = tf.make(
221       {3, 3, 6},
222       {0.7231091856956482,    0.7423362731933594,  0.5262957811355591,
223        0.24365824460983276,   0.584592342376709,   0.033152639865875244,
224        0.13871687650680542,   0.242235004901886,   0.815468966960907,
225        0.793160617351532,     0.2782524824142456,  0.48195880651474,
226        0.8197803497314453,    0.9970665574073792,  0.6984410881996155,
227        0.5675464272499084,    0.8352431654930115,  0.2055988311767578,
228        0.593172013759613,     0.11234724521636963, 0.1534569263458252,
229        0.24170821905136108,   0.7262365221977234,  0.7010802030563354,
230        0.2038237452507019,    0.6510535478591919,  0.7744860053062439,
231        0.4368913173675537,    0.5190907716751099,  0.6158523559570312,
232        0.8101882934570312,    0.9800970554351807,  0.1146882176399231,
233        0.3167651295661926,    0.6965049505233765,  0.9142746925354004,
234        0.9351036548614502,    0.9411783814430237,  0.5995072722434998,
235        0.06520867347717285,   0.5459962487220764,  0.18719732761383057,
236        0.03402292728424072,   0.944246232509613,   0.8801798820495605,
237        0.0012360215187072754, 0.5935860276222229,  0.4157699942588806,
238        0.41771942377090454,   0.2711215615272522,  0.6922780871391296,
239        0.2038482427597046,    0.6832956671714783,  0.75285404920578});
240   auto y = tf.make(
241       {3, 6, 2},
242       {0.8579357862472534,   0.6869555711746216,  0.0051323771476745605,
243        0.17565155029296875,  0.7496575117111206,  0.6046506762504578,
244        0.1099579930305481,   0.21209025382995605, 0.9703746438026428,
245        0.8369089365005493,   0.28198742866516113, 0.3741576075553894,
246        0.023700952529907227, 0.49101293087005615, 0.12347054481506348,
247        0.11432164907455444,  0.4724501967430115,  0.5750725269317627,
248        0.2952348589897156,   0.7966887950897217,  0.19573044776916504,
249        0.9536850452423096,   0.8426499366760254,  0.07835853099822998,
250        0.3755578398704529,   0.5225613117218018,  0.572950541973114,
251        0.6185871362686157,   0.6962141394615173,  0.5299500823020935,
252        0.25603562593460083,  0.7365944981575012,  0.020375549793243408,
253        0.20364665985107422,  0.3748350739479065,  0.2564433217048645});
254   Tensor expected_result = tf.make(
255       {3, 3, 2},
256       {1.6221470832824707,
257        1.498693823814392,
258        1.224705696105957,
259        1.2123372554779053,
260        2.1629090309143066,
261        2.05692195892334,
262        0.9047035574913025,
263        1.3324503898620605,
264        1.2006582021713257,
265        1.5112680196762085,
266        1.1946606636047363,
267        1.5640640258789062,
268        1.405808448791504,
269        1.5957869291305542,
270        1.3348338603973389,
271        1.2967426776885986,
272        1.1425018310546875,
273        1.2352378368377686});
274 
275   Tensor out =
276       tf.zeros({3, 3, 2}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
277   Tensor ret = op_bmm_out(x, y, out);
278   EXPECT_TENSOR_CLOSE(out, expected_result);
279 }
280 
TEST_F(OpBmmOutTest,DynamicShapeUpperBoundLargerThanExpected)281 TEST_F(OpBmmOutTest, DynamicShapeUpperBoundLargerThanExpected) {
282   TensorFactory<ScalarType::Float> tf;
283 
284   auto x = tf.make(
285       {3, 3, 6},
286       {0.7231091856956482,    0.7423362731933594,  0.5262957811355591,
287        0.24365824460983276,   0.584592342376709,   0.033152639865875244,
288        0.13871687650680542,   0.242235004901886,   0.815468966960907,
289        0.793160617351532,     0.2782524824142456,  0.48195880651474,
290        0.8197803497314453,    0.9970665574073792,  0.6984410881996155,
291        0.5675464272499084,    0.8352431654930115,  0.2055988311767578,
292        0.593172013759613,     0.11234724521636963, 0.1534569263458252,
293        0.24170821905136108,   0.7262365221977234,  0.7010802030563354,
294        0.2038237452507019,    0.6510535478591919,  0.7744860053062439,
295        0.4368913173675537,    0.5190907716751099,  0.6158523559570312,
296        0.8101882934570312,    0.9800970554351807,  0.1146882176399231,
297        0.3167651295661926,    0.6965049505233765,  0.9142746925354004,
298        0.9351036548614502,    0.9411783814430237,  0.5995072722434998,
299        0.06520867347717285,   0.5459962487220764,  0.18719732761383057,
300        0.03402292728424072,   0.944246232509613,   0.8801798820495605,
301        0.0012360215187072754, 0.5935860276222229,  0.4157699942588806,
302        0.41771942377090454,   0.2711215615272522,  0.6922780871391296,
303        0.2038482427597046,    0.6832956671714783,  0.75285404920578});
304   auto y = tf.make(
305       {3, 6, 2},
306       {0.8579357862472534,   0.6869555711746216,  0.0051323771476745605,
307        0.17565155029296875,  0.7496575117111206,  0.6046506762504578,
308        0.1099579930305481,   0.21209025382995605, 0.9703746438026428,
309        0.8369089365005493,   0.28198742866516113, 0.3741576075553894,
310        0.023700952529907227, 0.49101293087005615, 0.12347054481506348,
311        0.11432164907455444,  0.4724501967430115,  0.5750725269317627,
312        0.2952348589897156,   0.7966887950897217,  0.19573044776916504,
313        0.9536850452423096,   0.8426499366760254,  0.07835853099822998,
314        0.3755578398704529,   0.5225613117218018,  0.572950541973114,
315        0.6185871362686157,   0.6962141394615173,  0.5299500823020935,
316        0.25603562593460083,  0.7365944981575012,  0.020375549793243408,
317        0.20364665985107422,  0.3748350739479065,  0.2564433217048645});
318   Tensor expected_result = tf.make(
319       {3, 3, 2},
320       {1.6221470832824707,
321        1.498693823814392,
322        1.224705696105957,
323        1.2123372554779053,
324        2.1629090309143066,
325        2.05692195892334,
326        0.9047035574913025,
327        1.3324503898620605,
328        1.2006582021713257,
329        1.5112680196762085,
330        1.1946606636047363,
331        1.5640640258789062,
332        1.405808448791504,
333        1.5957869291305542,
334        1.3348338603973389,
335        1.2967426776885986,
336        1.1425018310546875,
337        1.2352378368377686});
338 
339   Tensor out =
340       tf.zeros({6, 6, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND);
341   Tensor ret = op_bmm_out(x, y, out);
342   EXPECT_TENSOR_CLOSE(out, expected_result);
343 }
344 
TEST_F(OpBmmOutTest,DynamicShapeUnbound)345 TEST_F(OpBmmOutTest, DynamicShapeUnbound) {
346   GTEST_SKIP() << "Dynamic shape unbound not supported";
347   TensorFactory<ScalarType::Float> tf;
348 
349   auto x = tf.make(
350       {3, 3, 6},
351       {0.7231091856956482,    0.7423362731933594,  0.5262957811355591,
352        0.24365824460983276,   0.584592342376709,   0.033152639865875244,
353        0.13871687650680542,   0.242235004901886,   0.815468966960907,
354        0.793160617351532,     0.2782524824142456,  0.48195880651474,
355        0.8197803497314453,    0.9970665574073792,  0.6984410881996155,
356        0.5675464272499084,    0.8352431654930115,  0.2055988311767578,
357        0.593172013759613,     0.11234724521636963, 0.1534569263458252,
358        0.24170821905136108,   0.7262365221977234,  0.7010802030563354,
359        0.2038237452507019,    0.6510535478591919,  0.7744860053062439,
360        0.4368913173675537,    0.5190907716751099,  0.6158523559570312,
361        0.8101882934570312,    0.9800970554351807,  0.1146882176399231,
362        0.3167651295661926,    0.6965049505233765,  0.9142746925354004,
363        0.9351036548614502,    0.9411783814430237,  0.5995072722434998,
364        0.06520867347717285,   0.5459962487220764,  0.18719732761383057,
365        0.03402292728424072,   0.944246232509613,   0.8801798820495605,
366        0.0012360215187072754, 0.5935860276222229,  0.4157699942588806,
367        0.41771942377090454,   0.2711215615272522,  0.6922780871391296,
368        0.2038482427597046,    0.6832956671714783,  0.75285404920578});
369   auto y = tf.make(
370       {3, 6, 2},
371       {0.8579357862472534,   0.6869555711746216,  0.0051323771476745605,
372        0.17565155029296875,  0.7496575117111206,  0.6046506762504578,
373        0.1099579930305481,   0.21209025382995605, 0.9703746438026428,
374        0.8369089365005493,   0.28198742866516113, 0.3741576075553894,
375        0.023700952529907227, 0.49101293087005615, 0.12347054481506348,
376        0.11432164907455444,  0.4724501967430115,  0.5750725269317627,
377        0.2952348589897156,   0.7966887950897217,  0.19573044776916504,
378        0.9536850452423096,   0.8426499366760254,  0.07835853099822998,
379        0.3755578398704529,   0.5225613117218018,  0.572950541973114,
380        0.6185871362686157,   0.6962141394615173,  0.5299500823020935,
381        0.25603562593460083,  0.7365944981575012,  0.020375549793243408,
382        0.20364665985107422,  0.3748350739479065,  0.2564433217048645});
383   Tensor expected_result = tf.make(
384       {3, 3, 2},
385       {1.6221470832824707,
386        1.498693823814392,
387        1.224705696105957,
388        1.2123372554779053,
389        2.1629090309143066,
390        2.05692195892334,
391        0.9047035574913025,
392        1.3324503898620605,
393        1.2006582021713257,
394        1.5112680196762085,
395        1.1946606636047363,
396        1.5640640258789062,
397        1.405808448791504,
398        1.5957869291305542,
399        1.3348338603973389,
400        1.2967426776885986,
401        1.1425018310546875,
402        1.2352378368377686});
403 
404   Tensor out = tf.zeros(
405       {1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
406   Tensor ret = op_bmm_out(x, y, out);
407   EXPECT_TENSOR_CLOSE(out, expected_result);
408 }
409