1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Gemm") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14*89c4ff92SAndroid Build Coastguard Worker { GemmFixtureGemmFixture15*89c4ff92SAndroid Build Coastguard Worker GemmFixture(const std::string& alpha, 16*89c4ff92SAndroid Build Coastguard Worker const std::string& beta, 17*89c4ff92SAndroid Build Coastguard Worker const std::string& transA, 18*89c4ff92SAndroid Build Coastguard Worker const std::string& transB, 19*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputAShape, 20*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputBShape, 21*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputCShape, 22*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& outputShape) 23*89c4ff92SAndroid Build Coastguard Worker { 24*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 25*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 26*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 27*89c4ff92SAndroid Build Coastguard Worker graph { 28*89c4ff92SAndroid Build Coastguard Worker node { 29*89c4ff92SAndroid Build Coastguard Worker input: "A" 30*89c4ff92SAndroid Build Coastguard Worker input: "B" 31*89c4ff92SAndroid Build Coastguard Worker input: "C" 32*89c4ff92SAndroid Build Coastguard Worker output: "Output" 33*89c4ff92SAndroid Build Coastguard Worker op_type: "Gemm" 34*89c4ff92SAndroid Build Coastguard Worker attribute { 35*89c4ff92SAndroid Build Coastguard Worker name: "alpha" 36*89c4ff92SAndroid Build Coastguard Worker f: )" + alpha + R"( 37*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker attribute { 40*89c4ff92SAndroid Build Coastguard Worker name: "beta" 41*89c4ff92SAndroid Build Coastguard Worker f: )" + beta + R"( 42*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 43*89c4ff92SAndroid Build Coastguard Worker } 44*89c4ff92SAndroid Build Coastguard Worker attribute { 45*89c4ff92SAndroid Build Coastguard Worker name: "transA" 46*89c4ff92SAndroid Build Coastguard Worker i: )" + transA + R"( 47*89c4ff92SAndroid Build Coastguard Worker type: INT 48*89c4ff92SAndroid Build Coastguard Worker } 49*89c4ff92SAndroid Build Coastguard Worker attribute { 50*89c4ff92SAndroid Build Coastguard Worker name: "transB" 51*89c4ff92SAndroid Build Coastguard Worker i: )" + transB + R"( 52*89c4ff92SAndroid Build Coastguard Worker type: INT 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker name: "gem-model" 56*89c4ff92SAndroid Build Coastguard Worker input { 57*89c4ff92SAndroid Build Coastguard Worker name: "A" 58*89c4ff92SAndroid Build Coastguard Worker type { 59*89c4ff92SAndroid Build Coastguard Worker tensor_type { 60*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 61*89c4ff92SAndroid Build Coastguard Worker shape { 62*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"( 63*89c4ff92SAndroid Build Coastguard Worker } 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker } 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker input { 68*89c4ff92SAndroid Build Coastguard Worker name: "B" 69*89c4ff92SAndroid Build Coastguard Worker type { 70*89c4ff92SAndroid Build Coastguard Worker tensor_type { 71*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 72*89c4ff92SAndroid Build Coastguard Worker shape { 73*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"( 74*89c4ff92SAndroid Build Coastguard Worker } 75*89c4ff92SAndroid Build Coastguard Worker } 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker } 78*89c4ff92SAndroid Build Coastguard Worker input { 79*89c4ff92SAndroid Build Coastguard Worker name: "C" 80*89c4ff92SAndroid Build Coastguard Worker type { 81*89c4ff92SAndroid Build Coastguard Worker tensor_type { 82*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 83*89c4ff92SAndroid Build Coastguard Worker shape { 84*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"( 85*89c4ff92SAndroid Build Coastguard Worker } 86*89c4ff92SAndroid Build Coastguard Worker } 87*89c4ff92SAndroid Build Coastguard Worker } 88*89c4ff92SAndroid Build Coastguard Worker } 89*89c4ff92SAndroid Build Coastguard Worker output { 90*89c4ff92SAndroid Build Coastguard Worker name: "Output" 91*89c4ff92SAndroid Build Coastguard Worker type { 92*89c4ff92SAndroid Build Coastguard Worker tensor_type { 93*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 94*89c4ff92SAndroid Build Coastguard Worker shape { 95*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 96*89c4ff92SAndroid Build Coastguard Worker } 97*89c4ff92SAndroid Build Coastguard Worker } 98*89c4ff92SAndroid Build Coastguard Worker } 99*89c4ff92SAndroid Build Coastguard Worker } 100*89c4ff92SAndroid Build Coastguard Worker })"; 101*89c4ff92SAndroid Build Coastguard Worker } 102*89c4ff92SAndroid Build Coastguard Worker }; 103*89c4ff92SAndroid Build Coastguard Worker 104*89c4ff92SAndroid Build Coastguard Worker struct GemmAllAttributesFixture : GemmFixture 105*89c4ff92SAndroid Build Coastguard Worker { GemmAllAttributesFixtureGemmAllAttributesFixture106*89c4ff92SAndroid Build Coastguard Worker GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 }) 107*89c4ff92SAndroid Build Coastguard Worker { 108*89c4ff92SAndroid Build Coastguard Worker Setup(); 109*89c4ff92SAndroid Build Coastguard Worker } 110*89c4ff92SAndroid Build Coastguard Worker }; 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker struct GemmSimpleFixture : GemmFixture 113*89c4ff92SAndroid Build Coastguard Worker { GemmSimpleFixtureGemmSimpleFixture114*89c4ff92SAndroid Build Coastguard Worker GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 }) 115*89c4ff92SAndroid Build Coastguard Worker { 116*89c4ff92SAndroid Build Coastguard Worker Setup(); 117*89c4ff92SAndroid Build Coastguard Worker } 118*89c4ff92SAndroid Build Coastguard Worker }; 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker struct GemmTransAFixture : GemmFixture 121*89c4ff92SAndroid Build Coastguard Worker { GemmTransAFixtureGemmTransAFixture122*89c4ff92SAndroid Build Coastguard Worker GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 }) 123*89c4ff92SAndroid Build Coastguard Worker { 124*89c4ff92SAndroid Build Coastguard Worker Setup(); 125*89c4ff92SAndroid Build Coastguard Worker } 126*89c4ff92SAndroid Build Coastguard Worker }; 127*89c4ff92SAndroid Build Coastguard Worker 128*89c4ff92SAndroid Build Coastguard Worker struct GemmTransBFixture : GemmFixture 129*89c4ff92SAndroid Build Coastguard Worker { GemmTransBFixtureGemmTransBFixture130*89c4ff92SAndroid Build Coastguard Worker GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 }) 131*89c4ff92SAndroid Build Coastguard Worker { 132*89c4ff92SAndroid Build Coastguard Worker Setup(); 133*89c4ff92SAndroid Build Coastguard Worker } 134*89c4ff92SAndroid Build Coastguard Worker }; 135*89c4ff92SAndroid Build Coastguard Worker 136*89c4ff92SAndroid Build Coastguard Worker struct GemmParseExceptionFixture : GemmFixture 137*89c4ff92SAndroid Build Coastguard Worker { GemmParseExceptionFixtureGemmParseExceptionFixture138*89c4ff92SAndroid Build Coastguard Worker GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {} 139*89c4ff92SAndroid Build Coastguard Worker }; 140*89c4ff92SAndroid Build Coastguard Worker 141*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest") 142*89c4ff92SAndroid Build Coastguard Worker { 143*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 144*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 145*89c4ff92SAndroid Build Coastguard Worker {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 146*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 147*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 148*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 149*89c4ff92SAndroid Build Coastguard Worker {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 150*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f, 151*89c4ff92SAndroid Build Coastguard Worker 12.535f, 38.57f, 64.605f, 90.64f, 116.675f, 152*89c4ff92SAndroid Build Coastguard Worker 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}}); 153*89c4ff92SAndroid Build Coastguard Worker } 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest") 156*89c4ff92SAndroid Build Coastguard Worker { 157*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 158*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 159*89c4ff92SAndroid Build Coastguard Worker {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 160*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 161*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 162*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 163*89c4ff92SAndroid Build Coastguard Worker {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 164*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f, 165*89c4ff92SAndroid Build Coastguard Worker 196.1f, 222.2f, 248.3f, 274.4f, 300.5f, 166*89c4ff92SAndroid Build Coastguard Worker 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}}); 167*89c4ff92SAndroid Build Coastguard Worker } 168*89c4ff92SAndroid Build Coastguard Worker 169*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest") 170*89c4ff92SAndroid Build Coastguard Worker { 171*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 172*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 173*89c4ff92SAndroid Build Coastguard Worker {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 174*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 175*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 176*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 177*89c4ff92SAndroid Build Coastguard Worker {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 178*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f, 179*89c4ff92SAndroid Build Coastguard Worker 146.1f, 172.2f, 198.3f, 224.4f, 250.5f, 180*89c4ff92SAndroid Build Coastguard Worker 112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}}); 181*89c4ff92SAndroid Build Coastguard Worker } 182*89c4ff92SAndroid Build Coastguard Worker 183*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest") 184*89c4ff92SAndroid Build Coastguard Worker { 185*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 186*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 187*89c4ff92SAndroid Build Coastguard Worker {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 188*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 189*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 190*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}, 191*89c4ff92SAndroid Build Coastguard Worker {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}}, 192*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f, 193*89c4ff92SAndroid Build Coastguard Worker 60.1f, 164.2f, 268.3f, 372.4f, 476.5f, 194*89c4ff92SAndroid Build Coastguard Worker 20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}}); 195*89c4ff92SAndroid Build Coastguard Worker } 196*89c4ff92SAndroid Build Coastguard Worker 197*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest") 198*89c4ff92SAndroid Build Coastguard Worker { 199*89c4ff92SAndroid Build Coastguard Worker // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension) 200*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup(), armnn::ParseException); 201*89c4ff92SAndroid Build Coastguard Worker } 202*89c4ff92SAndroid Build Coastguard Worker 203*89c4ff92SAndroid Build Coastguard Worker struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 204*89c4ff92SAndroid Build Coastguard Worker { GemmConstantFixtureGemmConstantFixture205*89c4ff92SAndroid Build Coastguard Worker GemmConstantFixture() 206*89c4ff92SAndroid Build Coastguard Worker { 207*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 208*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 209*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 210*89c4ff92SAndroid Build Coastguard Worker graph { 211*89c4ff92SAndroid Build Coastguard Worker node { 212*89c4ff92SAndroid Build Coastguard Worker input: "A" 213*89c4ff92SAndroid Build Coastguard Worker input: "B" 214*89c4ff92SAndroid Build Coastguard Worker input: "C" 215*89c4ff92SAndroid Build Coastguard Worker output: "Output" 216*89c4ff92SAndroid Build Coastguard Worker op_type: "Gemm" 217*89c4ff92SAndroid Build Coastguard Worker attribute { 218*89c4ff92SAndroid Build Coastguard Worker name: "alpha" 219*89c4ff92SAndroid Build Coastguard Worker f: 0.25 220*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 221*89c4ff92SAndroid Build Coastguard Worker } 222*89c4ff92SAndroid Build Coastguard Worker attribute { 223*89c4ff92SAndroid Build Coastguard Worker name: "beta" 224*89c4ff92SAndroid Build Coastguard Worker f: 0.35 225*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 226*89c4ff92SAndroid Build Coastguard Worker } 227*89c4ff92SAndroid Build Coastguard Worker attribute { 228*89c4ff92SAndroid Build Coastguard Worker name: "transA" 229*89c4ff92SAndroid Build Coastguard Worker i: 1 230*89c4ff92SAndroid Build Coastguard Worker type: INT 231*89c4ff92SAndroid Build Coastguard Worker } 232*89c4ff92SAndroid Build Coastguard Worker attribute { 233*89c4ff92SAndroid Build Coastguard Worker name: "transB" 234*89c4ff92SAndroid Build Coastguard Worker i: 1 235*89c4ff92SAndroid Build Coastguard Worker type: INT 236*89c4ff92SAndroid Build Coastguard Worker } 237*89c4ff92SAndroid Build Coastguard Worker } 238*89c4ff92SAndroid Build Coastguard Worker name: "gem-model" 239*89c4ff92SAndroid Build Coastguard Worker initializer { 240*89c4ff92SAndroid Build Coastguard Worker dims: 5 241*89c4ff92SAndroid Build Coastguard Worker dims: 4 242*89c4ff92SAndroid Build Coastguard Worker data_type: 1 243*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 244*89c4ff92SAndroid Build Coastguard Worker float_data: 2.0 245*89c4ff92SAndroid Build Coastguard Worker float_data: 3.0 246*89c4ff92SAndroid Build Coastguard Worker float_data: 4.0 247*89c4ff92SAndroid Build Coastguard Worker float_data: 5.0 248*89c4ff92SAndroid Build Coastguard Worker float_data: 6.0 249*89c4ff92SAndroid Build Coastguard Worker float_data: 7.0 250*89c4ff92SAndroid Build Coastguard Worker float_data: 8.0 251*89c4ff92SAndroid Build Coastguard Worker float_data: 9.0 252*89c4ff92SAndroid Build Coastguard Worker float_data: 10.0 253*89c4ff92SAndroid Build Coastguard Worker float_data: 11.0 254*89c4ff92SAndroid Build Coastguard Worker float_data: 12.0 255*89c4ff92SAndroid Build Coastguard Worker float_data: 13.0 256*89c4ff92SAndroid Build Coastguard Worker float_data: 14.0 257*89c4ff92SAndroid Build Coastguard Worker float_data: 15.0 258*89c4ff92SAndroid Build Coastguard Worker float_data: 16.0 259*89c4ff92SAndroid Build Coastguard Worker float_data: 17.0 260*89c4ff92SAndroid Build Coastguard Worker float_data: 18.0 261*89c4ff92SAndroid Build Coastguard Worker float_data: 19.0 262*89c4ff92SAndroid Build Coastguard Worker float_data: 20.0 263*89c4ff92SAndroid Build Coastguard Worker name: "B" 264*89c4ff92SAndroid Build Coastguard Worker } 265*89c4ff92SAndroid Build Coastguard Worker initializer { 266*89c4ff92SAndroid Build Coastguard Worker dims: 1 267*89c4ff92SAndroid Build Coastguard Worker dims: 5 268*89c4ff92SAndroid Build Coastguard Worker data_type: 1 269*89c4ff92SAndroid Build Coastguard Worker float_data: 0.1 270*89c4ff92SAndroid Build Coastguard Worker float_data: 0.2 271*89c4ff92SAndroid Build Coastguard Worker float_data: 0.3 272*89c4ff92SAndroid Build Coastguard Worker float_data: 0.4 273*89c4ff92SAndroid Build Coastguard Worker float_data: 0.5 274*89c4ff92SAndroid Build Coastguard Worker name: "C" 275*89c4ff92SAndroid Build Coastguard Worker } 276*89c4ff92SAndroid Build Coastguard Worker input { 277*89c4ff92SAndroid Build Coastguard Worker name: "A" 278*89c4ff92SAndroid Build Coastguard Worker type { 279*89c4ff92SAndroid Build Coastguard Worker tensor_type { 280*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 281*89c4ff92SAndroid Build Coastguard Worker shape { 282*89c4ff92SAndroid Build Coastguard Worker dim { 283*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 284*89c4ff92SAndroid Build Coastguard Worker } 285*89c4ff92SAndroid Build Coastguard Worker dim { 286*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 287*89c4ff92SAndroid Build Coastguard Worker } 288*89c4ff92SAndroid Build Coastguard Worker } 289*89c4ff92SAndroid Build Coastguard Worker } 290*89c4ff92SAndroid Build Coastguard Worker } 291*89c4ff92SAndroid Build Coastguard Worker } 292*89c4ff92SAndroid Build Coastguard Worker output { 293*89c4ff92SAndroid Build Coastguard Worker name: "Output" 294*89c4ff92SAndroid Build Coastguard Worker type { 295*89c4ff92SAndroid Build Coastguard Worker tensor_type { 296*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 297*89c4ff92SAndroid Build Coastguard Worker shape { 298*89c4ff92SAndroid Build Coastguard Worker dim { 299*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 300*89c4ff92SAndroid Build Coastguard Worker } 301*89c4ff92SAndroid Build Coastguard Worker dim { 302*89c4ff92SAndroid Build Coastguard Worker dim_value: 5 303*89c4ff92SAndroid Build Coastguard Worker } 304*89c4ff92SAndroid Build Coastguard Worker } 305*89c4ff92SAndroid Build Coastguard Worker } 306*89c4ff92SAndroid Build Coastguard Worker } 307*89c4ff92SAndroid Build Coastguard Worker } 308*89c4ff92SAndroid Build Coastguard Worker })"; 309*89c4ff92SAndroid Build Coastguard Worker Setup(); 310*89c4ff92SAndroid Build Coastguard Worker } 311*89c4ff92SAndroid Build Coastguard Worker }; 312*89c4ff92SAndroid Build Coastguard Worker 313*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest") 314*89c4ff92SAndroid Build Coastguard Worker { 315*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 316*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}}, 317*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f, 318*89c4ff92SAndroid Build Coastguard Worker 12.535f, 38.57f, 64.605f, 90.64f, 116.675f, 319*89c4ff92SAndroid Build Coastguard Worker 10.035f, 32.07f, 54.105f, 76.14f, 98.175f }}}); 320*89c4ff92SAndroid Build Coastguard Worker } 321*89c4ff92SAndroid Build Coastguard Worker 322*89c4ff92SAndroid Build Coastguard Worker struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 323*89c4ff92SAndroid Build Coastguard Worker { GemmConstantSimpleFixtureGemmConstantSimpleFixture324*89c4ff92SAndroid Build Coastguard Worker GemmConstantSimpleFixture() 325*89c4ff92SAndroid Build Coastguard Worker { 326*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 327*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 328*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 329*89c4ff92SAndroid Build Coastguard Worker graph { 330*89c4ff92SAndroid Build Coastguard Worker node { 331*89c4ff92SAndroid Build Coastguard Worker input: "A" 332*89c4ff92SAndroid Build Coastguard Worker input: "B" 333*89c4ff92SAndroid Build Coastguard Worker input: "C" 334*89c4ff92SAndroid Build Coastguard Worker output: "Output" 335*89c4ff92SAndroid Build Coastguard Worker op_type: "Gemm" 336*89c4ff92SAndroid Build Coastguard Worker attribute { 337*89c4ff92SAndroid Build Coastguard Worker name: "alpha" 338*89c4ff92SAndroid Build Coastguard Worker f: 1 339*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 340*89c4ff92SAndroid Build Coastguard Worker } 341*89c4ff92SAndroid Build Coastguard Worker attribute { 342*89c4ff92SAndroid Build Coastguard Worker name: "beta" 343*89c4ff92SAndroid Build Coastguard Worker f: 1 344*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 345*89c4ff92SAndroid Build Coastguard Worker } 346*89c4ff92SAndroid Build Coastguard Worker attribute { 347*89c4ff92SAndroid Build Coastguard Worker name: "transA" 348*89c4ff92SAndroid Build Coastguard Worker i: 0 349*89c4ff92SAndroid Build Coastguard Worker type: INT 350*89c4ff92SAndroid Build Coastguard Worker } 351*89c4ff92SAndroid Build Coastguard Worker attribute { 352*89c4ff92SAndroid Build Coastguard Worker name: "transB" 353*89c4ff92SAndroid Build Coastguard Worker i: 0 354*89c4ff92SAndroid Build Coastguard Worker type: INT 355*89c4ff92SAndroid Build Coastguard Worker } 356*89c4ff92SAndroid Build Coastguard Worker } 357*89c4ff92SAndroid Build Coastguard Worker name: "gem-model" 358*89c4ff92SAndroid Build Coastguard Worker initializer { 359*89c4ff92SAndroid Build Coastguard Worker dims: 4 360*89c4ff92SAndroid Build Coastguard Worker dims: 5 361*89c4ff92SAndroid Build Coastguard Worker data_type: 1 362*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 363*89c4ff92SAndroid Build Coastguard Worker float_data: 2.0 364*89c4ff92SAndroid Build Coastguard Worker float_data: 3.0 365*89c4ff92SAndroid Build Coastguard Worker float_data: 4.0 366*89c4ff92SAndroid Build Coastguard Worker float_data: 5.0 367*89c4ff92SAndroid Build Coastguard Worker float_data: 6.0 368*89c4ff92SAndroid Build Coastguard Worker float_data: 7.0 369*89c4ff92SAndroid Build Coastguard Worker float_data: 8.0 370*89c4ff92SAndroid Build Coastguard Worker float_data: 9.0 371*89c4ff92SAndroid Build Coastguard Worker float_data: 10.0 372*89c4ff92SAndroid Build Coastguard Worker float_data: 11.0 373*89c4ff92SAndroid Build Coastguard Worker float_data: 12.0 374*89c4ff92SAndroid Build Coastguard Worker float_data: 13.0 375*89c4ff92SAndroid Build Coastguard Worker float_data: 14.0 376*89c4ff92SAndroid Build Coastguard Worker float_data: 15.0 377*89c4ff92SAndroid Build Coastguard Worker float_data: 16.0 378*89c4ff92SAndroid Build Coastguard Worker float_data: 17.0 379*89c4ff92SAndroid Build Coastguard Worker float_data: 18.0 380*89c4ff92SAndroid Build Coastguard Worker float_data: 19.0 381*89c4ff92SAndroid Build Coastguard Worker float_data: 20.0 382*89c4ff92SAndroid Build Coastguard Worker name: "B" 383*89c4ff92SAndroid Build Coastguard Worker } 384*89c4ff92SAndroid Build Coastguard Worker initializer { 385*89c4ff92SAndroid Build Coastguard Worker dims: 1 386*89c4ff92SAndroid Build Coastguard Worker dims: 5 387*89c4ff92SAndroid Build Coastguard Worker data_type: 1 388*89c4ff92SAndroid Build Coastguard Worker float_data: 0.1 389*89c4ff92SAndroid Build Coastguard Worker float_data: 0.2 390*89c4ff92SAndroid Build Coastguard Worker float_data: 0.3 391*89c4ff92SAndroid Build Coastguard Worker float_data: 0.4 392*89c4ff92SAndroid Build Coastguard Worker float_data: 0.5 393*89c4ff92SAndroid Build Coastguard Worker name: "C" 394*89c4ff92SAndroid Build Coastguard Worker } 395*89c4ff92SAndroid Build Coastguard Worker input { 396*89c4ff92SAndroid Build Coastguard Worker name: "A" 397*89c4ff92SAndroid Build Coastguard Worker type { 398*89c4ff92SAndroid Build Coastguard Worker tensor_type { 399*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 400*89c4ff92SAndroid Build Coastguard Worker shape { 401*89c4ff92SAndroid Build Coastguard Worker dim { 402*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 403*89c4ff92SAndroid Build Coastguard Worker } 404*89c4ff92SAndroid Build Coastguard Worker dim { 405*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 406*89c4ff92SAndroid Build Coastguard Worker } 407*89c4ff92SAndroid Build Coastguard Worker } 408*89c4ff92SAndroid Build Coastguard Worker } 409*89c4ff92SAndroid Build Coastguard Worker } 410*89c4ff92SAndroid Build Coastguard Worker } 411*89c4ff92SAndroid Build Coastguard Worker output { 412*89c4ff92SAndroid Build Coastguard Worker name: "Output" 413*89c4ff92SAndroid Build Coastguard Worker type { 414*89c4ff92SAndroid Build Coastguard Worker tensor_type { 415*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 416*89c4ff92SAndroid Build Coastguard Worker shape { 417*89c4ff92SAndroid Build Coastguard Worker dim { 418*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 419*89c4ff92SAndroid Build Coastguard Worker } 420*89c4ff92SAndroid Build Coastguard Worker dim { 421*89c4ff92SAndroid Build Coastguard Worker dim_value: 5 422*89c4ff92SAndroid Build Coastguard Worker } 423*89c4ff92SAndroid Build Coastguard Worker } 424*89c4ff92SAndroid Build Coastguard Worker } 425*89c4ff92SAndroid Build Coastguard Worker } 426*89c4ff92SAndroid Build Coastguard Worker } 427*89c4ff92SAndroid Build Coastguard Worker })"; 428*89c4ff92SAndroid Build Coastguard Worker Setup(); 429*89c4ff92SAndroid Build Coastguard Worker } 430*89c4ff92SAndroid Build Coastguard Worker }; 431*89c4ff92SAndroid Build Coastguard Worker 432*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest") 433*89c4ff92SAndroid Build Coastguard Worker { 434*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 435*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}}, 436*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f, 437*89c4ff92SAndroid Build Coastguard Worker 196.1f, 222.2f, 248.3f, 274.4f, 300.5f, 438*89c4ff92SAndroid Build Coastguard Worker 60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}}); 439*89c4ff92SAndroid Build Coastguard Worker } 440*89c4ff92SAndroid Build Coastguard Worker 441*89c4ff92SAndroid Build Coastguard Worker struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 442*89c4ff92SAndroid Build Coastguard Worker { GemmABFixtureGemmABFixture443*89c4ff92SAndroid Build Coastguard Worker GemmABFixture(const std::string& alpha, 444*89c4ff92SAndroid Build Coastguard Worker const std::string& beta, 445*89c4ff92SAndroid Build Coastguard Worker const std::string& transA, 446*89c4ff92SAndroid Build Coastguard Worker const std::string& transB, 447*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputAShape, 448*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputBShape, 449*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& outputShape) 450*89c4ff92SAndroid Build Coastguard Worker { 451*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 452*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 453*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 454*89c4ff92SAndroid Build Coastguard Worker graph { 455*89c4ff92SAndroid Build Coastguard Worker node { 456*89c4ff92SAndroid Build Coastguard Worker input: "A" 457*89c4ff92SAndroid Build Coastguard Worker input: "B" 458*89c4ff92SAndroid Build Coastguard Worker output: "Output" 459*89c4ff92SAndroid Build Coastguard Worker op_type: "Gemm" 460*89c4ff92SAndroid Build Coastguard Worker attribute { 461*89c4ff92SAndroid Build Coastguard Worker name: "alpha" 462*89c4ff92SAndroid Build Coastguard Worker f: )" + alpha + R"( 463*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 464*89c4ff92SAndroid Build Coastguard Worker } 465*89c4ff92SAndroid Build Coastguard Worker attribute { 466*89c4ff92SAndroid Build Coastguard Worker name: "beta" 467*89c4ff92SAndroid Build Coastguard Worker f: )" + beta + R"( 468*89c4ff92SAndroid Build Coastguard Worker type: FLOAT 469*89c4ff92SAndroid Build Coastguard Worker } 470*89c4ff92SAndroid Build Coastguard Worker attribute { 471*89c4ff92SAndroid Build Coastguard Worker name: "transA" 472*89c4ff92SAndroid Build Coastguard Worker i: )" + transA + R"( 473*89c4ff92SAndroid Build Coastguard Worker type: INT 474*89c4ff92SAndroid Build Coastguard Worker } 475*89c4ff92SAndroid Build Coastguard Worker attribute { 476*89c4ff92SAndroid Build Coastguard Worker name: "transB" 477*89c4ff92SAndroid Build Coastguard Worker i: )" + transB + R"( 478*89c4ff92SAndroid Build Coastguard Worker type: INT 479*89c4ff92SAndroid Build Coastguard Worker } 480*89c4ff92SAndroid Build Coastguard Worker } 481*89c4ff92SAndroid Build Coastguard Worker name: "gem-model" 482*89c4ff92SAndroid Build Coastguard Worker input { 483*89c4ff92SAndroid Build Coastguard Worker name: "A" 484*89c4ff92SAndroid Build Coastguard Worker type { 485*89c4ff92SAndroid Build Coastguard Worker tensor_type { 486*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 487*89c4ff92SAndroid Build Coastguard Worker shape { 488*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"( 489*89c4ff92SAndroid Build Coastguard Worker } 490*89c4ff92SAndroid Build Coastguard Worker } 491*89c4ff92SAndroid Build Coastguard Worker } 492*89c4ff92SAndroid Build Coastguard Worker } 493*89c4ff92SAndroid Build Coastguard Worker input { 494*89c4ff92SAndroid Build Coastguard Worker name: "B" 495*89c4ff92SAndroid Build Coastguard Worker type { 496*89c4ff92SAndroid Build Coastguard Worker tensor_type { 497*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 498*89c4ff92SAndroid Build Coastguard Worker shape { 499*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"( 500*89c4ff92SAndroid Build Coastguard Worker } 501*89c4ff92SAndroid Build Coastguard Worker } 502*89c4ff92SAndroid Build Coastguard Worker } 503*89c4ff92SAndroid Build Coastguard Worker } 504*89c4ff92SAndroid Build Coastguard Worker output { 505*89c4ff92SAndroid Build Coastguard Worker name: "Output" 506*89c4ff92SAndroid Build Coastguard Worker type { 507*89c4ff92SAndroid Build Coastguard Worker tensor_type { 508*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 509*89c4ff92SAndroid Build Coastguard Worker shape { 510*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 511*89c4ff92SAndroid Build Coastguard Worker } 512*89c4ff92SAndroid Build Coastguard Worker } 513*89c4ff92SAndroid Build Coastguard Worker } 514*89c4ff92SAndroid Build Coastguard Worker } 515*89c4ff92SAndroid Build Coastguard Worker })"; 516*89c4ff92SAndroid Build Coastguard Worker Setup(); 517*89c4ff92SAndroid Build Coastguard Worker } 518*89c4ff92SAndroid Build Coastguard Worker }; 519*89c4ff92SAndroid Build Coastguard Worker 520*89c4ff92SAndroid Build Coastguard Worker struct GemmAlphaTransAFixture : GemmABFixture 521*89c4ff92SAndroid Build Coastguard Worker { GemmAlphaTransAFixtureGemmAlphaTransAFixture522*89c4ff92SAndroid Build Coastguard Worker GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {} 523*89c4ff92SAndroid Build Coastguard Worker }; 524*89c4ff92SAndroid Build Coastguard Worker 525*89c4ff92SAndroid Build Coastguard Worker struct GemmAlphaTransBFixture : GemmABFixture 526*89c4ff92SAndroid Build Coastguard Worker { GemmAlphaTransBFixtureGemmAlphaTransBFixture527*89c4ff92SAndroid Build Coastguard Worker GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {} 528*89c4ff92SAndroid Build Coastguard Worker }; 529*89c4ff92SAndroid Build Coastguard Worker 530*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest") 531*89c4ff92SAndroid Build Coastguard Worker { 532*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 533*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 534*89c4ff92SAndroid Build Coastguard Worker {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 535*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 536*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 537*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}}, 538*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f, 539*89c4ff92SAndroid Build Coastguard Worker 36.5f, 43.0f, 49.5f, 56.0f, 62.5f, 540*89c4ff92SAndroid Build Coastguard Worker 28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}}); 541*89c4ff92SAndroid Build Coastguard Worker } 542*89c4ff92SAndroid Build Coastguard Worker 543*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest") 544*89c4ff92SAndroid Build Coastguard Worker { 545*89c4ff92SAndroid Build Coastguard Worker RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 546*89c4ff92SAndroid Build Coastguard Worker 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}, 547*89c4ff92SAndroid Build Coastguard Worker {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 548*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 549*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 550*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}}, 551*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f, 552*89c4ff92SAndroid Build Coastguard Worker 15.0f, 41.0f, 67.0f, 93.0f, 119.0f, 553*89c4ff92SAndroid Build Coastguard Worker 5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}}); 554*89c4ff92SAndroid Build Coastguard Worker } 555*89c4ff92SAndroid Build Coastguard Worker 556*89c4ff92SAndroid Build Coastguard Worker } 557