1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker #include <string> 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_BatchMatMul") 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker struct BatchMatMulFixture : public ParserFlatbuffersSerializeFixture 16*89c4ff92SAndroid Build Coastguard Worker { BatchMatMulFixtureBatchMatMulFixture17*89c4ff92SAndroid Build Coastguard Worker explicit BatchMatMulFixture(const std::string& inputXShape, 18*89c4ff92SAndroid Build Coastguard Worker const std::string& inputYShape, 19*89c4ff92SAndroid Build Coastguard Worker const std::string& outputShape, 20*89c4ff92SAndroid Build Coastguard Worker const std::string& dataType) 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 23*89c4ff92SAndroid Build Coastguard Worker { 24*89c4ff92SAndroid Build Coastguard Worker inputIds:[ 25*89c4ff92SAndroid Build Coastguard Worker 0, 26*89c4ff92SAndroid Build Coastguard Worker 1 27*89c4ff92SAndroid Build Coastguard Worker ], 28*89c4ff92SAndroid Build Coastguard Worker outputIds:[ 29*89c4ff92SAndroid Build Coastguard Worker 3 30*89c4ff92SAndroid Build Coastguard Worker ], 31*89c4ff92SAndroid Build Coastguard Worker layers:[ 32*89c4ff92SAndroid Build Coastguard Worker { 33*89c4ff92SAndroid Build Coastguard Worker layer_type:"InputLayer", 34*89c4ff92SAndroid Build Coastguard Worker layer:{ 35*89c4ff92SAndroid Build Coastguard Worker base:{ 36*89c4ff92SAndroid Build Coastguard Worker layerBindingId:0, 37*89c4ff92SAndroid Build Coastguard Worker base:{ 38*89c4ff92SAndroid Build Coastguard Worker index:0, 39*89c4ff92SAndroid Build Coastguard Worker layerName:"InputXLayer", 40*89c4ff92SAndroid Build Coastguard Worker layerType:"Input", 41*89c4ff92SAndroid Build Coastguard Worker inputSlots:[ 42*89c4ff92SAndroid Build Coastguard Worker { 43*89c4ff92SAndroid Build Coastguard Worker index:0, 44*89c4ff92SAndroid Build Coastguard Worker connection:{ 45*89c4ff92SAndroid Build Coastguard Worker sourceLayerIndex:0, 46*89c4ff92SAndroid Build Coastguard Worker outputSlotIndex:0 47*89c4ff92SAndroid Build Coastguard Worker }, 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker ], 51*89c4ff92SAndroid Build Coastguard Worker outputSlots:[ 52*89c4ff92SAndroid Build Coastguard Worker { 53*89c4ff92SAndroid Build Coastguard Worker index:0, 54*89c4ff92SAndroid Build Coastguard Worker tensorInfo:{ 55*89c4ff92SAndroid Build Coastguard Worker dimensions:)" + inputXShape + R"(, 56*89c4ff92SAndroid Build Coastguard Worker dataType:)" + dataType + R"( 57*89c4ff92SAndroid Build Coastguard Worker }, 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker } 60*89c4ff92SAndroid Build Coastguard Worker ], 61*89c4ff92SAndroid Build Coastguard Worker 62*89c4ff92SAndroid Build Coastguard Worker }, 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker }, 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker }, 68*89c4ff92SAndroid Build Coastguard Worker { 69*89c4ff92SAndroid Build Coastguard Worker layer_type:"InputLayer", 70*89c4ff92SAndroid Build Coastguard Worker layer:{ 71*89c4ff92SAndroid Build Coastguard Worker base:{ 72*89c4ff92SAndroid Build Coastguard Worker layerBindingId:1, 73*89c4ff92SAndroid Build Coastguard Worker base:{ 74*89c4ff92SAndroid Build Coastguard Worker index:1, 75*89c4ff92SAndroid Build Coastguard Worker layerName:"InputYLayer", 76*89c4ff92SAndroid Build Coastguard Worker layerType:"Input", 77*89c4ff92SAndroid Build Coastguard Worker inputSlots:[ 78*89c4ff92SAndroid Build Coastguard Worker { 79*89c4ff92SAndroid Build Coastguard Worker index:0, 80*89c4ff92SAndroid Build Coastguard Worker connection:{ 81*89c4ff92SAndroid Build Coastguard Worker sourceLayerIndex:0, 82*89c4ff92SAndroid Build Coastguard Worker outputSlotIndex:0 83*89c4ff92SAndroid Build Coastguard Worker }, 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker } 86*89c4ff92SAndroid Build Coastguard Worker ], 87*89c4ff92SAndroid Build Coastguard Worker outputSlots:[ 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker index:0, 90*89c4ff92SAndroid Build Coastguard Worker tensorInfo:{ 91*89c4ff92SAndroid Build Coastguard Worker dimensions:)" + inputYShape + R"(, 92*89c4ff92SAndroid Build Coastguard Worker dataType:)" + dataType + R"( 93*89c4ff92SAndroid Build Coastguard Worker }, 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker ], 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker }, 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker } 101*89c4ff92SAndroid Build Coastguard Worker }, 102*89c4ff92SAndroid Build Coastguard Worker 103*89c4ff92SAndroid Build Coastguard Worker }, 104*89c4ff92SAndroid Build Coastguard Worker { 105*89c4ff92SAndroid Build Coastguard Worker layer_type:"BatchMatMulLayer", 106*89c4ff92SAndroid Build Coastguard Worker layer:{ 107*89c4ff92SAndroid Build Coastguard Worker base:{ 108*89c4ff92SAndroid Build Coastguard Worker index:2, 109*89c4ff92SAndroid Build Coastguard Worker layerName:"BatchMatMulLayer", 110*89c4ff92SAndroid Build Coastguard Worker layerType:"BatchMatMul", 111*89c4ff92SAndroid Build Coastguard Worker inputSlots:[ 112*89c4ff92SAndroid Build Coastguard Worker { 113*89c4ff92SAndroid Build Coastguard Worker index:0, 114*89c4ff92SAndroid Build Coastguard Worker connection:{ 115*89c4ff92SAndroid Build Coastguard Worker sourceLayerIndex:0, 116*89c4ff92SAndroid Build Coastguard Worker outputSlotIndex:0 117*89c4ff92SAndroid Build Coastguard Worker }, 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker }, 120*89c4ff92SAndroid Build Coastguard Worker { 121*89c4ff92SAndroid Build Coastguard Worker index:1, 122*89c4ff92SAndroid Build Coastguard Worker connection:{ 123*89c4ff92SAndroid Build Coastguard Worker sourceLayerIndex:1, 124*89c4ff92SAndroid Build Coastguard Worker outputSlotIndex:0 125*89c4ff92SAndroid Build Coastguard Worker }, 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker } 128*89c4ff92SAndroid Build Coastguard Worker ], 129*89c4ff92SAndroid Build Coastguard Worker outputSlots:[ 130*89c4ff92SAndroid Build Coastguard Worker { 131*89c4ff92SAndroid Build Coastguard Worker index:0, 132*89c4ff92SAndroid Build Coastguard Worker tensorInfo:{ 133*89c4ff92SAndroid Build Coastguard Worker dimensions:)" + outputShape + R"(, 134*89c4ff92SAndroid Build Coastguard Worker dataType:)" + dataType + R"( 135*89c4ff92SAndroid Build Coastguard Worker }, 136*89c4ff92SAndroid Build Coastguard Worker 137*89c4ff92SAndroid Build Coastguard Worker } 138*89c4ff92SAndroid Build Coastguard Worker ], 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker }, 141*89c4ff92SAndroid Build Coastguard Worker descriptor:{ 142*89c4ff92SAndroid Build Coastguard Worker transposeX:false, 143*89c4ff92SAndroid Build Coastguard Worker transposeY:false, 144*89c4ff92SAndroid Build Coastguard Worker adjointX:false, 145*89c4ff92SAndroid Build Coastguard Worker adjointY:false, 146*89c4ff92SAndroid Build Coastguard Worker dataLayoutX:NHWC, 147*89c4ff92SAndroid Build Coastguard Worker dataLayoutY:NHWC 148*89c4ff92SAndroid Build Coastguard Worker } 149*89c4ff92SAndroid Build Coastguard Worker }, 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker }, 152*89c4ff92SAndroid Build Coastguard Worker { 153*89c4ff92SAndroid Build Coastguard Worker layer_type:"OutputLayer", 154*89c4ff92SAndroid Build Coastguard Worker layer:{ 155*89c4ff92SAndroid Build Coastguard Worker base:{ 156*89c4ff92SAndroid Build Coastguard Worker layerBindingId:0, 157*89c4ff92SAndroid Build Coastguard Worker base:{ 158*89c4ff92SAndroid Build Coastguard Worker index:3, 159*89c4ff92SAndroid Build Coastguard Worker layerName:"OutputLayer", 160*89c4ff92SAndroid Build Coastguard Worker layerType:"Output", 161*89c4ff92SAndroid Build Coastguard Worker inputSlots:[ 162*89c4ff92SAndroid Build Coastguard Worker { 163*89c4ff92SAndroid Build Coastguard Worker index:0, 164*89c4ff92SAndroid Build Coastguard Worker connection:{ 165*89c4ff92SAndroid Build Coastguard Worker sourceLayerIndex:2, 166*89c4ff92SAndroid Build Coastguard Worker outputSlotIndex:0 167*89c4ff92SAndroid Build Coastguard Worker }, 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker } 170*89c4ff92SAndroid Build Coastguard Worker ], 171*89c4ff92SAndroid Build Coastguard Worker outputSlots:[ 172*89c4ff92SAndroid Build Coastguard Worker { 173*89c4ff92SAndroid Build Coastguard Worker index:0, 174*89c4ff92SAndroid Build Coastguard Worker tensorInfo:{ 175*89c4ff92SAndroid Build Coastguard Worker dimensions:)" + outputShape + R"(, 176*89c4ff92SAndroid Build Coastguard Worker dataType:)" + dataType + R"( 177*89c4ff92SAndroid Build Coastguard Worker }, 178*89c4ff92SAndroid Build Coastguard Worker 179*89c4ff92SAndroid Build Coastguard Worker } 180*89c4ff92SAndroid Build Coastguard Worker ], 181*89c4ff92SAndroid Build Coastguard Worker 182*89c4ff92SAndroid Build Coastguard Worker } 183*89c4ff92SAndroid Build Coastguard Worker } 184*89c4ff92SAndroid Build Coastguard Worker }, 185*89c4ff92SAndroid Build Coastguard Worker 186*89c4ff92SAndroid Build Coastguard Worker } 187*89c4ff92SAndroid Build Coastguard Worker ] 188*89c4ff92SAndroid Build Coastguard Worker } 189*89c4ff92SAndroid Build Coastguard Worker )"; 190*89c4ff92SAndroid Build Coastguard Worker Setup(); 191*89c4ff92SAndroid Build Coastguard Worker } 192*89c4ff92SAndroid Build Coastguard Worker }; 193*89c4ff92SAndroid Build Coastguard Worker 194*89c4ff92SAndroid Build Coastguard Worker struct SimpleBatchMatMulFixture : BatchMatMulFixture 195*89c4ff92SAndroid Build Coastguard Worker { SimpleBatchMatMulFixtureSimpleBatchMatMulFixture196*89c4ff92SAndroid Build Coastguard Worker SimpleBatchMatMulFixture() 197*89c4ff92SAndroid Build Coastguard Worker : BatchMatMulFixture("[ 1, 2, 2, 1 ]", 198*89c4ff92SAndroid Build Coastguard Worker "[ 1, 2, 2, 1 ]", 199*89c4ff92SAndroid Build Coastguard Worker "[ 1, 2, 2, 1 ]", 200*89c4ff92SAndroid Build Coastguard Worker "Float32") 201*89c4ff92SAndroid Build Coastguard Worker {} 202*89c4ff92SAndroid Build Coastguard Worker }; 203*89c4ff92SAndroid Build Coastguard Worker 204*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleBatchMatMulFixture, "SimpleBatchMatMulTest") 205*89c4ff92SAndroid Build Coastguard Worker { 206*89c4ff92SAndroid Build Coastguard Worker RunTest<4, armnn::DataType::Float32>( 207*89c4ff92SAndroid Build Coastguard Worker 0, 208*89c4ff92SAndroid Build Coastguard Worker {{"InputXLayer", { 1.0f, 2.0f, 3.0f, 4.0f }}, 209*89c4ff92SAndroid Build Coastguard Worker {"InputYLayer", { 5.0f, 6.0f, 7.0f, 8.0f }}}, 210*89c4ff92SAndroid Build Coastguard Worker {{"OutputLayer", { 19.0f, 22.0f, 43.0f, 50.0f }}}); 211*89c4ff92SAndroid Build Coastguard Worker } 212*89c4ff92SAndroid Build Coastguard Worker 213*89c4ff92SAndroid Build Coastguard Worker }