xref: /aosp_15_r20/external/armnn/delegate/test/BatchMatMulTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchMatMulTestHelper.hpp"
7 
8 #include <armnn_delegate.hpp>
9 
10 #include <flatbuffers/flatbuffers.h>
11 #include <schema_generated.h>
12 
13 #include <doctest/doctest.h>
14 
15 namespace armnnDelegate
16 {
17 
BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId> & backends)18     void BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
19     {
20         // Set input data
21         std::vector<int32_t> LHSInputShape { 2, 2 };
22         std::vector<int32_t> RHSInputShape { 2, 2 };
23         std::vector<int32_t> outputShape   { 2, 2 };
24 
25         std::vector<float> LHSInputValues = { 1, 2,
26                                               3, 4 };
27 
28         std::vector<float> RHSInputValues = { 5, 6,
29                                               7, 8  };
30 
31         std::vector<float> expectedOutputValues = { 19, 22,
32                                                     43, 50 };
33 
34         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
35                                ::tflite::TensorType_FLOAT32,
36                                backends,
37                                LHSInputShape,
38                                RHSInputShape,
39                                outputShape,
40                                LHSInputValues,
41                                RHSInputValues,
42                                expectedOutputValues,
43                                false,
44                                false);
45     }
BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId> & backends)46     void BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
47     {
48         // Set input data
49         std::vector<int32_t> LHSInputShape { 2, 2 };
50         std::vector<int32_t> RHSInputShape { 2, 2 };
51         std::vector<int32_t> outputShape   { 2, 2 };
52 
53         std::vector<int8_t> LHSInputValues = { 1, 2,
54                                               3, 4 };
55 
56         std::vector<int8_t> RHSInputValues = { 5, 6,
57                                               7, 8  };
58 
59         std::vector<int8_t> expectedOutputValues = { 19, 22,
60                                                     43, 50 };
61 
62         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
63                                ::tflite::TensorType_INT8,
64                                backends,
65                                LHSInputShape,
66                                RHSInputShape,
67                                outputShape,
68                                LHSInputValues,
69                                RHSInputValues,
70                                expectedOutputValues,
71                                false,
72                                false);
73     }
74 
BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId> & backends)75     void BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
76     {
77         // Set input data
78         std::vector<int32_t> LHSInputShape { 1,2,2 };
79         std::vector<int32_t> RHSInputShape { 1,2,2 };
80         std::vector<int32_t> outputShape   { 1,2,2 };
81 
82         std::vector<float> LHSInputValues = { 1, 2,
83                                               3, 4 };
84 
85         std::vector<float> RHSInputValues = { 5, 6,
86                                               7, 8  };
87 
88         std::vector<float> expectedOutputValues = { 19, 22,
89                                                     43, 50 };
90 
91         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
92                                ::tflite::TensorType_FLOAT32,
93                                backends,
94                                LHSInputShape,
95                                RHSInputShape,
96                                outputShape,
97                                LHSInputValues,
98                                RHSInputValues,
99                                expectedOutputValues,
100                                false,
101                                false);
102     }
103 
BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId> & backends)104     void BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
105     {
106         // Set input data
107         std::vector<int32_t> LHSInputShape { 1,2,2 };
108         std::vector<int32_t> RHSInputShape { 1,2,2 };
109         std::vector<int32_t> outputShape   { 1,2,2 };
110 
111         std::vector<int8_t> LHSInputValues = { 1, 2,
112                                               3, 4 };
113 
114         std::vector<int8_t> RHSInputValues = { 5, 6,
115                                               7, 8  };
116 
117         std::vector<int8_t> expectedOutputValues = { 19, 22,
118                                                     43, 50 };
119 
120         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
121                                ::tflite::TensorType_INT8,
122                                backends,
123                                LHSInputShape,
124                                RHSInputShape,
125                                outputShape,
126                                LHSInputValues,
127                                RHSInputValues,
128                                expectedOutputValues,
129                                false,
130                                false);
131     }
132 
BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId> & backends)133     void BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
134     {
135         // Set input data
136         std::vector<int32_t> LHSInputShape { 1,1,2,2 };
137         std::vector<int32_t> RHSInputShape { 1,1,2,2 };
138         std::vector<int32_t> outputShape   { 1,1,2,2 };
139 
140         std::vector<float> LHSInputValues = { 1, 2,
141                                               3, 4 };
142 
143         std::vector<float> RHSInputValues = { 5, 6,
144                                               7, 8  };
145 
146         std::vector<float> expectedOutputValues = { 19, 22,
147                                                     43, 50 };
148 
149         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
150                                ::tflite::TensorType_FLOAT32,
151                                backends,
152                                LHSInputShape,
153                                RHSInputShape,
154                                outputShape,
155                                LHSInputValues,
156                                RHSInputValues,
157                                expectedOutputValues,
158                                false,
159                                false);
160     }
161 
BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId> & backends)162     void BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
163     {
164         // Set input data
165         std::vector<int32_t> LHSInputShape { 1,1,2,2};
166         std::vector<int32_t> RHSInputShape { 1,1,2,2 };
167         std::vector<int32_t> outputShape   { 1,1,2,2 };
168 
169         std::vector<int8_t> LHSInputValues = { 1, 2,
170                                               3, 4 };
171 
172         std::vector<int8_t> RHSInputValues = { 5, 6,
173                                               7, 8 };
174 
175         std::vector<int8_t> expectedOutputValues = { 19, 22,
176                                                     43, 50 };
177 
178         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
179                                ::tflite::TensorType_INT8,
180                                backends,
181                                LHSInputShape,
182                                RHSInputShape,
183                                outputShape,
184                                LHSInputValues,
185                                RHSInputValues,
186                                expectedOutputValues,
187                                false,
188                                false);
189     }
190 
BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId> & backends)191     void BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId>& backends)
192     {
193         // Set input data
194         std::vector<int32_t> LHSInputShape { 2,2,2 };
195         std::vector<int32_t> RHSInputShape { 2,2,2 };
196         std::vector<int32_t> outputShape   { 2,2,2 };
197 
198         std::vector<float> LHSInputValues = { 1, 2,
199                                               3, 4,
200 
201                                               9, 10,
202                                               11, 12 };
203 
204         std::vector<float> RHSInputValues = { 5, 6,
205                                               7, 8,
206 
207                                               13, 14,
208                                               15, 16 };
209 
210         std::vector<float> expectedOutputValues = { 19, 22,
211                                                     43, 50,
212 
213                                                     267, 286,
214                                                     323, 346 };
215 
216         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
217                                ::tflite::TensorType_FLOAT32,
218                                backends,
219                                LHSInputShape,
220                                RHSInputShape,
221                                outputShape,
222                                LHSInputValues,
223                                RHSInputValues,
224                                expectedOutputValues,
225                                false,
226                                false);
227     }
228 
BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId> & backends)229     void BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId>& backends)
230     {
231         // Set input data
232         std::vector<int32_t> LHSInputShape { 2,2,2 };
233         std::vector<int32_t> RHSInputShape { 2,2,2 };
234         std::vector<int32_t> outputShape   { 2,2,2 };
235 
236         std::vector<int8_t> LHSInputValues = { 1, 2,
237                                               3, 4,
238 
239                                               9, 10,
240                                               11, 12 };
241 
242         std::vector<int8_t> RHSInputValues = { 5, 6,
243                                               7, 8,
244 
245                                               1, 2,
246                                               3, 4 };
247 
248         std::vector<int8_t> expectedOutputValues = { 19, 22,
249                                                     43, 50,
250 
251                                                     39, 58,
252                                                     47, 70 };
253 
254         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
255                                ::tflite::TensorType_INT8,
256                                backends,
257                                LHSInputShape,
258                                RHSInputShape,
259                                outputShape,
260                                LHSInputValues,
261                                RHSInputValues,
262                                expectedOutputValues,
263                                false,
264                                false);
265     }
266 
BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId> & backends)267     void BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId>& backends)
268     {
269         // Set input data
270         std::vector<int32_t> LHSInputShape { 2,2,2 };
271         std::vector<int32_t> RHSInputShape { 2,2 };
272         std::vector<int32_t> outputShape   { 2,2,2 };
273 
274         std::vector<float> LHSInputValues = { 1, 2,
275                                               3, 4,
276 
277                                               9, 10,
278                                               11, 12 };
279 
280         std::vector<float> RHSInputValues = { 13, 14,
281                                               15, 16 };
282 
283         std::vector<float> expectedOutputValues = {  43, 46,
284                                                      99, 106,
285 
286                                                      267, 286,
287                                                      323, 346 };
288 
289         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
290                                ::tflite::TensorType_FLOAT32,
291                                backends,
292                                LHSInputShape,
293                                RHSInputShape,
294                                outputShape,
295                                LHSInputValues,
296                                RHSInputValues,
297                                expectedOutputValues,
298                                false,
299                                false);
300     }
301 
BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId> & backends)302     void BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId>& backends)
303     {
304         // Set input data
305         std::vector<int32_t> LHSInputShape { 2,2,2 };
306         std::vector<int32_t> RHSInputShape { 1,2,2 };
307         std::vector<int32_t> outputShape   { 2,2,2 };
308 
309         std::vector<int8_t> LHSInputValues = { 1, 2,
310                                               3, 4,
311 
312                                               9, 10,
313                                               11, 12 };
314 
315         std::vector<int8_t> RHSInputValues = { 1, 2,
316                                                3, 4 };
317 
318         std::vector<int8_t> expectedOutputValues = {  7,  10,
319                                                       15, 22,
320 
321                                                       39, 58,
322                                                       47, 70 };
323 
324         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
325                                ::tflite::TensorType_INT8,
326                                backends,
327                                LHSInputShape,
328                                RHSInputShape,
329                                outputShape,
330                                LHSInputValues,
331                                RHSInputValues,
332                                expectedOutputValues,
333                                false,
334                                false);
335     }
336 
BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId> & backends)337     void BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId>& backends)
338     {
339         // Set input data
340         std::vector<int32_t> LHSInputShape { 2,2,2 };
341         std::vector<int32_t> RHSInputShape { 2,2 };
342         std::vector<int32_t> outputShape   { 2,2,2 };
343 
344         std::vector<float> LHSInputValues = { 1, 2,
345                                               3, 4,
346 
347                                               9, 10,
348                                               11, 12 };
349 
350         std::vector<float> RHSInputValues = { 13, 14,
351                                               15, 16 };
352 
353         std::vector<float> expectedOutputValues = {  43, 46,
354                                                      99, 106,
355 
356                                                      267, 286,
357                                                      323, 346 };
358 
359         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
360                                ::tflite::TensorType_FLOAT32,
361                                backends,
362                                LHSInputShape,
363                                RHSInputShape,
364                                outputShape,
365                                LHSInputValues,
366                                RHSInputValues,
367                                expectedOutputValues,
368                                false,
369                                false);
370     }
371 
BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId> & backends)372     void BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId>& backends)
373     {
374         // Set input data
375         std::vector<int32_t> LHSInputShape { 2,2,2 };
376         std::vector<int32_t> RHSInputShape { 2,2 };
377         std::vector<int32_t> outputShape   { 2,2,2 };
378 
379         std::vector<int8_t> LHSInputValues = { 1, 2,
380                                               3, 4,
381 
382                                               9, 10,
383                                               11, 12 };
384 
385         std::vector<int8_t> RHSInputValues = { 1, 2,
386                                                3, 4 };
387 
388         std::vector<int8_t> expectedOutputValues = {  7, 10,
389                                                       15, 22,
390 
391                                                       39, 58,
392                                                       47, 70 };
393 
394         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
395                                ::tflite::TensorType_INT8,
396                                backends,
397                                LHSInputShape,
398                                RHSInputShape,
399                                outputShape,
400                                LHSInputValues,
401                                RHSInputValues,
402                                expectedOutputValues,
403                                false,
404                                false);
405     }
406 
BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId> & backends)407     void BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId>& backends)
408     {
409         // Set input data
410         std::vector<int32_t> LHSInputShape { 1,1 };
411         std::vector<int32_t> RHSInputShape { 1,1 };
412         std::vector<int32_t> outputShape   { 1,1 };
413 
414         std::vector<float> LHSInputValues = { 3 };
415 
416         std::vector<float> RHSInputValues = { 5 };
417 
418         std::vector<float> expectedOutputValues = { 15 };
419 
420         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
421                                ::tflite::TensorType_FLOAT32,
422                                backends,
423                                LHSInputShape,
424                                RHSInputShape,
425                                outputShape,
426                                LHSInputValues,
427                                RHSInputValues,
428                                expectedOutputValues,
429                                false,
430                                false);
431     }
BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId> & backends)432     void BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId>& backends)
433     {
434         // Set input data
435         std::vector<int32_t> LHSInputShape { 1,1 };
436         std::vector<int32_t> RHSInputShape { 1,1 };
437         std::vector<int32_t> outputShape   { 1,1 };
438 
439         std::vector<int8_t> LHSInputValues = { 3 };
440 
441         std::vector<int8_t> RHSInputValues = { 5 };
442 
443         std::vector<int8_t> expectedOutputValues = { 15 };
444 
445         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
446                                 ::tflite::TensorType_INT8,
447                                 backends,
448                                 LHSInputShape,
449                                 RHSInputShape,
450                                 outputShape,
451                                 LHSInputValues,
452                                 RHSInputValues,
453                                 expectedOutputValues,
454                                 false,
455                                 false);
456     }
457 
BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId> & backends)458     void BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId>& backends)
459     {
460         // Set input data
461         std::vector<int32_t> LHSInputShape { 2,5,3 };
462         std::vector<int32_t> RHSInputShape { 2,3,4 };
463         std::vector<int32_t> outputShape   { 2,5,4 };
464 
465         std::vector<float> LHSInputValues = { 8, 8, 4,
466                                               6, 1, 3,
467                                               8, 8, 3,
468                                               8, 9, 8,
469                                               5, 4, 4,
470 
471                                               1, 8, 5,
472                                               7, 1, 1,
473                                               8, 7, 9,
474                                               3, 2, 7,
475                                               8, 5, 3 };
476 
477         std::vector<float> RHSInputValues = { 6, 2, 3, 2,
478                                               6, 2, 2, 8,
479                                               3, 7, 8, 1,
480 
481                                               7, 2, 9, 5,
482                                               2, 3, 1, 3,
483                                               2, 7, 7, 5 };
484 
485         std::vector<float> expectedOutputValues = { 108, 60, 72, 84,
486                                                     51, 35, 44, 23,
487                                                     105, 53, 64, 83,
488                                                     126, 90, 106, 96,
489                                                     66, 46, 55, 46,
490 
491                                                     33, 61, 52, 54,
492                                                     53, 24, 71, 43,
493                                                     88, 100, 142, 106,
494                                                     39, 61, 78, 56,
495                                                     72, 52, 98, 70 };
496 
497         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
498                                ::tflite::TensorType_FLOAT32,
499                                backends,
500                                LHSInputShape,
501                                RHSInputShape,
502                                outputShape,
503                                LHSInputValues,
504                                RHSInputValues,
505                                expectedOutputValues,
506                                false,
507                                false);
508     }
509 
BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId> & backends)510     void BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId>& backends)
511     {
512         // Set input data
513         std::vector<int32_t> LHSInputShape { 2,5,3 };
514         std::vector<int32_t> RHSInputShape { 2,3,4 };
515         std::vector<int32_t> outputShape   { 2,5,4 };
516 
517         std::vector<int8_t> LHSInputValues = { 8, 8, 4,
518                                               6, 1, 3,
519                                               8, 8, 3,
520                                               8, 9, 8,
521                                               5, 4, 4,
522 
523                                               1, 8, 5,
524                                               7, 1, 1,
525                                               8, 7, 9,
526                                               3, 2, 7,
527                                               8, 5, 3 };
528 
529         std::vector<int8_t> RHSInputValues = { 6, 2, 3, 2,
530                                               6, 2, 2, 8,
531                                               3, 7, 8, 1,
532 
533                                               7, 2, 3, 5,
534                                               2, 3, 1, 3,
535                                               2, 7, 7, 5 };
536 
537         std::vector<int8_t> expectedOutputValues = { 108, 60, 72, 84,
538                                                     51, 35, 44, 23,
539                                                     105, 53, 64, 83,
540                                                     126, 90, 106, 96,
541                                                     66, 46, 55, 46,
542 
543                                                     33, 61, 46, 54,
544                                                     53, 24, 29, 43,
545                                                     88, 100, 94, 106,
546                                                     39, 61, 60, 56,
547                                                     72, 52, 50, 70 };
548 
549         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
550                                ::tflite::TensorType_INT8,
551                                backends,
552                                LHSInputShape,
553                                RHSInputShape,
554                                outputShape,
555                                LHSInputValues,
556                                RHSInputValues,
557                                expectedOutputValues,
558                                false,
559                                false);
560     }
561 
BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId> & backends)562     void BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId>& backends)
563     {
564         // Set input data
565         std::vector<int32_t> LHSInputShape { 3,3 };
566         std::vector<int32_t> RHSInputShape { 3,3 };
567         std::vector<int32_t> outputShape   { 3,3 };
568 
569         std::vector<float> LHSInputValues = { 3, 1, 1,
570                                               1, 3, -1,
571                                               2, 4, 1 };
572 
573         std::vector<float> RHSInputValues = { 1, 0, 0,
574                                               0, 1, 0,
575                                               0, 0, 1 };
576 
577         std::vector<float> expectedOutputValues = { 3, 1, 2,
578                                                     1, 3, 4,
579                                                     1, -1, 1 };
580 
581         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
582                                ::tflite::TensorType_FLOAT32,
583                                backends,
584                                LHSInputShape,
585                                RHSInputShape,
586                                outputShape,
587                                LHSInputValues,
588                                RHSInputValues,
589                                expectedOutputValues,
590                                true,
591                                false);
592     }
593 
BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId> & backends)594     void BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId>& backends)
595     {
596         // Set input data
597         std::vector<int32_t> LHSInputShape { 3,3 };
598         std::vector<int32_t> RHSInputShape { 3,3 };
599         std::vector<int32_t> outputShape   { 3,3 };
600 
601         std::vector<int8_t> LHSInputValues = { 3, 1, 1,
602                                               1, 3, -1,
603                                               2, 4, 1 };
604 
605         std::vector<int8_t> RHSInputValues = { 1, 0, 0,
606                                               0, 1, 0,
607                                               0, 0, 1 };
608 
609         std::vector<int8_t> expectedOutputValues = { 3, 1, 2,
610                                                      1, 3, 4,
611                                                      1, -1, 1 };
612 
613         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
614                                ::tflite::TensorType_INT8,
615                                backends,
616                                LHSInputShape,
617                                RHSInputShape,
618                                outputShape,
619                                LHSInputValues,
620                                RHSInputValues,
621                                expectedOutputValues,
622                                true,
623                                false);
624     }
625 
626     TEST_SUITE("BATCH_MATMUL_CpuRefTests")
627     {
628         TEST_CASE("BATCH_MATMUL_Fp32_CpuRefTests")
629         {
630             std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
631             BatchMatMul2DFp32SimpleTest       (backends);
632             BatchMatMul3DFp32SimpleTest       (backends);
633             BatchMatMul4DFp32SimpleTest       (backends);
634             BatchMatMul3DFp32BatchTest        (backends);
635             BatchMatMul3DFp32BroadcastTest    (backends);
636             BatchMatMul3D2DFp32BroadcastTest  (backends);
637             BatchMatMul2DFp32TinyTest         (backends);
638             BatchMatMulNonSquareFp32Test      (backends);
639             BatchMatMul2DFp32SimpleAdjointTest(backends);
640         }
641 
642         TEST_CASE("BATCH_MATMUL_Int8_CpuRefTests")
643         {
644             std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
645             BatchMatMul2DInt8SimpleTest       (backends);
646             BatchMatMul3DInt8SimpleTest       (backends);
647             BatchMatMul4DInt8SimpleTest       (backends);
648             BatchMatMul3DInt8BatchTest        (backends);
649             BatchMatMul3DInt8BroadcastTest    (backends);
650             BatchMatMul3D2DInt8BroadcastTest  (backends);
651             BatchMatMul2DInt8TinyTest         (backends);
652             BatchMatMulNonSquareInt8Test      (backends);
653             BatchMatMul2DInt8SimpleAdjointTest(backends);
654         }
655     }
656 
657     TEST_SUITE("BATCH_MATMUL_CpuAccTests")
658     {
659         TEST_CASE("BATCH_MATMUL_Fp32_CpuAccTests")
660         {
661             std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
662             BatchMatMul2DFp32SimpleTest       (backends);
663             BatchMatMul3DFp32SimpleTest       (backends);
664             BatchMatMul4DFp32SimpleTest       (backends);
665             BatchMatMul3DFp32BatchTest        (backends);
666             BatchMatMul3DFp32BroadcastTest    (backends);
667             BatchMatMul3D2DFp32BroadcastTest  (backends);
668             BatchMatMul2DFp32TinyTest         (backends);
669             BatchMatMulNonSquareFp32Test      (backends);
670             BatchMatMul2DFp32SimpleAdjointTest(backends);
671         }
672     }
673     TEST_SUITE("BATCH_MATMUL_GpuAccTests")
674     {
675         TEST_CASE("BATCH_MATMUL_Fp32_GpuAccTests")
676         {
677             std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc};
678             BatchMatMul2DFp32SimpleTest       (backends);
679             BatchMatMul3DFp32SimpleTest       (backends);
680             BatchMatMul4DFp32SimpleTest       (backends);
681             BatchMatMul3DFp32BatchTest        (backends);
682             BatchMatMul3DFp32BroadcastTest    (backends);
683             BatchMatMul3D2DFp32BroadcastTest  (backends);
684             BatchMatMul2DFp32TinyTest         (backends);
685             BatchMatMulNonSquareFp32Test      (backends);
686             BatchMatMul2DFp32SimpleAdjointTest(backends);
687         }
688     }
689 }
690