xref: /aosp_15_r20/external/armnn/delegate/test/BatchMatMulTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "BatchMatMulTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId> & backends)18*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
19*89c4ff92SAndroid Build Coastguard Worker     {
20*89c4ff92SAndroid Build Coastguard Worker         // Set input data
21*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2, 2 };
22*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2, 2 };
23*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2, 2 };
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 1, 2,
26*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 5, 6,
29*89c4ff92SAndroid Build Coastguard Worker                                               7, 8  };
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 19, 22,
32*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50 };
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
35*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
36*89c4ff92SAndroid Build Coastguard Worker                                backends,
37*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
38*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
39*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
40*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
41*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
42*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
43*89c4ff92SAndroid Build Coastguard Worker                                false,
44*89c4ff92SAndroid Build Coastguard Worker                                false);
45*89c4ff92SAndroid Build Coastguard Worker     }
BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId> & backends)46*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
47*89c4ff92SAndroid Build Coastguard Worker     {
48*89c4ff92SAndroid Build Coastguard Worker         // Set input data
49*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2, 2 };
50*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2, 2 };
51*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2, 2 };
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 1, 2,
54*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 5, 6,
57*89c4ff92SAndroid Build Coastguard Worker                                               7, 8  };
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 19, 22,
60*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50 };
61*89c4ff92SAndroid Build Coastguard Worker 
62*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
63*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
64*89c4ff92SAndroid Build Coastguard Worker                                backends,
65*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
66*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
67*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
68*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
69*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
70*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
71*89c4ff92SAndroid Build Coastguard Worker                                false,
72*89c4ff92SAndroid Build Coastguard Worker                                false);
73*89c4ff92SAndroid Build Coastguard Worker     }
74*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId> & backends)75*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
76*89c4ff92SAndroid Build Coastguard Worker     {
77*89c4ff92SAndroid Build Coastguard Worker         // Set input data
78*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 1,2,2 };
79*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,2,2 };
80*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 1,2,2 };
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 1, 2,
83*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 5, 6,
86*89c4ff92SAndroid Build Coastguard Worker                                               7, 8  };
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 19, 22,
89*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50 };
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
92*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
93*89c4ff92SAndroid Build Coastguard Worker                                backends,
94*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
95*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
96*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
97*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
98*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
99*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
100*89c4ff92SAndroid Build Coastguard Worker                                false,
101*89c4ff92SAndroid Build Coastguard Worker                                false);
102*89c4ff92SAndroid Build Coastguard Worker     }
103*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId> & backends)104*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
105*89c4ff92SAndroid Build Coastguard Worker     {
106*89c4ff92SAndroid Build Coastguard Worker         // Set input data
107*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 1,2,2 };
108*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,2,2 };
109*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 1,2,2 };
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 1, 2,
112*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 5, 6,
115*89c4ff92SAndroid Build Coastguard Worker                                               7, 8  };
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 19, 22,
118*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50 };
119*89c4ff92SAndroid Build Coastguard Worker 
120*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
121*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
122*89c4ff92SAndroid Build Coastguard Worker                                backends,
123*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
124*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
125*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
126*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
127*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
128*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
129*89c4ff92SAndroid Build Coastguard Worker                                false,
130*89c4ff92SAndroid Build Coastguard Worker                                false);
131*89c4ff92SAndroid Build Coastguard Worker     }
132*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId> & backends)133*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId>& backends)
134*89c4ff92SAndroid Build Coastguard Worker     {
135*89c4ff92SAndroid Build Coastguard Worker         // Set input data
136*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 1,1,2,2 };
137*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,1,2,2 };
138*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 1,1,2,2 };
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 1, 2,
141*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 5, 6,
144*89c4ff92SAndroid Build Coastguard Worker                                               7, 8  };
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 19, 22,
147*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50 };
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
150*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
151*89c4ff92SAndroid Build Coastguard Worker                                backends,
152*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
153*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
154*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
155*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
156*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
157*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
158*89c4ff92SAndroid Build Coastguard Worker                                false,
159*89c4ff92SAndroid Build Coastguard Worker                                false);
160*89c4ff92SAndroid Build Coastguard Worker     }
161*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId> & backends)162*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId>& backends)
163*89c4ff92SAndroid Build Coastguard Worker     {
164*89c4ff92SAndroid Build Coastguard Worker         // Set input data
165*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 1,1,2,2};
166*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,1,2,2 };
167*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 1,1,2,2 };
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 1, 2,
170*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 5, 6,
173*89c4ff92SAndroid Build Coastguard Worker                                               7, 8 };
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 19, 22,
176*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50 };
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
179*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
180*89c4ff92SAndroid Build Coastguard Worker                                backends,
181*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
182*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
183*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
184*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
185*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
186*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
187*89c4ff92SAndroid Build Coastguard Worker                                false,
188*89c4ff92SAndroid Build Coastguard Worker                                false);
189*89c4ff92SAndroid Build Coastguard Worker     }
190*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId> & backends)191*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId>& backends)
192*89c4ff92SAndroid Build Coastguard Worker     {
193*89c4ff92SAndroid Build Coastguard Worker         // Set input data
194*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,2,2 };
195*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,2,2 };
196*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,2,2 };
197*89c4ff92SAndroid Build Coastguard Worker 
198*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 1, 2,
199*89c4ff92SAndroid Build Coastguard Worker                                               3, 4,
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker                                               9, 10,
202*89c4ff92SAndroid Build Coastguard Worker                                               11, 12 };
203*89c4ff92SAndroid Build Coastguard Worker 
204*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 5, 6,
205*89c4ff92SAndroid Build Coastguard Worker                                               7, 8,
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker                                               13, 14,
208*89c4ff92SAndroid Build Coastguard Worker                                               15, 16 };
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 19, 22,
211*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50,
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker                                                     267, 286,
214*89c4ff92SAndroid Build Coastguard Worker                                                     323, 346 };
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
217*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
218*89c4ff92SAndroid Build Coastguard Worker                                backends,
219*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
220*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
221*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
222*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
223*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
224*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
225*89c4ff92SAndroid Build Coastguard Worker                                false,
226*89c4ff92SAndroid Build Coastguard Worker                                false);
227*89c4ff92SAndroid Build Coastguard Worker     }
228*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId> & backends)229*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId>& backends)
230*89c4ff92SAndroid Build Coastguard Worker     {
231*89c4ff92SAndroid Build Coastguard Worker         // Set input data
232*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,2,2 };
233*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,2,2 };
234*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,2,2 };
235*89c4ff92SAndroid Build Coastguard Worker 
236*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 1, 2,
237*89c4ff92SAndroid Build Coastguard Worker                                               3, 4,
238*89c4ff92SAndroid Build Coastguard Worker 
239*89c4ff92SAndroid Build Coastguard Worker                                               9, 10,
240*89c4ff92SAndroid Build Coastguard Worker                                               11, 12 };
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 5, 6,
243*89c4ff92SAndroid Build Coastguard Worker                                               7, 8,
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker                                               1, 2,
246*89c4ff92SAndroid Build Coastguard Worker                                               3, 4 };
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 19, 22,
249*89c4ff92SAndroid Build Coastguard Worker                                                     43, 50,
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker                                                     39, 58,
252*89c4ff92SAndroid Build Coastguard Worker                                                     47, 70 };
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
255*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
256*89c4ff92SAndroid Build Coastguard Worker                                backends,
257*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
258*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
259*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
260*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
261*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
262*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
263*89c4ff92SAndroid Build Coastguard Worker                                false,
264*89c4ff92SAndroid Build Coastguard Worker                                false);
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId> & backends)267*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId>& backends)
268*89c4ff92SAndroid Build Coastguard Worker     {
269*89c4ff92SAndroid Build Coastguard Worker         // Set input data
270*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,2,2 };
271*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,2 };
272*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,2,2 };
273*89c4ff92SAndroid Build Coastguard Worker 
274*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 1, 2,
275*89c4ff92SAndroid Build Coastguard Worker                                               3, 4,
276*89c4ff92SAndroid Build Coastguard Worker 
277*89c4ff92SAndroid Build Coastguard Worker                                               9, 10,
278*89c4ff92SAndroid Build Coastguard Worker                                               11, 12 };
279*89c4ff92SAndroid Build Coastguard Worker 
280*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 13, 14,
281*89c4ff92SAndroid Build Coastguard Worker                                               15, 16 };
282*89c4ff92SAndroid Build Coastguard Worker 
283*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = {  43, 46,
284*89c4ff92SAndroid Build Coastguard Worker                                                      99, 106,
285*89c4ff92SAndroid Build Coastguard Worker 
286*89c4ff92SAndroid Build Coastguard Worker                                                      267, 286,
287*89c4ff92SAndroid Build Coastguard Worker                                                      323, 346 };
288*89c4ff92SAndroid Build Coastguard Worker 
289*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
290*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
291*89c4ff92SAndroid Build Coastguard Worker                                backends,
292*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
293*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
294*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
295*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
296*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
297*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
298*89c4ff92SAndroid Build Coastguard Worker                                false,
299*89c4ff92SAndroid Build Coastguard Worker                                false);
300*89c4ff92SAndroid Build Coastguard Worker     }
301*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId> & backends)302*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId>& backends)
303*89c4ff92SAndroid Build Coastguard Worker     {
304*89c4ff92SAndroid Build Coastguard Worker         // Set input data
305*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,2,2 };
306*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,2,2 };
307*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,2,2 };
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 1, 2,
310*89c4ff92SAndroid Build Coastguard Worker                                               3, 4,
311*89c4ff92SAndroid Build Coastguard Worker 
312*89c4ff92SAndroid Build Coastguard Worker                                               9, 10,
313*89c4ff92SAndroid Build Coastguard Worker                                               11, 12 };
314*89c4ff92SAndroid Build Coastguard Worker 
315*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 1, 2,
316*89c4ff92SAndroid Build Coastguard Worker                                                3, 4 };
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = {  7,  10,
319*89c4ff92SAndroid Build Coastguard Worker                                                       15, 22,
320*89c4ff92SAndroid Build Coastguard Worker 
321*89c4ff92SAndroid Build Coastguard Worker                                                       39, 58,
322*89c4ff92SAndroid Build Coastguard Worker                                                       47, 70 };
323*89c4ff92SAndroid Build Coastguard Worker 
324*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
325*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
326*89c4ff92SAndroid Build Coastguard Worker                                backends,
327*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
328*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
329*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
330*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
331*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
332*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
333*89c4ff92SAndroid Build Coastguard Worker                                false,
334*89c4ff92SAndroid Build Coastguard Worker                                false);
335*89c4ff92SAndroid Build Coastguard Worker     }
336*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId> & backends)337*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId>& backends)
338*89c4ff92SAndroid Build Coastguard Worker     {
339*89c4ff92SAndroid Build Coastguard Worker         // Set input data
340*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,2,2 };
341*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,2 };
342*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,2,2 };
343*89c4ff92SAndroid Build Coastguard Worker 
344*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 1, 2,
345*89c4ff92SAndroid Build Coastguard Worker                                               3, 4,
346*89c4ff92SAndroid Build Coastguard Worker 
347*89c4ff92SAndroid Build Coastguard Worker                                               9, 10,
348*89c4ff92SAndroid Build Coastguard Worker                                               11, 12 };
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 13, 14,
351*89c4ff92SAndroid Build Coastguard Worker                                               15, 16 };
352*89c4ff92SAndroid Build Coastguard Worker 
353*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = {  43, 46,
354*89c4ff92SAndroid Build Coastguard Worker                                                      99, 106,
355*89c4ff92SAndroid Build Coastguard Worker 
356*89c4ff92SAndroid Build Coastguard Worker                                                      267, 286,
357*89c4ff92SAndroid Build Coastguard Worker                                                      323, 346 };
358*89c4ff92SAndroid Build Coastguard Worker 
359*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
360*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
361*89c4ff92SAndroid Build Coastguard Worker                                backends,
362*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
363*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
364*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
365*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
366*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
367*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
368*89c4ff92SAndroid Build Coastguard Worker                                false,
369*89c4ff92SAndroid Build Coastguard Worker                                false);
370*89c4ff92SAndroid Build Coastguard Worker     }
371*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId> & backends)372*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId>& backends)
373*89c4ff92SAndroid Build Coastguard Worker     {
374*89c4ff92SAndroid Build Coastguard Worker         // Set input data
375*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,2,2 };
376*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,2 };
377*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,2,2 };
378*89c4ff92SAndroid Build Coastguard Worker 
379*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 1, 2,
380*89c4ff92SAndroid Build Coastguard Worker                                               3, 4,
381*89c4ff92SAndroid Build Coastguard Worker 
382*89c4ff92SAndroid Build Coastguard Worker                                               9, 10,
383*89c4ff92SAndroid Build Coastguard Worker                                               11, 12 };
384*89c4ff92SAndroid Build Coastguard Worker 
385*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 1, 2,
386*89c4ff92SAndroid Build Coastguard Worker                                                3, 4 };
387*89c4ff92SAndroid Build Coastguard Worker 
388*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = {  7, 10,
389*89c4ff92SAndroid Build Coastguard Worker                                                       15, 22,
390*89c4ff92SAndroid Build Coastguard Worker 
391*89c4ff92SAndroid Build Coastguard Worker                                                       39, 58,
392*89c4ff92SAndroid Build Coastguard Worker                                                       47, 70 };
393*89c4ff92SAndroid Build Coastguard Worker 
394*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
395*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
396*89c4ff92SAndroid Build Coastguard Worker                                backends,
397*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
398*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
399*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
400*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
401*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
402*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
403*89c4ff92SAndroid Build Coastguard Worker                                false,
404*89c4ff92SAndroid Build Coastguard Worker                                false);
405*89c4ff92SAndroid Build Coastguard Worker     }
406*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId> & backends)407*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId>& backends)
408*89c4ff92SAndroid Build Coastguard Worker     {
409*89c4ff92SAndroid Build Coastguard Worker         // Set input data
410*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 1,1 };
411*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,1 };
412*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 1,1 };
413*89c4ff92SAndroid Build Coastguard Worker 
414*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 3 };
415*89c4ff92SAndroid Build Coastguard Worker 
416*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 5 };
417*89c4ff92SAndroid Build Coastguard Worker 
418*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 15 };
419*89c4ff92SAndroid Build Coastguard Worker 
420*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
421*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
422*89c4ff92SAndroid Build Coastguard Worker                                backends,
423*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
424*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
425*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
426*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
427*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
428*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
429*89c4ff92SAndroid Build Coastguard Worker                                false,
430*89c4ff92SAndroid Build Coastguard Worker                                false);
431*89c4ff92SAndroid Build Coastguard Worker     }
BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId> & backends)432*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId>& backends)
433*89c4ff92SAndroid Build Coastguard Worker     {
434*89c4ff92SAndroid Build Coastguard Worker         // Set input data
435*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 1,1 };
436*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 1,1 };
437*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 1,1 };
438*89c4ff92SAndroid Build Coastguard Worker 
439*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 3 };
440*89c4ff92SAndroid Build Coastguard Worker 
441*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 5 };
442*89c4ff92SAndroid Build Coastguard Worker 
443*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 15 };
444*89c4ff92SAndroid Build Coastguard Worker 
445*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
446*89c4ff92SAndroid Build Coastguard Worker                                 ::tflite::TensorType_INT8,
447*89c4ff92SAndroid Build Coastguard Worker                                 backends,
448*89c4ff92SAndroid Build Coastguard Worker                                 LHSInputShape,
449*89c4ff92SAndroid Build Coastguard Worker                                 RHSInputShape,
450*89c4ff92SAndroid Build Coastguard Worker                                 outputShape,
451*89c4ff92SAndroid Build Coastguard Worker                                 LHSInputValues,
452*89c4ff92SAndroid Build Coastguard Worker                                 RHSInputValues,
453*89c4ff92SAndroid Build Coastguard Worker                                 expectedOutputValues,
454*89c4ff92SAndroid Build Coastguard Worker                                 false,
455*89c4ff92SAndroid Build Coastguard Worker                                 false);
456*89c4ff92SAndroid Build Coastguard Worker     }
457*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId> & backends)458*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId>& backends)
459*89c4ff92SAndroid Build Coastguard Worker     {
460*89c4ff92SAndroid Build Coastguard Worker         // Set input data
461*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,5,3 };
462*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,3,4 };
463*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,5,4 };
464*89c4ff92SAndroid Build Coastguard Worker 
465*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 8, 8, 4,
466*89c4ff92SAndroid Build Coastguard Worker                                               6, 1, 3,
467*89c4ff92SAndroid Build Coastguard Worker                                               8, 8, 3,
468*89c4ff92SAndroid Build Coastguard Worker                                               8, 9, 8,
469*89c4ff92SAndroid Build Coastguard Worker                                               5, 4, 4,
470*89c4ff92SAndroid Build Coastguard Worker 
471*89c4ff92SAndroid Build Coastguard Worker                                               1, 8, 5,
472*89c4ff92SAndroid Build Coastguard Worker                                               7, 1, 1,
473*89c4ff92SAndroid Build Coastguard Worker                                               8, 7, 9,
474*89c4ff92SAndroid Build Coastguard Worker                                               3, 2, 7,
475*89c4ff92SAndroid Build Coastguard Worker                                               8, 5, 3 };
476*89c4ff92SAndroid Build Coastguard Worker 
477*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 6, 2, 3, 2,
478*89c4ff92SAndroid Build Coastguard Worker                                               6, 2, 2, 8,
479*89c4ff92SAndroid Build Coastguard Worker                                               3, 7, 8, 1,
480*89c4ff92SAndroid Build Coastguard Worker 
481*89c4ff92SAndroid Build Coastguard Worker                                               7, 2, 9, 5,
482*89c4ff92SAndroid Build Coastguard Worker                                               2, 3, 1, 3,
483*89c4ff92SAndroid Build Coastguard Worker                                               2, 7, 7, 5 };
484*89c4ff92SAndroid Build Coastguard Worker 
485*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 108, 60, 72, 84,
486*89c4ff92SAndroid Build Coastguard Worker                                                     51, 35, 44, 23,
487*89c4ff92SAndroid Build Coastguard Worker                                                     105, 53, 64, 83,
488*89c4ff92SAndroid Build Coastguard Worker                                                     126, 90, 106, 96,
489*89c4ff92SAndroid Build Coastguard Worker                                                     66, 46, 55, 46,
490*89c4ff92SAndroid Build Coastguard Worker 
491*89c4ff92SAndroid Build Coastguard Worker                                                     33, 61, 52, 54,
492*89c4ff92SAndroid Build Coastguard Worker                                                     53, 24, 71, 43,
493*89c4ff92SAndroid Build Coastguard Worker                                                     88, 100, 142, 106,
494*89c4ff92SAndroid Build Coastguard Worker                                                     39, 61, 78, 56,
495*89c4ff92SAndroid Build Coastguard Worker                                                     72, 52, 98, 70 };
496*89c4ff92SAndroid Build Coastguard Worker 
497*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
498*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
499*89c4ff92SAndroid Build Coastguard Worker                                backends,
500*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
501*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
502*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
503*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
504*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
505*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
506*89c4ff92SAndroid Build Coastguard Worker                                false,
507*89c4ff92SAndroid Build Coastguard Worker                                false);
508*89c4ff92SAndroid Build Coastguard Worker     }
509*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId> & backends)510*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId>& backends)
511*89c4ff92SAndroid Build Coastguard Worker     {
512*89c4ff92SAndroid Build Coastguard Worker         // Set input data
513*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 2,5,3 };
514*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 2,3,4 };
515*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 2,5,4 };
516*89c4ff92SAndroid Build Coastguard Worker 
517*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 8, 8, 4,
518*89c4ff92SAndroid Build Coastguard Worker                                               6, 1, 3,
519*89c4ff92SAndroid Build Coastguard Worker                                               8, 8, 3,
520*89c4ff92SAndroid Build Coastguard Worker                                               8, 9, 8,
521*89c4ff92SAndroid Build Coastguard Worker                                               5, 4, 4,
522*89c4ff92SAndroid Build Coastguard Worker 
523*89c4ff92SAndroid Build Coastguard Worker                                               1, 8, 5,
524*89c4ff92SAndroid Build Coastguard Worker                                               7, 1, 1,
525*89c4ff92SAndroid Build Coastguard Worker                                               8, 7, 9,
526*89c4ff92SAndroid Build Coastguard Worker                                               3, 2, 7,
527*89c4ff92SAndroid Build Coastguard Worker                                               8, 5, 3 };
528*89c4ff92SAndroid Build Coastguard Worker 
529*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 6, 2, 3, 2,
530*89c4ff92SAndroid Build Coastguard Worker                                               6, 2, 2, 8,
531*89c4ff92SAndroid Build Coastguard Worker                                               3, 7, 8, 1,
532*89c4ff92SAndroid Build Coastguard Worker 
533*89c4ff92SAndroid Build Coastguard Worker                                               7, 2, 3, 5,
534*89c4ff92SAndroid Build Coastguard Worker                                               2, 3, 1, 3,
535*89c4ff92SAndroid Build Coastguard Worker                                               2, 7, 7, 5 };
536*89c4ff92SAndroid Build Coastguard Worker 
537*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 108, 60, 72, 84,
538*89c4ff92SAndroid Build Coastguard Worker                                                     51, 35, 44, 23,
539*89c4ff92SAndroid Build Coastguard Worker                                                     105, 53, 64, 83,
540*89c4ff92SAndroid Build Coastguard Worker                                                     126, 90, 106, 96,
541*89c4ff92SAndroid Build Coastguard Worker                                                     66, 46, 55, 46,
542*89c4ff92SAndroid Build Coastguard Worker 
543*89c4ff92SAndroid Build Coastguard Worker                                                     33, 61, 46, 54,
544*89c4ff92SAndroid Build Coastguard Worker                                                     53, 24, 29, 43,
545*89c4ff92SAndroid Build Coastguard Worker                                                     88, 100, 94, 106,
546*89c4ff92SAndroid Build Coastguard Worker                                                     39, 61, 60, 56,
547*89c4ff92SAndroid Build Coastguard Worker                                                     72, 52, 50, 70 };
548*89c4ff92SAndroid Build Coastguard Worker 
549*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
550*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
551*89c4ff92SAndroid Build Coastguard Worker                                backends,
552*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
553*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
554*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
555*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
556*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
557*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
558*89c4ff92SAndroid Build Coastguard Worker                                false,
559*89c4ff92SAndroid Build Coastguard Worker                                false);
560*89c4ff92SAndroid Build Coastguard Worker     }
561*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId> & backends)562*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId>& backends)
563*89c4ff92SAndroid Build Coastguard Worker     {
564*89c4ff92SAndroid Build Coastguard Worker         // Set input data
565*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 3,3 };
566*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 3,3 };
567*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 3,3 };
568*89c4ff92SAndroid Build Coastguard Worker 
569*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> LHSInputValues = { 3, 1, 1,
570*89c4ff92SAndroid Build Coastguard Worker                                               1, 3, -1,
571*89c4ff92SAndroid Build Coastguard Worker                                               2, 4, 1 };
572*89c4ff92SAndroid Build Coastguard Worker 
573*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> RHSInputValues = { 1, 0, 0,
574*89c4ff92SAndroid Build Coastguard Worker                                               0, 1, 0,
575*89c4ff92SAndroid Build Coastguard Worker                                               0, 0, 1 };
576*89c4ff92SAndroid Build Coastguard Worker 
577*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> expectedOutputValues = { 3, 1, 2,
578*89c4ff92SAndroid Build Coastguard Worker                                                     1, 3, 4,
579*89c4ff92SAndroid Build Coastguard Worker                                                     1, -1, 1 };
580*89c4ff92SAndroid Build Coastguard Worker 
581*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL,
582*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_FLOAT32,
583*89c4ff92SAndroid Build Coastguard Worker                                backends,
584*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
585*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
586*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
587*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
588*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
589*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
590*89c4ff92SAndroid Build Coastguard Worker                                true,
591*89c4ff92SAndroid Build Coastguard Worker                                false);
592*89c4ff92SAndroid Build Coastguard Worker     }
593*89c4ff92SAndroid Build Coastguard Worker 
BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId> & backends)594*89c4ff92SAndroid Build Coastguard Worker     void BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId>& backends)
595*89c4ff92SAndroid Build Coastguard Worker     {
596*89c4ff92SAndroid Build Coastguard Worker         // Set input data
597*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> LHSInputShape { 3,3 };
598*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> RHSInputShape { 3,3 };
599*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> outputShape   { 3,3 };
600*89c4ff92SAndroid Build Coastguard Worker 
601*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> LHSInputValues = { 3, 1, 1,
602*89c4ff92SAndroid Build Coastguard Worker                                               1, 3, -1,
603*89c4ff92SAndroid Build Coastguard Worker                                               2, 4, 1 };
604*89c4ff92SAndroid Build Coastguard Worker 
605*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> RHSInputValues = { 1, 0, 0,
606*89c4ff92SAndroid Build Coastguard Worker                                               0, 1, 0,
607*89c4ff92SAndroid Build Coastguard Worker                                               0, 0, 1 };
608*89c4ff92SAndroid Build Coastguard Worker 
609*89c4ff92SAndroid Build Coastguard Worker         std::vector<int8_t> expectedOutputValues = { 3, 1, 2,
610*89c4ff92SAndroid Build Coastguard Worker                                                      1, 3, 4,
611*89c4ff92SAndroid Build Coastguard Worker                                                      1, -1, 1 };
612*89c4ff92SAndroid Build Coastguard Worker 
613*89c4ff92SAndroid Build Coastguard Worker         BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL,
614*89c4ff92SAndroid Build Coastguard Worker                                ::tflite::TensorType_INT8,
615*89c4ff92SAndroid Build Coastguard Worker                                backends,
616*89c4ff92SAndroid Build Coastguard Worker                                LHSInputShape,
617*89c4ff92SAndroid Build Coastguard Worker                                RHSInputShape,
618*89c4ff92SAndroid Build Coastguard Worker                                outputShape,
619*89c4ff92SAndroid Build Coastguard Worker                                LHSInputValues,
620*89c4ff92SAndroid Build Coastguard Worker                                RHSInputValues,
621*89c4ff92SAndroid Build Coastguard Worker                                expectedOutputValues,
622*89c4ff92SAndroid Build Coastguard Worker                                true,
623*89c4ff92SAndroid Build Coastguard Worker                                false);
624*89c4ff92SAndroid Build Coastguard Worker     }
625*89c4ff92SAndroid Build Coastguard Worker 
626*89c4ff92SAndroid Build Coastguard Worker     TEST_SUITE("BATCH_MATMUL_CpuRefTests")
627*89c4ff92SAndroid Build Coastguard Worker     {
628*89c4ff92SAndroid Build Coastguard Worker         TEST_CASE("BATCH_MATMUL_Fp32_CpuRefTests")
629*89c4ff92SAndroid Build Coastguard Worker         {
630*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
631*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32SimpleTest       (backends);
632*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32SimpleTest       (backends);
633*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul4DFp32SimpleTest       (backends);
634*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32BatchTest        (backends);
635*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32BroadcastTest    (backends);
636*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3D2DFp32BroadcastTest  (backends);
637*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32TinyTest         (backends);
638*89c4ff92SAndroid Build Coastguard Worker             BatchMatMulNonSquareFp32Test      (backends);
639*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32SimpleAdjointTest(backends);
640*89c4ff92SAndroid Build Coastguard Worker         }
641*89c4ff92SAndroid Build Coastguard Worker 
642*89c4ff92SAndroid Build Coastguard Worker         TEST_CASE("BATCH_MATMUL_Int8_CpuRefTests")
643*89c4ff92SAndroid Build Coastguard Worker         {
644*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
645*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DInt8SimpleTest       (backends);
646*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DInt8SimpleTest       (backends);
647*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul4DInt8SimpleTest       (backends);
648*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DInt8BatchTest        (backends);
649*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DInt8BroadcastTest    (backends);
650*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3D2DInt8BroadcastTest  (backends);
651*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DInt8TinyTest         (backends);
652*89c4ff92SAndroid Build Coastguard Worker             BatchMatMulNonSquareInt8Test      (backends);
653*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DInt8SimpleAdjointTest(backends);
654*89c4ff92SAndroid Build Coastguard Worker         }
655*89c4ff92SAndroid Build Coastguard Worker     }
656*89c4ff92SAndroid Build Coastguard Worker 
657*89c4ff92SAndroid Build Coastguard Worker     TEST_SUITE("BATCH_MATMUL_CpuAccTests")
658*89c4ff92SAndroid Build Coastguard Worker     {
659*89c4ff92SAndroid Build Coastguard Worker         TEST_CASE("BATCH_MATMUL_Fp32_CpuAccTests")
660*89c4ff92SAndroid Build Coastguard Worker         {
661*89c4ff92SAndroid Build Coastguard Worker             std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
662*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32SimpleTest       (backends);
663*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32SimpleTest       (backends);
664*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul4DFp32SimpleTest       (backends);
665*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32BatchTest        (backends);
666*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32BroadcastTest    (backends);
667*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3D2DFp32BroadcastTest  (backends);
668*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32TinyTest         (backends);
669*89c4ff92SAndroid Build Coastguard Worker             BatchMatMulNonSquareFp32Test      (backends);
670*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32SimpleAdjointTest(backends);
671*89c4ff92SAndroid Build Coastguard Worker         }
672*89c4ff92SAndroid Build Coastguard Worker     }
673*89c4ff92SAndroid Build Coastguard Worker     TEST_SUITE("BATCH_MATMUL_GpuAccTests")
674*89c4ff92SAndroid Build Coastguard Worker     {
675*89c4ff92SAndroid Build Coastguard Worker         TEST_CASE("BATCH_MATMUL_Fp32_GpuAccTests")
676*89c4ff92SAndroid Build Coastguard Worker         {
677*89c4ff92SAndroid Build Coastguard Worker             std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc};
678*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32SimpleTest       (backends);
679*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32SimpleTest       (backends);
680*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul4DFp32SimpleTest       (backends);
681*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32BatchTest        (backends);
682*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3DFp32BroadcastTest    (backends);
683*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul3D2DFp32BroadcastTest  (backends);
684*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32TinyTest         (backends);
685*89c4ff92SAndroid Build Coastguard Worker             BatchMatMulNonSquareFp32Test      (backends);
686*89c4ff92SAndroid Build Coastguard Worker             BatchMatMul2DFp32SimpleAdjointTest(backends);
687*89c4ff92SAndroid Build Coastguard Worker         }
688*89c4ff92SAndroid Build Coastguard Worker     }
689*89c4ff92SAndroid Build Coastguard Worker }
690