1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_FullyConnected") 10*89c4ff92SAndroid Build Coastguard Worker { 11*89c4ff92SAndroid Build Coastguard Worker // A MatMul in isolation, not connected to an add. Should result in a non-biased FullyConnected layer. 12*89c4ff92SAndroid Build Coastguard Worker struct MatMulFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 13*89c4ff92SAndroid Build Coastguard Worker { MatMulFixtureMatMulFixture14*89c4ff92SAndroid Build Coastguard Worker MatMulFixture() 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 17*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 18*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK " 19*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1 " 20*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk " 21*89c4ff92SAndroid Build Coastguard Worker model_version: 1 22*89c4ff92SAndroid Build Coastguard Worker graph { 23*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph " 24*89c4ff92SAndroid Build Coastguard Worker input { 25*89c4ff92SAndroid Build Coastguard Worker name: "Input" 26*89c4ff92SAndroid Build Coastguard Worker type { 27*89c4ff92SAndroid Build Coastguard Worker tensor_type { 28*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 29*89c4ff92SAndroid Build Coastguard Worker shape { 30*89c4ff92SAndroid Build Coastguard Worker dim { 31*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 32*89c4ff92SAndroid Build Coastguard Worker } 33*89c4ff92SAndroid Build Coastguard Worker dim { 34*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker } 37*89c4ff92SAndroid Build Coastguard Worker } 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker } 40*89c4ff92SAndroid Build Coastguard Worker input { 41*89c4ff92SAndroid Build Coastguard Worker name: "Const" 42*89c4ff92SAndroid Build Coastguard Worker type { 43*89c4ff92SAndroid Build Coastguard Worker tensor_type { 44*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 45*89c4ff92SAndroid Build Coastguard Worker shape { 46*89c4ff92SAndroid Build Coastguard Worker dim { 47*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 48*89c4ff92SAndroid Build Coastguard Worker } 49*89c4ff92SAndroid Build Coastguard Worker dim { 50*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker } 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker initializer { 57*89c4ff92SAndroid Build Coastguard Worker dims: 1 58*89c4ff92SAndroid Build Coastguard Worker dims: 1 59*89c4ff92SAndroid Build Coastguard Worker data_type: 1 60*89c4ff92SAndroid Build Coastguard Worker float_data: 17.0 61*89c4ff92SAndroid Build Coastguard Worker name: "Const" 62*89c4ff92SAndroid Build Coastguard Worker } 63*89c4ff92SAndroid Build Coastguard Worker node { 64*89c4ff92SAndroid Build Coastguard Worker input: "Input" 65*89c4ff92SAndroid Build Coastguard Worker input: "Const" 66*89c4ff92SAndroid Build Coastguard Worker output: "Output" 67*89c4ff92SAndroid Build Coastguard Worker name: "SimpleMatmul" 68*89c4ff92SAndroid Build Coastguard Worker op_type: "MatMul" 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker output { 71*89c4ff92SAndroid Build Coastguard Worker name: "Output" 72*89c4ff92SAndroid Build Coastguard Worker type { 73*89c4ff92SAndroid Build Coastguard Worker tensor_type { 74*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 75*89c4ff92SAndroid Build Coastguard Worker shape { 76*89c4ff92SAndroid Build Coastguard Worker dim { 77*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 78*89c4ff92SAndroid Build Coastguard Worker } 79*89c4ff92SAndroid Build Coastguard Worker dim { 80*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker } 85*89c4ff92SAndroid Build Coastguard Worker } 86*89c4ff92SAndroid Build Coastguard Worker } 87*89c4ff92SAndroid Build Coastguard Worker opset_import { 88*89c4ff92SAndroid Build Coastguard Worker version: 7 89*89c4ff92SAndroid Build Coastguard Worker })"; 90*89c4ff92SAndroid Build Coastguard Worker 91*89c4ff92SAndroid Build Coastguard Worker Setup(); 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker }; 94*89c4ff92SAndroid Build Coastguard Worker 95*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MatMulFixture, "MatMul") 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker RunTest<1>({{"Input", { 2 }}}, {{"Output", { 34 }}}); 98*89c4ff92SAndroid Build Coastguard Worker } 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker // In Onnx fully connected layers are expressed as a MatMul followed by an Add. 101*89c4ff92SAndroid Build Coastguard Worker // The OnnxParser must detect this case and convert them to a FullyConnected layer. 102*89c4ff92SAndroid Build Coastguard Worker struct FullyConnectedFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 103*89c4ff92SAndroid Build Coastguard Worker { FullyConnectedFixtureFullyConnectedFixture104*89c4ff92SAndroid Build Coastguard Worker FullyConnectedFixture() 105*89c4ff92SAndroid Build Coastguard Worker { 106*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 107*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 108*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK " 109*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1 " 110*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk " 111*89c4ff92SAndroid Build Coastguard Worker model_version: 1 112*89c4ff92SAndroid Build Coastguard Worker graph { 113*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph " 114*89c4ff92SAndroid Build Coastguard Worker input { 115*89c4ff92SAndroid Build Coastguard Worker name: "Input" 116*89c4ff92SAndroid Build Coastguard Worker type { 117*89c4ff92SAndroid Build Coastguard Worker tensor_type { 118*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 119*89c4ff92SAndroid Build Coastguard Worker shape { 120*89c4ff92SAndroid Build Coastguard Worker dim { 121*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 122*89c4ff92SAndroid Build Coastguard Worker } 123*89c4ff92SAndroid Build Coastguard Worker dim { 124*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 125*89c4ff92SAndroid Build Coastguard Worker } 126*89c4ff92SAndroid Build Coastguard Worker } 127*89c4ff92SAndroid Build Coastguard Worker } 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker } 130*89c4ff92SAndroid Build Coastguard Worker input { 131*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 132*89c4ff92SAndroid Build Coastguard Worker type { 133*89c4ff92SAndroid Build Coastguard Worker tensor_type { 134*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 135*89c4ff92SAndroid Build Coastguard Worker shape { 136*89c4ff92SAndroid Build Coastguard Worker dim { 137*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 138*89c4ff92SAndroid Build Coastguard Worker } 139*89c4ff92SAndroid Build Coastguard Worker dim { 140*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 141*89c4ff92SAndroid Build Coastguard Worker } 142*89c4ff92SAndroid Build Coastguard Worker } 143*89c4ff92SAndroid Build Coastguard Worker } 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker } 146*89c4ff92SAndroid Build Coastguard Worker initializer { 147*89c4ff92SAndroid Build Coastguard Worker dims: 1 148*89c4ff92SAndroid Build Coastguard Worker dims: 1 149*89c4ff92SAndroid Build Coastguard Worker data_type: 1 150*89c4ff92SAndroid Build Coastguard Worker float_data: 2 151*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 152*89c4ff92SAndroid Build Coastguard Worker } 153*89c4ff92SAndroid Build Coastguard Worker input { 154*89c4ff92SAndroid Build Coastguard Worker name: "Bias" 155*89c4ff92SAndroid Build Coastguard Worker type { 156*89c4ff92SAndroid Build Coastguard Worker tensor_type { 157*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 158*89c4ff92SAndroid Build Coastguard Worker shape { 159*89c4ff92SAndroid Build Coastguard Worker dim { 160*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker } 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker } 165*89c4ff92SAndroid Build Coastguard Worker } 166*89c4ff92SAndroid Build Coastguard Worker initializer { 167*89c4ff92SAndroid Build Coastguard Worker dims: 1 168*89c4ff92SAndroid Build Coastguard Worker data_type: 1 169*89c4ff92SAndroid Build Coastguard Worker float_data: 1 170*89c4ff92SAndroid Build Coastguard Worker name: "Bias" 171*89c4ff92SAndroid Build Coastguard Worker } 172*89c4ff92SAndroid Build Coastguard Worker node { 173*89c4ff92SAndroid Build Coastguard Worker input: "Input" 174*89c4ff92SAndroid Build Coastguard Worker input: "Weight" 175*89c4ff92SAndroid Build Coastguard Worker output: "AddInput" 176*89c4ff92SAndroid Build Coastguard Worker name: "FCMatmul" 177*89c4ff92SAndroid Build Coastguard Worker op_type: "MatMul" 178*89c4ff92SAndroid Build Coastguard Worker } 179*89c4ff92SAndroid Build Coastguard Worker node { 180*89c4ff92SAndroid Build Coastguard Worker input: "AddInput" 181*89c4ff92SAndroid Build Coastguard Worker input: "Bias" 182*89c4ff92SAndroid Build Coastguard Worker output: "Output" 183*89c4ff92SAndroid Build Coastguard Worker name: "FCAdd" 184*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 185*89c4ff92SAndroid Build Coastguard Worker } 186*89c4ff92SAndroid Build Coastguard Worker value_info { 187*89c4ff92SAndroid Build Coastguard Worker name: "AddInput" 188*89c4ff92SAndroid Build Coastguard Worker type { 189*89c4ff92SAndroid Build Coastguard Worker tensor_type { 190*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 191*89c4ff92SAndroid Build Coastguard Worker shape { 192*89c4ff92SAndroid Build Coastguard Worker dim { 193*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 194*89c4ff92SAndroid Build Coastguard Worker } 195*89c4ff92SAndroid Build Coastguard Worker dim { 196*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 197*89c4ff92SAndroid Build Coastguard Worker } 198*89c4ff92SAndroid Build Coastguard Worker } 199*89c4ff92SAndroid Build Coastguard Worker } 200*89c4ff92SAndroid Build Coastguard Worker } 201*89c4ff92SAndroid Build Coastguard Worker } 202*89c4ff92SAndroid Build Coastguard Worker output { 203*89c4ff92SAndroid Build Coastguard Worker name: "Output" 204*89c4ff92SAndroid Build Coastguard Worker type { 205*89c4ff92SAndroid Build Coastguard Worker tensor_type { 206*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 207*89c4ff92SAndroid Build Coastguard Worker shape { 208*89c4ff92SAndroid Build Coastguard Worker dim { 209*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 210*89c4ff92SAndroid Build Coastguard Worker } 211*89c4ff92SAndroid Build Coastguard Worker dim { 212*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 213*89c4ff92SAndroid Build Coastguard Worker } 214*89c4ff92SAndroid Build Coastguard Worker } 215*89c4ff92SAndroid Build Coastguard Worker } 216*89c4ff92SAndroid Build Coastguard Worker } 217*89c4ff92SAndroid Build Coastguard Worker } 218*89c4ff92SAndroid Build Coastguard Worker } 219*89c4ff92SAndroid Build Coastguard Worker opset_import { 220*89c4ff92SAndroid Build Coastguard Worker version: 7 221*89c4ff92SAndroid Build Coastguard Worker })"; 222*89c4ff92SAndroid Build Coastguard Worker 223*89c4ff92SAndroid Build Coastguard Worker Setup(); 224*89c4ff92SAndroid Build Coastguard Worker } 225*89c4ff92SAndroid Build Coastguard Worker }; 226*89c4ff92SAndroid Build Coastguard Worker 227*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(FullyConnectedFixture, "FullyConnected") 228*89c4ff92SAndroid Build Coastguard Worker { 229*89c4ff92SAndroid Build Coastguard Worker RunTest<1>({{"Input", { 3 }}}, {{"Output", { 7 }}}); 230*89c4ff92SAndroid Build Coastguard Worker } 231*89c4ff92SAndroid Build Coastguard Worker 232*89c4ff92SAndroid Build Coastguard Worker 233*89c4ff92SAndroid Build Coastguard Worker // Similar to FullyConnectedFixture, but this time the MatMul's output is used by two Adds. This should result 234*89c4ff92SAndroid Build Coastguard Worker // in two FullyConnected layers being created. 235*89c4ff92SAndroid Build Coastguard Worker // I 236*89c4ff92SAndroid Build Coastguard Worker // | 237*89c4ff92SAndroid Build Coastguard Worker // M -- C 238*89c4ff92SAndroid Build Coastguard Worker // / \' 239*89c4ff92SAndroid Build Coastguard Worker // C-- A A -- C 240*89c4ff92SAndroid Build Coastguard Worker // \ / 241*89c4ff92SAndroid Build Coastguard Worker // A 242*89c4ff92SAndroid Build Coastguard Worker struct MatMulUsedInTwoFcFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 243*89c4ff92SAndroid Build Coastguard Worker { MatMulUsedInTwoFcFixtureMatMulUsedInTwoFcFixture244*89c4ff92SAndroid Build Coastguard Worker MatMulUsedInTwoFcFixture() 245*89c4ff92SAndroid Build Coastguard Worker { 246*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 247*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 248*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK " 249*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1 " 250*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk " 251*89c4ff92SAndroid Build Coastguard Worker model_version: 1 252*89c4ff92SAndroid Build Coastguard Worker graph { 253*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph " 254*89c4ff92SAndroid Build Coastguard Worker input { 255*89c4ff92SAndroid Build Coastguard Worker name: "Input" 256*89c4ff92SAndroid Build Coastguard Worker type { 257*89c4ff92SAndroid Build Coastguard Worker tensor_type { 258*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 259*89c4ff92SAndroid Build Coastguard Worker shape { 260*89c4ff92SAndroid Build Coastguard Worker dim { 261*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 262*89c4ff92SAndroid Build Coastguard Worker } 263*89c4ff92SAndroid Build Coastguard Worker dim { 264*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 265*89c4ff92SAndroid Build Coastguard Worker } 266*89c4ff92SAndroid Build Coastguard Worker } 267*89c4ff92SAndroid Build Coastguard Worker } 268*89c4ff92SAndroid Build Coastguard Worker } 269*89c4ff92SAndroid Build Coastguard Worker } 270*89c4ff92SAndroid Build Coastguard Worker input { 271*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 272*89c4ff92SAndroid Build Coastguard Worker type { 273*89c4ff92SAndroid Build Coastguard Worker tensor_type { 274*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 275*89c4ff92SAndroid Build Coastguard Worker shape { 276*89c4ff92SAndroid Build Coastguard Worker dim { 277*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 278*89c4ff92SAndroid Build Coastguard Worker } 279*89c4ff92SAndroid Build Coastguard Worker dim { 280*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 281*89c4ff92SAndroid Build Coastguard Worker } 282*89c4ff92SAndroid Build Coastguard Worker } 283*89c4ff92SAndroid Build Coastguard Worker } 284*89c4ff92SAndroid Build Coastguard Worker } 285*89c4ff92SAndroid Build Coastguard Worker } 286*89c4ff92SAndroid Build Coastguard Worker initializer { 287*89c4ff92SAndroid Build Coastguard Worker dims: 1 288*89c4ff92SAndroid Build Coastguard Worker dims: 1 289*89c4ff92SAndroid Build Coastguard Worker data_type: 1 290*89c4ff92SAndroid Build Coastguard Worker float_data: 2 291*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 292*89c4ff92SAndroid Build Coastguard Worker } 293*89c4ff92SAndroid Build Coastguard Worker input { 294*89c4ff92SAndroid Build Coastguard Worker name: "Bias" 295*89c4ff92SAndroid Build Coastguard Worker type { 296*89c4ff92SAndroid Build Coastguard Worker tensor_type { 297*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 298*89c4ff92SAndroid Build Coastguard Worker shape { 299*89c4ff92SAndroid Build Coastguard Worker dim { 300*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 301*89c4ff92SAndroid Build Coastguard Worker } 302*89c4ff92SAndroid Build Coastguard Worker } 303*89c4ff92SAndroid Build Coastguard Worker } 304*89c4ff92SAndroid Build Coastguard Worker } 305*89c4ff92SAndroid Build Coastguard Worker } 306*89c4ff92SAndroid Build Coastguard Worker initializer { 307*89c4ff92SAndroid Build Coastguard Worker dims: 1 308*89c4ff92SAndroid Build Coastguard Worker data_type: 1 309*89c4ff92SAndroid Build Coastguard Worker float_data: 1 310*89c4ff92SAndroid Build Coastguard Worker name: "Bias" 311*89c4ff92SAndroid Build Coastguard Worker } 312*89c4ff92SAndroid Build Coastguard Worker input { 313*89c4ff92SAndroid Build Coastguard Worker name: "Bias_1" 314*89c4ff92SAndroid Build Coastguard Worker type { 315*89c4ff92SAndroid Build Coastguard Worker tensor_type { 316*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 317*89c4ff92SAndroid Build Coastguard Worker shape { 318*89c4ff92SAndroid Build Coastguard Worker dim { 319*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 320*89c4ff92SAndroid Build Coastguard Worker } 321*89c4ff92SAndroid Build Coastguard Worker } 322*89c4ff92SAndroid Build Coastguard Worker } 323*89c4ff92SAndroid Build Coastguard Worker } 324*89c4ff92SAndroid Build Coastguard Worker } 325*89c4ff92SAndroid Build Coastguard Worker initializer { 326*89c4ff92SAndroid Build Coastguard Worker dims: 1 327*89c4ff92SAndroid Build Coastguard Worker data_type: 1 328*89c4ff92SAndroid Build Coastguard Worker float_data: 10.0 329*89c4ff92SAndroid Build Coastguard Worker name: "Bias_1" 330*89c4ff92SAndroid Build Coastguard Worker } 331*89c4ff92SAndroid Build Coastguard Worker node { 332*89c4ff92SAndroid Build Coastguard Worker input: "Input" 333*89c4ff92SAndroid Build Coastguard Worker input: "Weight" 334*89c4ff92SAndroid Build Coastguard Worker output: "AddInput" 335*89c4ff92SAndroid Build Coastguard Worker name: "FCMatmul" 336*89c4ff92SAndroid Build Coastguard Worker op_type: "MatMul" 337*89c4ff92SAndroid Build Coastguard Worker } 338*89c4ff92SAndroid Build Coastguard Worker node { 339*89c4ff92SAndroid Build Coastguard Worker input: "AddInput" 340*89c4ff92SAndroid Build Coastguard Worker input: "Bias" 341*89c4ff92SAndroid Build Coastguard Worker output: "AddOutput" 342*89c4ff92SAndroid Build Coastguard Worker name: "FCAdd" 343*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 344*89c4ff92SAndroid Build Coastguard Worker } 345*89c4ff92SAndroid Build Coastguard Worker node { 346*89c4ff92SAndroid Build Coastguard Worker input: "AddInput" 347*89c4ff92SAndroid Build Coastguard Worker input: "Bias_1" 348*89c4ff92SAndroid Build Coastguard Worker output: "AddOutput_1" 349*89c4ff92SAndroid Build Coastguard Worker name: "FCAdd_1" 350*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 351*89c4ff92SAndroid Build Coastguard Worker } 352*89c4ff92SAndroid Build Coastguard Worker node { 353*89c4ff92SAndroid Build Coastguard Worker input: "AddOutput" 354*89c4ff92SAndroid Build Coastguard Worker input: "AddOutput_1" 355*89c4ff92SAndroid Build Coastguard Worker output: "Output" 356*89c4ff92SAndroid Build Coastguard Worker name: "FinalAdd" 357*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 358*89c4ff92SAndroid Build Coastguard Worker } 359*89c4ff92SAndroid Build Coastguard Worker value_info { 360*89c4ff92SAndroid Build Coastguard Worker name: "AddInput" 361*89c4ff92SAndroid Build Coastguard Worker type { 362*89c4ff92SAndroid Build Coastguard Worker tensor_type { 363*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 364*89c4ff92SAndroid Build Coastguard Worker shape { 365*89c4ff92SAndroid Build Coastguard Worker dim { 366*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 367*89c4ff92SAndroid Build Coastguard Worker } 368*89c4ff92SAndroid Build Coastguard Worker dim { 369*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 370*89c4ff92SAndroid Build Coastguard Worker } 371*89c4ff92SAndroid Build Coastguard Worker } 372*89c4ff92SAndroid Build Coastguard Worker } 373*89c4ff92SAndroid Build Coastguard Worker } 374*89c4ff92SAndroid Build Coastguard Worker } 375*89c4ff92SAndroid Build Coastguard Worker value_info { 376*89c4ff92SAndroid Build Coastguard Worker name: "AddOutput" 377*89c4ff92SAndroid Build Coastguard Worker type { 378*89c4ff92SAndroid Build Coastguard Worker tensor_type { 379*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 380*89c4ff92SAndroid Build Coastguard Worker shape { 381*89c4ff92SAndroid Build Coastguard Worker dim { 382*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 383*89c4ff92SAndroid Build Coastguard Worker } 384*89c4ff92SAndroid Build Coastguard Worker dim { 385*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 386*89c4ff92SAndroid Build Coastguard Worker } 387*89c4ff92SAndroid Build Coastguard Worker } 388*89c4ff92SAndroid Build Coastguard Worker } 389*89c4ff92SAndroid Build Coastguard Worker } 390*89c4ff92SAndroid Build Coastguard Worker } 391*89c4ff92SAndroid Build Coastguard Worker value_info { 392*89c4ff92SAndroid Build Coastguard Worker name: "AddOutput_1" 393*89c4ff92SAndroid Build Coastguard Worker type { 394*89c4ff92SAndroid Build Coastguard Worker tensor_type { 395*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 396*89c4ff92SAndroid Build Coastguard Worker shape { 397*89c4ff92SAndroid Build Coastguard Worker dim { 398*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 399*89c4ff92SAndroid Build Coastguard Worker } 400*89c4ff92SAndroid Build Coastguard Worker dim { 401*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 402*89c4ff92SAndroid Build Coastguard Worker } 403*89c4ff92SAndroid Build Coastguard Worker } 404*89c4ff92SAndroid Build Coastguard Worker } 405*89c4ff92SAndroid Build Coastguard Worker } 406*89c4ff92SAndroid Build Coastguard Worker } 407*89c4ff92SAndroid Build Coastguard Worker output { 408*89c4ff92SAndroid Build Coastguard Worker name: "Output" 409*89c4ff92SAndroid Build Coastguard Worker type { 410*89c4ff92SAndroid Build Coastguard Worker tensor_type { 411*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 412*89c4ff92SAndroid Build Coastguard Worker shape { 413*89c4ff92SAndroid Build Coastguard Worker dim { 414*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 415*89c4ff92SAndroid Build Coastguard Worker } 416*89c4ff92SAndroid Build Coastguard Worker dim { 417*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 418*89c4ff92SAndroid Build Coastguard Worker } 419*89c4ff92SAndroid Build Coastguard Worker } 420*89c4ff92SAndroid Build Coastguard Worker } 421*89c4ff92SAndroid Build Coastguard Worker } 422*89c4ff92SAndroid Build Coastguard Worker } 423*89c4ff92SAndroid Build Coastguard Worker } 424*89c4ff92SAndroid Build Coastguard Worker opset_import { 425*89c4ff92SAndroid Build Coastguard Worker version: 7 426*89c4ff92SAndroid Build Coastguard Worker })"; 427*89c4ff92SAndroid Build Coastguard Worker 428*89c4ff92SAndroid Build Coastguard Worker Setup(); 429*89c4ff92SAndroid Build Coastguard Worker } 430*89c4ff92SAndroid Build Coastguard Worker }; 431*89c4ff92SAndroid Build Coastguard Worker 432*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MatMulUsedInTwoFcFixture, "MatMulUsedInTwoFc") 433*89c4ff92SAndroid Build Coastguard Worker { 434*89c4ff92SAndroid Build Coastguard Worker RunTest<1>({{"Input", { 3 }}}, {{"Output", { 23 }}}); 435*89c4ff92SAndroid Build Coastguard Worker } 436*89c4ff92SAndroid Build Coastguard Worker 437*89c4ff92SAndroid Build Coastguard Worker 438*89c4ff92SAndroid Build Coastguard Worker // Similar to MatMulUsedInTwoFc, but this time the Adds are 'staggered' (see diagram), which means that only one 439*89c4ff92SAndroid Build Coastguard Worker // FullyConnected layer can be created (the other should just be an Add). 440*89c4ff92SAndroid Build Coastguard Worker // I 441*89c4ff92SAndroid Build Coastguard Worker // | 442*89c4ff92SAndroid Build Coastguard Worker // M -- C1 443*89c4ff92SAndroid Build Coastguard Worker // / \' 444*89c4ff92SAndroid Build Coastguard Worker // C2 -- A | 445*89c4ff92SAndroid Build Coastguard Worker // \ / 446*89c4ff92SAndroid Build Coastguard Worker // A 447*89c4ff92SAndroid Build Coastguard Worker struct MatMulUsedInTwoFcStaggeredFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 448*89c4ff92SAndroid Build Coastguard Worker { MatMulUsedInTwoFcStaggeredFixtureMatMulUsedInTwoFcStaggeredFixture449*89c4ff92SAndroid Build Coastguard Worker MatMulUsedInTwoFcStaggeredFixture() 450*89c4ff92SAndroid Build Coastguard Worker { 451*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 452*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 453*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK " 454*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1 " 455*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk " 456*89c4ff92SAndroid Build Coastguard Worker model_version: 1 457*89c4ff92SAndroid Build Coastguard Worker graph { 458*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph " 459*89c4ff92SAndroid Build Coastguard Worker input { 460*89c4ff92SAndroid Build Coastguard Worker name: "Input" 461*89c4ff92SAndroid Build Coastguard Worker type { 462*89c4ff92SAndroid Build Coastguard Worker tensor_type { 463*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 464*89c4ff92SAndroid Build Coastguard Worker shape { 465*89c4ff92SAndroid Build Coastguard Worker dim { 466*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 467*89c4ff92SAndroid Build Coastguard Worker } 468*89c4ff92SAndroid Build Coastguard Worker dim { 469*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 470*89c4ff92SAndroid Build Coastguard Worker } 471*89c4ff92SAndroid Build Coastguard Worker } 472*89c4ff92SAndroid Build Coastguard Worker } 473*89c4ff92SAndroid Build Coastguard Worker } 474*89c4ff92SAndroid Build Coastguard Worker } 475*89c4ff92SAndroid Build Coastguard Worker input { 476*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 477*89c4ff92SAndroid Build Coastguard Worker type { 478*89c4ff92SAndroid Build Coastguard Worker tensor_type { 479*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 480*89c4ff92SAndroid Build Coastguard Worker shape { 481*89c4ff92SAndroid Build Coastguard Worker dim { 482*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 483*89c4ff92SAndroid Build Coastguard Worker } 484*89c4ff92SAndroid Build Coastguard Worker dim { 485*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 486*89c4ff92SAndroid Build Coastguard Worker } 487*89c4ff92SAndroid Build Coastguard Worker } 488*89c4ff92SAndroid Build Coastguard Worker } 489*89c4ff92SAndroid Build Coastguard Worker } 490*89c4ff92SAndroid Build Coastguard Worker } 491*89c4ff92SAndroid Build Coastguard Worker initializer { 492*89c4ff92SAndroid Build Coastguard Worker dims: 1 493*89c4ff92SAndroid Build Coastguard Worker dims: 1 494*89c4ff92SAndroid Build Coastguard Worker data_type: 1 495*89c4ff92SAndroid Build Coastguard Worker float_data: 2 496*89c4ff92SAndroid Build Coastguard Worker name: "Weight" 497*89c4ff92SAndroid Build Coastguard Worker } 498*89c4ff92SAndroid Build Coastguard Worker input { 499*89c4ff92SAndroid Build Coastguard Worker name: "Bias" 500*89c4ff92SAndroid Build Coastguard Worker type { 501*89c4ff92SAndroid Build Coastguard Worker tensor_type { 502*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 503*89c4ff92SAndroid Build Coastguard Worker shape { 504*89c4ff92SAndroid Build Coastguard Worker dim { 505*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 506*89c4ff92SAndroid Build Coastguard Worker } 507*89c4ff92SAndroid Build Coastguard Worker } 508*89c4ff92SAndroid Build Coastguard Worker } 509*89c4ff92SAndroid Build Coastguard Worker } 510*89c4ff92SAndroid Build Coastguard Worker } 511*89c4ff92SAndroid Build Coastguard Worker initializer { 512*89c4ff92SAndroid Build Coastguard Worker dims: 1 513*89c4ff92SAndroid Build Coastguard Worker data_type: 1 514*89c4ff92SAndroid Build Coastguard Worker float_data: 1 515*89c4ff92SAndroid Build Coastguard Worker name: "Bias" 516*89c4ff92SAndroid Build Coastguard Worker } 517*89c4ff92SAndroid Build Coastguard Worker node { 518*89c4ff92SAndroid Build Coastguard Worker input: "Input" 519*89c4ff92SAndroid Build Coastguard Worker input: "Weight" 520*89c4ff92SAndroid Build Coastguard Worker output: "AddInput" 521*89c4ff92SAndroid Build Coastguard Worker name: "MatmulFC&NFC" 522*89c4ff92SAndroid Build Coastguard Worker op_type: "MatMul" 523*89c4ff92SAndroid Build Coastguard Worker } 524*89c4ff92SAndroid Build Coastguard Worker node { 525*89c4ff92SAndroid Build Coastguard Worker input: "AddInput" 526*89c4ff92SAndroid Build Coastguard Worker input: "Bias" 527*89c4ff92SAndroid Build Coastguard Worker output: "AddOutput" 528*89c4ff92SAndroid Build Coastguard Worker name: "FCAdd" 529*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 530*89c4ff92SAndroid Build Coastguard Worker } 531*89c4ff92SAndroid Build Coastguard Worker 532*89c4ff92SAndroid Build Coastguard Worker node { 533*89c4ff92SAndroid Build Coastguard Worker input: "AddInput" 534*89c4ff92SAndroid Build Coastguard Worker input: "AddOutput" 535*89c4ff92SAndroid Build Coastguard Worker output: "Output" 536*89c4ff92SAndroid Build Coastguard Worker name: "FinalAdd" 537*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 538*89c4ff92SAndroid Build Coastguard Worker } 539*89c4ff92SAndroid Build Coastguard Worker value_info { 540*89c4ff92SAndroid Build Coastguard Worker name: "AddInput" 541*89c4ff92SAndroid Build Coastguard Worker type { 542*89c4ff92SAndroid Build Coastguard Worker tensor_type { 543*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 544*89c4ff92SAndroid Build Coastguard Worker shape { 545*89c4ff92SAndroid Build Coastguard Worker dim { 546*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 547*89c4ff92SAndroid Build Coastguard Worker } 548*89c4ff92SAndroid Build Coastguard Worker dim { 549*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 550*89c4ff92SAndroid Build Coastguard Worker } 551*89c4ff92SAndroid Build Coastguard Worker } 552*89c4ff92SAndroid Build Coastguard Worker } 553*89c4ff92SAndroid Build Coastguard Worker } 554*89c4ff92SAndroid Build Coastguard Worker } 555*89c4ff92SAndroid Build Coastguard Worker value_info { 556*89c4ff92SAndroid Build Coastguard Worker name: "AddOutput" 557*89c4ff92SAndroid Build Coastguard Worker type { 558*89c4ff92SAndroid Build Coastguard Worker tensor_type { 559*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 560*89c4ff92SAndroid Build Coastguard Worker shape { 561*89c4ff92SAndroid Build Coastguard Worker dim { 562*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 563*89c4ff92SAndroid Build Coastguard Worker } 564*89c4ff92SAndroid Build Coastguard Worker dim { 565*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 566*89c4ff92SAndroid Build Coastguard Worker } 567*89c4ff92SAndroid Build Coastguard Worker } 568*89c4ff92SAndroid Build Coastguard Worker } 569*89c4ff92SAndroid Build Coastguard Worker } 570*89c4ff92SAndroid Build Coastguard Worker } 571*89c4ff92SAndroid Build Coastguard Worker output { 572*89c4ff92SAndroid Build Coastguard Worker name: "Output" 573*89c4ff92SAndroid Build Coastguard Worker type { 574*89c4ff92SAndroid Build Coastguard Worker tensor_type { 575*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 576*89c4ff92SAndroid Build Coastguard Worker shape { 577*89c4ff92SAndroid Build Coastguard Worker dim { 578*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 579*89c4ff92SAndroid Build Coastguard Worker } 580*89c4ff92SAndroid Build Coastguard Worker dim { 581*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 582*89c4ff92SAndroid Build Coastguard Worker } 583*89c4ff92SAndroid Build Coastguard Worker } 584*89c4ff92SAndroid Build Coastguard Worker } 585*89c4ff92SAndroid Build Coastguard Worker } 586*89c4ff92SAndroid Build Coastguard Worker } 587*89c4ff92SAndroid Build Coastguard Worker } 588*89c4ff92SAndroid Build Coastguard Worker opset_import { 589*89c4ff92SAndroid Build Coastguard Worker version: 7 590*89c4ff92SAndroid Build Coastguard Worker })"; 591*89c4ff92SAndroid Build Coastguard Worker Setup(); 592*89c4ff92SAndroid Build Coastguard Worker } 593*89c4ff92SAndroid Build Coastguard Worker }; 594*89c4ff92SAndroid Build Coastguard Worker 595*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MatMulUsedInTwoFcStaggeredFixture, "MatMulUsedInTwoFcStaggered") 596*89c4ff92SAndroid Build Coastguard Worker { 597*89c4ff92SAndroid Build Coastguard Worker RunTest<1>({{"Input", { 3 }}}, {{"Output", { 13 }}}); 598*89c4ff92SAndroid Build Coastguard Worker } 599*89c4ff92SAndroid Build Coastguard Worker 600*89c4ff92SAndroid Build Coastguard Worker } 601