1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "BatchMatMulTestHelper.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h> 11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 14*89c4ff92SAndroid Build Coastguard Worker 15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate 16*89c4ff92SAndroid Build Coastguard Worker { 17*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId> & backends)18*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul2DFp32SimpleTest(std::vector<armnn::BackendId>& backends) 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker // Set input data 21*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2, 2 }; 22*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2, 2 }; 23*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2, 2 }; 24*89c4ff92SAndroid Build Coastguard Worker 25*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 1, 2, 26*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 5, 6, 29*89c4ff92SAndroid Build Coastguard Worker 7, 8 }; 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 19, 22, 32*89c4ff92SAndroid Build Coastguard Worker 43, 50 }; 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 35*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 36*89c4ff92SAndroid Build Coastguard Worker backends, 37*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 38*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 39*89c4ff92SAndroid Build Coastguard Worker outputShape, 40*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 41*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 42*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 43*89c4ff92SAndroid Build Coastguard Worker false, 44*89c4ff92SAndroid Build Coastguard Worker false); 45*89c4ff92SAndroid Build Coastguard Worker } BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId> & backends)46*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul2DInt8SimpleTest(std::vector<armnn::BackendId>& backends) 47*89c4ff92SAndroid Build Coastguard Worker { 48*89c4ff92SAndroid Build Coastguard Worker // Set input data 49*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2, 2 }; 50*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2, 2 }; 51*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2, 2 }; 52*89c4ff92SAndroid Build Coastguard Worker 53*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 1, 2, 54*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 5, 6, 57*89c4ff92SAndroid Build Coastguard Worker 7, 8 }; 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 19, 22, 60*89c4ff92SAndroid Build Coastguard Worker 43, 50 }; 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 63*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 64*89c4ff92SAndroid Build Coastguard Worker backends, 65*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 66*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 67*89c4ff92SAndroid Build Coastguard Worker outputShape, 68*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 69*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 70*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 71*89c4ff92SAndroid Build Coastguard Worker false, 72*89c4ff92SAndroid Build Coastguard Worker false); 73*89c4ff92SAndroid Build Coastguard Worker } 74*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId> & backends)75*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3DFp32SimpleTest(std::vector<armnn::BackendId>& backends) 76*89c4ff92SAndroid Build Coastguard Worker { 77*89c4ff92SAndroid Build Coastguard Worker // Set input data 78*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 1,2,2 }; 79*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,2,2 }; 80*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1,2,2 }; 81*89c4ff92SAndroid Build Coastguard Worker 82*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 1, 2, 83*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 5, 6, 86*89c4ff92SAndroid Build Coastguard Worker 7, 8 }; 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 19, 22, 89*89c4ff92SAndroid Build Coastguard Worker 43, 50 }; 90*89c4ff92SAndroid Build Coastguard Worker 91*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 92*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 93*89c4ff92SAndroid Build Coastguard Worker backends, 94*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 95*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 96*89c4ff92SAndroid Build Coastguard Worker outputShape, 97*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 98*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 99*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 100*89c4ff92SAndroid Build Coastguard Worker false, 101*89c4ff92SAndroid Build Coastguard Worker false); 102*89c4ff92SAndroid Build Coastguard Worker } 103*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId> & backends)104*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3DInt8SimpleTest(std::vector<armnn::BackendId>& backends) 105*89c4ff92SAndroid Build Coastguard Worker { 106*89c4ff92SAndroid Build Coastguard Worker // Set input data 107*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 1,2,2 }; 108*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,2,2 }; 109*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1,2,2 }; 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 1, 2, 112*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 113*89c4ff92SAndroid Build Coastguard Worker 114*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 5, 6, 115*89c4ff92SAndroid Build Coastguard Worker 7, 8 }; 116*89c4ff92SAndroid Build Coastguard Worker 117*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 19, 22, 118*89c4ff92SAndroid Build Coastguard Worker 43, 50 }; 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 121*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 122*89c4ff92SAndroid Build Coastguard Worker backends, 123*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 124*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 125*89c4ff92SAndroid Build Coastguard Worker outputShape, 126*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 127*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 128*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 129*89c4ff92SAndroid Build Coastguard Worker false, 130*89c4ff92SAndroid Build Coastguard Worker false); 131*89c4ff92SAndroid Build Coastguard Worker } 132*89c4ff92SAndroid Build Coastguard Worker BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId> & backends)133*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul4DFp32SimpleTest(std::vector<armnn::BackendId>& backends) 134*89c4ff92SAndroid Build Coastguard Worker { 135*89c4ff92SAndroid Build Coastguard Worker // Set input data 136*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 1,1,2,2 }; 137*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,1,2,2 }; 138*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1,1,2,2 }; 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 1, 2, 141*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 142*89c4ff92SAndroid Build Coastguard Worker 143*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 5, 6, 144*89c4ff92SAndroid Build Coastguard Worker 7, 8 }; 145*89c4ff92SAndroid Build Coastguard Worker 146*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 19, 22, 147*89c4ff92SAndroid Build Coastguard Worker 43, 50 }; 148*89c4ff92SAndroid Build Coastguard Worker 149*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 150*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 151*89c4ff92SAndroid Build Coastguard Worker backends, 152*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 153*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 154*89c4ff92SAndroid Build Coastguard Worker outputShape, 155*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 156*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 157*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 158*89c4ff92SAndroid Build Coastguard Worker false, 159*89c4ff92SAndroid Build Coastguard Worker false); 160*89c4ff92SAndroid Build Coastguard Worker } 161*89c4ff92SAndroid Build Coastguard Worker BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId> & backends)162*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul4DInt8SimpleTest(std::vector<armnn::BackendId>& backends) 163*89c4ff92SAndroid Build Coastguard Worker { 164*89c4ff92SAndroid Build Coastguard Worker // Set input data 165*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 1,1,2,2}; 166*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,1,2,2 }; 167*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1,1,2,2 }; 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 1, 2, 170*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 171*89c4ff92SAndroid Build Coastguard Worker 172*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 5, 6, 173*89c4ff92SAndroid Build Coastguard Worker 7, 8 }; 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 19, 22, 176*89c4ff92SAndroid Build Coastguard Worker 43, 50 }; 177*89c4ff92SAndroid Build Coastguard Worker 178*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 179*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 180*89c4ff92SAndroid Build Coastguard Worker backends, 181*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 182*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 183*89c4ff92SAndroid Build Coastguard Worker outputShape, 184*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 185*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 186*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 187*89c4ff92SAndroid Build Coastguard Worker false, 188*89c4ff92SAndroid Build Coastguard Worker false); 189*89c4ff92SAndroid Build Coastguard Worker } 190*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId> & backends)191*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3DFp32BatchTest(std::vector<armnn::BackendId>& backends) 192*89c4ff92SAndroid Build Coastguard Worker { 193*89c4ff92SAndroid Build Coastguard Worker // Set input data 194*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,2,2 }; 195*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,2,2 }; 196*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,2,2 }; 197*89c4ff92SAndroid Build Coastguard Worker 198*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 1, 2, 199*89c4ff92SAndroid Build Coastguard Worker 3, 4, 200*89c4ff92SAndroid Build Coastguard Worker 201*89c4ff92SAndroid Build Coastguard Worker 9, 10, 202*89c4ff92SAndroid Build Coastguard Worker 11, 12 }; 203*89c4ff92SAndroid Build Coastguard Worker 204*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 5, 6, 205*89c4ff92SAndroid Build Coastguard Worker 7, 8, 206*89c4ff92SAndroid Build Coastguard Worker 207*89c4ff92SAndroid Build Coastguard Worker 13, 14, 208*89c4ff92SAndroid Build Coastguard Worker 15, 16 }; 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 19, 22, 211*89c4ff92SAndroid Build Coastguard Worker 43, 50, 212*89c4ff92SAndroid Build Coastguard Worker 213*89c4ff92SAndroid Build Coastguard Worker 267, 286, 214*89c4ff92SAndroid Build Coastguard Worker 323, 346 }; 215*89c4ff92SAndroid Build Coastguard Worker 216*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 217*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 218*89c4ff92SAndroid Build Coastguard Worker backends, 219*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 220*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 221*89c4ff92SAndroid Build Coastguard Worker outputShape, 222*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 223*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 224*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 225*89c4ff92SAndroid Build Coastguard Worker false, 226*89c4ff92SAndroid Build Coastguard Worker false); 227*89c4ff92SAndroid Build Coastguard Worker } 228*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId> & backends)229*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3DInt8BatchTest(std::vector<armnn::BackendId>& backends) 230*89c4ff92SAndroid Build Coastguard Worker { 231*89c4ff92SAndroid Build Coastguard Worker // Set input data 232*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,2,2 }; 233*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,2,2 }; 234*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,2,2 }; 235*89c4ff92SAndroid Build Coastguard Worker 236*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 1, 2, 237*89c4ff92SAndroid Build Coastguard Worker 3, 4, 238*89c4ff92SAndroid Build Coastguard Worker 239*89c4ff92SAndroid Build Coastguard Worker 9, 10, 240*89c4ff92SAndroid Build Coastguard Worker 11, 12 }; 241*89c4ff92SAndroid Build Coastguard Worker 242*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 5, 6, 243*89c4ff92SAndroid Build Coastguard Worker 7, 8, 244*89c4ff92SAndroid Build Coastguard Worker 245*89c4ff92SAndroid Build Coastguard Worker 1, 2, 246*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 247*89c4ff92SAndroid Build Coastguard Worker 248*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 19, 22, 249*89c4ff92SAndroid Build Coastguard Worker 43, 50, 250*89c4ff92SAndroid Build Coastguard Worker 251*89c4ff92SAndroid Build Coastguard Worker 39, 58, 252*89c4ff92SAndroid Build Coastguard Worker 47, 70 }; 253*89c4ff92SAndroid Build Coastguard Worker 254*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 255*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 256*89c4ff92SAndroid Build Coastguard Worker backends, 257*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 258*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 259*89c4ff92SAndroid Build Coastguard Worker outputShape, 260*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 261*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 262*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 263*89c4ff92SAndroid Build Coastguard Worker false, 264*89c4ff92SAndroid Build Coastguard Worker false); 265*89c4ff92SAndroid Build Coastguard Worker } 266*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId> & backends)267*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3DFp32BroadcastTest(std::vector<armnn::BackendId>& backends) 268*89c4ff92SAndroid Build Coastguard Worker { 269*89c4ff92SAndroid Build Coastguard Worker // Set input data 270*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,2,2 }; 271*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,2 }; 272*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,2,2 }; 273*89c4ff92SAndroid Build Coastguard Worker 274*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 1, 2, 275*89c4ff92SAndroid Build Coastguard Worker 3, 4, 276*89c4ff92SAndroid Build Coastguard Worker 277*89c4ff92SAndroid Build Coastguard Worker 9, 10, 278*89c4ff92SAndroid Build Coastguard Worker 11, 12 }; 279*89c4ff92SAndroid Build Coastguard Worker 280*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 13, 14, 281*89c4ff92SAndroid Build Coastguard Worker 15, 16 }; 282*89c4ff92SAndroid Build Coastguard Worker 283*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 43, 46, 284*89c4ff92SAndroid Build Coastguard Worker 99, 106, 285*89c4ff92SAndroid Build Coastguard Worker 286*89c4ff92SAndroid Build Coastguard Worker 267, 286, 287*89c4ff92SAndroid Build Coastguard Worker 323, 346 }; 288*89c4ff92SAndroid Build Coastguard Worker 289*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 290*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 291*89c4ff92SAndroid Build Coastguard Worker backends, 292*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 293*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 294*89c4ff92SAndroid Build Coastguard Worker outputShape, 295*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 296*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 297*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 298*89c4ff92SAndroid Build Coastguard Worker false, 299*89c4ff92SAndroid Build Coastguard Worker false); 300*89c4ff92SAndroid Build Coastguard Worker } 301*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId> & backends)302*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3DInt8BroadcastTest(std::vector<armnn::BackendId>& backends) 303*89c4ff92SAndroid Build Coastguard Worker { 304*89c4ff92SAndroid Build Coastguard Worker // Set input data 305*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,2,2 }; 306*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,2,2 }; 307*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,2,2 }; 308*89c4ff92SAndroid Build Coastguard Worker 309*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 1, 2, 310*89c4ff92SAndroid Build Coastguard Worker 3, 4, 311*89c4ff92SAndroid Build Coastguard Worker 312*89c4ff92SAndroid Build Coastguard Worker 9, 10, 313*89c4ff92SAndroid Build Coastguard Worker 11, 12 }; 314*89c4ff92SAndroid Build Coastguard Worker 315*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 1, 2, 316*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 317*89c4ff92SAndroid Build Coastguard Worker 318*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 7, 10, 319*89c4ff92SAndroid Build Coastguard Worker 15, 22, 320*89c4ff92SAndroid Build Coastguard Worker 321*89c4ff92SAndroid Build Coastguard Worker 39, 58, 322*89c4ff92SAndroid Build Coastguard Worker 47, 70 }; 323*89c4ff92SAndroid Build Coastguard Worker 324*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 325*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 326*89c4ff92SAndroid Build Coastguard Worker backends, 327*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 328*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 329*89c4ff92SAndroid Build Coastguard Worker outputShape, 330*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 331*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 332*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 333*89c4ff92SAndroid Build Coastguard Worker false, 334*89c4ff92SAndroid Build Coastguard Worker false); 335*89c4ff92SAndroid Build Coastguard Worker } 336*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId> & backends)337*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3D2DFp32BroadcastTest(std::vector<armnn::BackendId>& backends) 338*89c4ff92SAndroid Build Coastguard Worker { 339*89c4ff92SAndroid Build Coastguard Worker // Set input data 340*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,2,2 }; 341*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,2 }; 342*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,2,2 }; 343*89c4ff92SAndroid Build Coastguard Worker 344*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 1, 2, 345*89c4ff92SAndroid Build Coastguard Worker 3, 4, 346*89c4ff92SAndroid Build Coastguard Worker 347*89c4ff92SAndroid Build Coastguard Worker 9, 10, 348*89c4ff92SAndroid Build Coastguard Worker 11, 12 }; 349*89c4ff92SAndroid Build Coastguard Worker 350*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 13, 14, 351*89c4ff92SAndroid Build Coastguard Worker 15, 16 }; 352*89c4ff92SAndroid Build Coastguard Worker 353*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 43, 46, 354*89c4ff92SAndroid Build Coastguard Worker 99, 106, 355*89c4ff92SAndroid Build Coastguard Worker 356*89c4ff92SAndroid Build Coastguard Worker 267, 286, 357*89c4ff92SAndroid Build Coastguard Worker 323, 346 }; 358*89c4ff92SAndroid Build Coastguard Worker 359*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 360*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 361*89c4ff92SAndroid Build Coastguard Worker backends, 362*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 363*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 364*89c4ff92SAndroid Build Coastguard Worker outputShape, 365*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 366*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 367*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 368*89c4ff92SAndroid Build Coastguard Worker false, 369*89c4ff92SAndroid Build Coastguard Worker false); 370*89c4ff92SAndroid Build Coastguard Worker } 371*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId> & backends)372*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul3D2DInt8BroadcastTest(std::vector<armnn::BackendId>& backends) 373*89c4ff92SAndroid Build Coastguard Worker { 374*89c4ff92SAndroid Build Coastguard Worker // Set input data 375*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,2,2 }; 376*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,2 }; 377*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,2,2 }; 378*89c4ff92SAndroid Build Coastguard Worker 379*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 1, 2, 380*89c4ff92SAndroid Build Coastguard Worker 3, 4, 381*89c4ff92SAndroid Build Coastguard Worker 382*89c4ff92SAndroid Build Coastguard Worker 9, 10, 383*89c4ff92SAndroid Build Coastguard Worker 11, 12 }; 384*89c4ff92SAndroid Build Coastguard Worker 385*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 1, 2, 386*89c4ff92SAndroid Build Coastguard Worker 3, 4 }; 387*89c4ff92SAndroid Build Coastguard Worker 388*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 7, 10, 389*89c4ff92SAndroid Build Coastguard Worker 15, 22, 390*89c4ff92SAndroid Build Coastguard Worker 391*89c4ff92SAndroid Build Coastguard Worker 39, 58, 392*89c4ff92SAndroid Build Coastguard Worker 47, 70 }; 393*89c4ff92SAndroid Build Coastguard Worker 394*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 395*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 396*89c4ff92SAndroid Build Coastguard Worker backends, 397*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 398*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 399*89c4ff92SAndroid Build Coastguard Worker outputShape, 400*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 401*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 402*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 403*89c4ff92SAndroid Build Coastguard Worker false, 404*89c4ff92SAndroid Build Coastguard Worker false); 405*89c4ff92SAndroid Build Coastguard Worker } 406*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId> & backends)407*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul2DFp32TinyTest(std::vector<armnn::BackendId>& backends) 408*89c4ff92SAndroid Build Coastguard Worker { 409*89c4ff92SAndroid Build Coastguard Worker // Set input data 410*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 1,1 }; 411*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,1 }; 412*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1,1 }; 413*89c4ff92SAndroid Build Coastguard Worker 414*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 3 }; 415*89c4ff92SAndroid Build Coastguard Worker 416*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 5 }; 417*89c4ff92SAndroid Build Coastguard Worker 418*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 15 }; 419*89c4ff92SAndroid Build Coastguard Worker 420*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 421*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 422*89c4ff92SAndroid Build Coastguard Worker backends, 423*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 424*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 425*89c4ff92SAndroid Build Coastguard Worker outputShape, 426*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 427*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 428*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 429*89c4ff92SAndroid Build Coastguard Worker false, 430*89c4ff92SAndroid Build Coastguard Worker false); 431*89c4ff92SAndroid Build Coastguard Worker } BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId> & backends)432*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul2DInt8TinyTest(std::vector<armnn::BackendId>& backends) 433*89c4ff92SAndroid Build Coastguard Worker { 434*89c4ff92SAndroid Build Coastguard Worker // Set input data 435*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 1,1 }; 436*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 1,1 }; 437*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 1,1 }; 438*89c4ff92SAndroid Build Coastguard Worker 439*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 3 }; 440*89c4ff92SAndroid Build Coastguard Worker 441*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 5 }; 442*89c4ff92SAndroid Build Coastguard Worker 443*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 15 }; 444*89c4ff92SAndroid Build Coastguard Worker 445*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 446*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 447*89c4ff92SAndroid Build Coastguard Worker backends, 448*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 449*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 450*89c4ff92SAndroid Build Coastguard Worker outputShape, 451*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 452*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 453*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 454*89c4ff92SAndroid Build Coastguard Worker false, 455*89c4ff92SAndroid Build Coastguard Worker false); 456*89c4ff92SAndroid Build Coastguard Worker } 457*89c4ff92SAndroid Build Coastguard Worker BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId> & backends)458*89c4ff92SAndroid Build Coastguard Worker void BatchMatMulNonSquareFp32Test(std::vector<armnn::BackendId>& backends) 459*89c4ff92SAndroid Build Coastguard Worker { 460*89c4ff92SAndroid Build Coastguard Worker // Set input data 461*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,5,3 }; 462*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,3,4 }; 463*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,5,4 }; 464*89c4ff92SAndroid Build Coastguard Worker 465*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 8, 8, 4, 466*89c4ff92SAndroid Build Coastguard Worker 6, 1, 3, 467*89c4ff92SAndroid Build Coastguard Worker 8, 8, 3, 468*89c4ff92SAndroid Build Coastguard Worker 8, 9, 8, 469*89c4ff92SAndroid Build Coastguard Worker 5, 4, 4, 470*89c4ff92SAndroid Build Coastguard Worker 471*89c4ff92SAndroid Build Coastguard Worker 1, 8, 5, 472*89c4ff92SAndroid Build Coastguard Worker 7, 1, 1, 473*89c4ff92SAndroid Build Coastguard Worker 8, 7, 9, 474*89c4ff92SAndroid Build Coastguard Worker 3, 2, 7, 475*89c4ff92SAndroid Build Coastguard Worker 8, 5, 3 }; 476*89c4ff92SAndroid Build Coastguard Worker 477*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 6, 2, 3, 2, 478*89c4ff92SAndroid Build Coastguard Worker 6, 2, 2, 8, 479*89c4ff92SAndroid Build Coastguard Worker 3, 7, 8, 1, 480*89c4ff92SAndroid Build Coastguard Worker 481*89c4ff92SAndroid Build Coastguard Worker 7, 2, 9, 5, 482*89c4ff92SAndroid Build Coastguard Worker 2, 3, 1, 3, 483*89c4ff92SAndroid Build Coastguard Worker 2, 7, 7, 5 }; 484*89c4ff92SAndroid Build Coastguard Worker 485*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 108, 60, 72, 84, 486*89c4ff92SAndroid Build Coastguard Worker 51, 35, 44, 23, 487*89c4ff92SAndroid Build Coastguard Worker 105, 53, 64, 83, 488*89c4ff92SAndroid Build Coastguard Worker 126, 90, 106, 96, 489*89c4ff92SAndroid Build Coastguard Worker 66, 46, 55, 46, 490*89c4ff92SAndroid Build Coastguard Worker 491*89c4ff92SAndroid Build Coastguard Worker 33, 61, 52, 54, 492*89c4ff92SAndroid Build Coastguard Worker 53, 24, 71, 43, 493*89c4ff92SAndroid Build Coastguard Worker 88, 100, 142, 106, 494*89c4ff92SAndroid Build Coastguard Worker 39, 61, 78, 56, 495*89c4ff92SAndroid Build Coastguard Worker 72, 52, 98, 70 }; 496*89c4ff92SAndroid Build Coastguard Worker 497*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 498*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 499*89c4ff92SAndroid Build Coastguard Worker backends, 500*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 501*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 502*89c4ff92SAndroid Build Coastguard Worker outputShape, 503*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 504*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 505*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 506*89c4ff92SAndroid Build Coastguard Worker false, 507*89c4ff92SAndroid Build Coastguard Worker false); 508*89c4ff92SAndroid Build Coastguard Worker } 509*89c4ff92SAndroid Build Coastguard Worker BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId> & backends)510*89c4ff92SAndroid Build Coastguard Worker void BatchMatMulNonSquareInt8Test(std::vector<armnn::BackendId>& backends) 511*89c4ff92SAndroid Build Coastguard Worker { 512*89c4ff92SAndroid Build Coastguard Worker // Set input data 513*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 2,5,3 }; 514*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 2,3,4 }; 515*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 2,5,4 }; 516*89c4ff92SAndroid Build Coastguard Worker 517*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 8, 8, 4, 518*89c4ff92SAndroid Build Coastguard Worker 6, 1, 3, 519*89c4ff92SAndroid Build Coastguard Worker 8, 8, 3, 520*89c4ff92SAndroid Build Coastguard Worker 8, 9, 8, 521*89c4ff92SAndroid Build Coastguard Worker 5, 4, 4, 522*89c4ff92SAndroid Build Coastguard Worker 523*89c4ff92SAndroid Build Coastguard Worker 1, 8, 5, 524*89c4ff92SAndroid Build Coastguard Worker 7, 1, 1, 525*89c4ff92SAndroid Build Coastguard Worker 8, 7, 9, 526*89c4ff92SAndroid Build Coastguard Worker 3, 2, 7, 527*89c4ff92SAndroid Build Coastguard Worker 8, 5, 3 }; 528*89c4ff92SAndroid Build Coastguard Worker 529*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 6, 2, 3, 2, 530*89c4ff92SAndroid Build Coastguard Worker 6, 2, 2, 8, 531*89c4ff92SAndroid Build Coastguard Worker 3, 7, 8, 1, 532*89c4ff92SAndroid Build Coastguard Worker 533*89c4ff92SAndroid Build Coastguard Worker 7, 2, 3, 5, 534*89c4ff92SAndroid Build Coastguard Worker 2, 3, 1, 3, 535*89c4ff92SAndroid Build Coastguard Worker 2, 7, 7, 5 }; 536*89c4ff92SAndroid Build Coastguard Worker 537*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 108, 60, 72, 84, 538*89c4ff92SAndroid Build Coastguard Worker 51, 35, 44, 23, 539*89c4ff92SAndroid Build Coastguard Worker 105, 53, 64, 83, 540*89c4ff92SAndroid Build Coastguard Worker 126, 90, 106, 96, 541*89c4ff92SAndroid Build Coastguard Worker 66, 46, 55, 46, 542*89c4ff92SAndroid Build Coastguard Worker 543*89c4ff92SAndroid Build Coastguard Worker 33, 61, 46, 54, 544*89c4ff92SAndroid Build Coastguard Worker 53, 24, 29, 43, 545*89c4ff92SAndroid Build Coastguard Worker 88, 100, 94, 106, 546*89c4ff92SAndroid Build Coastguard Worker 39, 61, 60, 56, 547*89c4ff92SAndroid Build Coastguard Worker 72, 52, 50, 70 }; 548*89c4ff92SAndroid Build Coastguard Worker 549*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 550*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 551*89c4ff92SAndroid Build Coastguard Worker backends, 552*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 553*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 554*89c4ff92SAndroid Build Coastguard Worker outputShape, 555*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 556*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 557*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 558*89c4ff92SAndroid Build Coastguard Worker false, 559*89c4ff92SAndroid Build Coastguard Worker false); 560*89c4ff92SAndroid Build Coastguard Worker } 561*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId> & backends)562*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul2DFp32SimpleAdjointTest(std::vector<armnn::BackendId>& backends) 563*89c4ff92SAndroid Build Coastguard Worker { 564*89c4ff92SAndroid Build Coastguard Worker // Set input data 565*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 3,3 }; 566*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 3,3 }; 567*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 3,3 }; 568*89c4ff92SAndroid Build Coastguard Worker 569*89c4ff92SAndroid Build Coastguard Worker std::vector<float> LHSInputValues = { 3, 1, 1, 570*89c4ff92SAndroid Build Coastguard Worker 1, 3, -1, 571*89c4ff92SAndroid Build Coastguard Worker 2, 4, 1 }; 572*89c4ff92SAndroid Build Coastguard Worker 573*89c4ff92SAndroid Build Coastguard Worker std::vector<float> RHSInputValues = { 1, 0, 0, 574*89c4ff92SAndroid Build Coastguard Worker 0, 1, 0, 575*89c4ff92SAndroid Build Coastguard Worker 0, 0, 1 }; 576*89c4ff92SAndroid Build Coastguard Worker 577*89c4ff92SAndroid Build Coastguard Worker std::vector<float> expectedOutputValues = { 3, 1, 2, 578*89c4ff92SAndroid Build Coastguard Worker 1, 3, 4, 579*89c4ff92SAndroid Build Coastguard Worker 1, -1, 1 }; 580*89c4ff92SAndroid Build Coastguard Worker 581*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<float>(tflite::BuiltinOperator_BATCH_MATMUL, 582*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_FLOAT32, 583*89c4ff92SAndroid Build Coastguard Worker backends, 584*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 585*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 586*89c4ff92SAndroid Build Coastguard Worker outputShape, 587*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 588*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 589*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 590*89c4ff92SAndroid Build Coastguard Worker true, 591*89c4ff92SAndroid Build Coastguard Worker false); 592*89c4ff92SAndroid Build Coastguard Worker } 593*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId> & backends)594*89c4ff92SAndroid Build Coastguard Worker void BatchMatMul2DInt8SimpleAdjointTest(std::vector<armnn::BackendId>& backends) 595*89c4ff92SAndroid Build Coastguard Worker { 596*89c4ff92SAndroid Build Coastguard Worker // Set input data 597*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> LHSInputShape { 3,3 }; 598*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> RHSInputShape { 3,3 }; 599*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t> outputShape { 3,3 }; 600*89c4ff92SAndroid Build Coastguard Worker 601*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> LHSInputValues = { 3, 1, 1, 602*89c4ff92SAndroid Build Coastguard Worker 1, 3, -1, 603*89c4ff92SAndroid Build Coastguard Worker 2, 4, 1 }; 604*89c4ff92SAndroid Build Coastguard Worker 605*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> RHSInputValues = { 1, 0, 0, 606*89c4ff92SAndroid Build Coastguard Worker 0, 1, 0, 607*89c4ff92SAndroid Build Coastguard Worker 0, 0, 1 }; 608*89c4ff92SAndroid Build Coastguard Worker 609*89c4ff92SAndroid Build Coastguard Worker std::vector<int8_t> expectedOutputValues = { 3, 1, 2, 610*89c4ff92SAndroid Build Coastguard Worker 1, 3, 4, 611*89c4ff92SAndroid Build Coastguard Worker 1, -1, 1 }; 612*89c4ff92SAndroid Build Coastguard Worker 613*89c4ff92SAndroid Build Coastguard Worker BatchMatMulTest<int8_t>(tflite::BuiltinOperator_BATCH_MATMUL, 614*89c4ff92SAndroid Build Coastguard Worker ::tflite::TensorType_INT8, 615*89c4ff92SAndroid Build Coastguard Worker backends, 616*89c4ff92SAndroid Build Coastguard Worker LHSInputShape, 617*89c4ff92SAndroid Build Coastguard Worker RHSInputShape, 618*89c4ff92SAndroid Build Coastguard Worker outputShape, 619*89c4ff92SAndroid Build Coastguard Worker LHSInputValues, 620*89c4ff92SAndroid Build Coastguard Worker RHSInputValues, 621*89c4ff92SAndroid Build Coastguard Worker expectedOutputValues, 622*89c4ff92SAndroid Build Coastguard Worker true, 623*89c4ff92SAndroid Build Coastguard Worker false); 624*89c4ff92SAndroid Build Coastguard Worker } 625*89c4ff92SAndroid Build Coastguard Worker 626*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BATCH_MATMUL_CpuRefTests") 627*89c4ff92SAndroid Build Coastguard Worker { 628*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BATCH_MATMUL_Fp32_CpuRefTests") 629*89c4ff92SAndroid Build Coastguard Worker { 630*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef}; 631*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleTest (backends); 632*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32SimpleTest (backends); 633*89c4ff92SAndroid Build Coastguard Worker BatchMatMul4DFp32SimpleTest (backends); 634*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BatchTest (backends); 635*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BroadcastTest (backends); 636*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3D2DFp32BroadcastTest (backends); 637*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32TinyTest (backends); 638*89c4ff92SAndroid Build Coastguard Worker BatchMatMulNonSquareFp32Test (backends); 639*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleAdjointTest(backends); 640*89c4ff92SAndroid Build Coastguard Worker } 641*89c4ff92SAndroid Build Coastguard Worker 642*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BATCH_MATMUL_Int8_CpuRefTests") 643*89c4ff92SAndroid Build Coastguard Worker { 644*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef}; 645*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DInt8SimpleTest (backends); 646*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DInt8SimpleTest (backends); 647*89c4ff92SAndroid Build Coastguard Worker BatchMatMul4DInt8SimpleTest (backends); 648*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DInt8BatchTest (backends); 649*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DInt8BroadcastTest (backends); 650*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3D2DInt8BroadcastTest (backends); 651*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DInt8TinyTest (backends); 652*89c4ff92SAndroid Build Coastguard Worker BatchMatMulNonSquareInt8Test (backends); 653*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DInt8SimpleAdjointTest(backends); 654*89c4ff92SAndroid Build Coastguard Worker } 655*89c4ff92SAndroid Build Coastguard Worker } 656*89c4ff92SAndroid Build Coastguard Worker 657*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BATCH_MATMUL_CpuAccTests") 658*89c4ff92SAndroid Build Coastguard Worker { 659*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BATCH_MATMUL_Fp32_CpuAccTests") 660*89c4ff92SAndroid Build Coastguard Worker { 661*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc}; 662*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleTest (backends); 663*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32SimpleTest (backends); 664*89c4ff92SAndroid Build Coastguard Worker BatchMatMul4DFp32SimpleTest (backends); 665*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BatchTest (backends); 666*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BroadcastTest (backends); 667*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3D2DFp32BroadcastTest (backends); 668*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32TinyTest (backends); 669*89c4ff92SAndroid Build Coastguard Worker BatchMatMulNonSquareFp32Test (backends); 670*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleAdjointTest(backends); 671*89c4ff92SAndroid Build Coastguard Worker } 672*89c4ff92SAndroid Build Coastguard Worker } 673*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("BATCH_MATMUL_GpuAccTests") 674*89c4ff92SAndroid Build Coastguard Worker { 675*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BATCH_MATMUL_Fp32_GpuAccTests") 676*89c4ff92SAndroid Build Coastguard Worker { 677*89c4ff92SAndroid Build Coastguard Worker std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc}; 678*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleTest (backends); 679*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32SimpleTest (backends); 680*89c4ff92SAndroid Build Coastguard Worker BatchMatMul4DFp32SimpleTest (backends); 681*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BatchTest (backends); 682*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3DFp32BroadcastTest (backends); 683*89c4ff92SAndroid Build Coastguard Worker BatchMatMul3D2DFp32BroadcastTest (backends); 684*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32TinyTest (backends); 685*89c4ff92SAndroid Build Coastguard Worker BatchMatMulNonSquareFp32Test (backends); 686*89c4ff92SAndroid Build Coastguard Worker BatchMatMul2DFp32SimpleAdjointTest(backends); 687*89c4ff92SAndroid Build Coastguard Worker } 688*89c4ff92SAndroid Build Coastguard Worker } 689*89c4ff92SAndroid Build Coastguard Worker } 690