1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Pooling") 10*89c4ff92SAndroid Build Coastguard Worker { 11*89c4ff92SAndroid Build Coastguard Worker struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12*89c4ff92SAndroid Build Coastguard Worker { PoolingMainFixturePoolingMainFixture13*89c4ff92SAndroid Build Coastguard Worker PoolingMainFixture(const std::string& dataType, const std::string& op) 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 16*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 17*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 18*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 19*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 20*89c4ff92SAndroid Build Coastguard Worker model_version: 1 21*89c4ff92SAndroid Build Coastguard Worker graph { 22*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 23*89c4ff92SAndroid Build Coastguard Worker input { 24*89c4ff92SAndroid Build Coastguard Worker name: "Input" 25*89c4ff92SAndroid Build Coastguard Worker type { 26*89c4ff92SAndroid Build Coastguard Worker tensor_type { 27*89c4ff92SAndroid Build Coastguard Worker elem_type: )" + dataType + R"( 28*89c4ff92SAndroid Build Coastguard Worker shape { 29*89c4ff92SAndroid Build Coastguard Worker dim { 30*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 31*89c4ff92SAndroid Build Coastguard Worker } 32*89c4ff92SAndroid Build Coastguard Worker dim { 33*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 34*89c4ff92SAndroid Build Coastguard Worker } 35*89c4ff92SAndroid Build Coastguard Worker dim { 36*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 37*89c4ff92SAndroid Build Coastguard Worker } 38*89c4ff92SAndroid Build Coastguard Worker dim { 39*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 40*89c4ff92SAndroid Build Coastguard Worker } 41*89c4ff92SAndroid Build Coastguard Worker } 42*89c4ff92SAndroid Build Coastguard Worker } 43*89c4ff92SAndroid Build Coastguard Worker } 44*89c4ff92SAndroid Build Coastguard Worker } 45*89c4ff92SAndroid Build Coastguard Worker node { 46*89c4ff92SAndroid Build Coastguard Worker input: "Input" 47*89c4ff92SAndroid Build Coastguard Worker output: "Output" 48*89c4ff92SAndroid Build Coastguard Worker name: "Pooling" 49*89c4ff92SAndroid Build Coastguard Worker op_type: )" + op + R"( 50*89c4ff92SAndroid Build Coastguard Worker attribute { 51*89c4ff92SAndroid Build Coastguard Worker name: "kernel_shape" 52*89c4ff92SAndroid Build Coastguard Worker ints: 2 53*89c4ff92SAndroid Build Coastguard Worker ints: 2 54*89c4ff92SAndroid Build Coastguard Worker type: INTS 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker attribute { 57*89c4ff92SAndroid Build Coastguard Worker name: "strides" 58*89c4ff92SAndroid Build Coastguard Worker ints: 1 59*89c4ff92SAndroid Build Coastguard Worker ints: 1 60*89c4ff92SAndroid Build Coastguard Worker type: INTS 61*89c4ff92SAndroid Build Coastguard Worker } 62*89c4ff92SAndroid Build Coastguard Worker attribute { 63*89c4ff92SAndroid Build Coastguard Worker name: "pads" 64*89c4ff92SAndroid Build Coastguard Worker ints: 0 65*89c4ff92SAndroid Build Coastguard Worker ints: 0 66*89c4ff92SAndroid Build Coastguard Worker ints: 0 67*89c4ff92SAndroid Build Coastguard Worker ints: 0 68*89c4ff92SAndroid Build Coastguard Worker type: INTS 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker output { 72*89c4ff92SAndroid Build Coastguard Worker name: "Output" 73*89c4ff92SAndroid Build Coastguard Worker type { 74*89c4ff92SAndroid Build Coastguard Worker tensor_type { 75*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 76*89c4ff92SAndroid Build Coastguard Worker shape { 77*89c4ff92SAndroid Build Coastguard Worker dim { 78*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker dim { 81*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker dim { 84*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 85*89c4ff92SAndroid Build Coastguard Worker } 86*89c4ff92SAndroid Build Coastguard Worker dim { 87*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 88*89c4ff92SAndroid Build Coastguard Worker } 89*89c4ff92SAndroid Build Coastguard Worker } 90*89c4ff92SAndroid Build Coastguard Worker } 91*89c4ff92SAndroid Build Coastguard Worker } 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker } 94*89c4ff92SAndroid Build Coastguard Worker opset_import { 95*89c4ff92SAndroid Build Coastguard Worker version: 7 96*89c4ff92SAndroid Build Coastguard Worker })"; 97*89c4ff92SAndroid Build Coastguard Worker } 98*89c4ff92SAndroid Build Coastguard Worker }; 99*89c4ff92SAndroid Build Coastguard Worker 100*89c4ff92SAndroid Build Coastguard Worker struct MaxPoolValidFixture : PoolingMainFixture 101*89c4ff92SAndroid Build Coastguard Worker { MaxPoolValidFixtureMaxPoolValidFixture102*89c4ff92SAndroid Build Coastguard Worker MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") { 103*89c4ff92SAndroid Build Coastguard Worker Setup(); 104*89c4ff92SAndroid Build Coastguard Worker } 105*89c4ff92SAndroid Build Coastguard Worker }; 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker struct MaxPoolInvalidFixture : PoolingMainFixture 108*89c4ff92SAndroid Build Coastguard Worker { MaxPoolInvalidFixtureMaxPoolInvalidFixture109*89c4ff92SAndroid Build Coastguard Worker MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { } 110*89c4ff92SAndroid Build Coastguard Worker }; 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MaxPoolValidFixture, "ValidMaxPoolTest") 113*89c4ff92SAndroid Build Coastguard Worker { 114*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}}); 115*89c4ff92SAndroid Build Coastguard Worker } 116*89c4ff92SAndroid Build Coastguard Worker 117*89c4ff92SAndroid Build Coastguard Worker struct AvgPoolValidFixture : PoolingMainFixture 118*89c4ff92SAndroid Build Coastguard Worker { AvgPoolValidFixtureAvgPoolValidFixture119*89c4ff92SAndroid Build Coastguard Worker AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") { 120*89c4ff92SAndroid Build Coastguard Worker Setup(); 121*89c4ff92SAndroid Build Coastguard Worker } 122*89c4ff92SAndroid Build Coastguard Worker }; 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 125*89c4ff92SAndroid Build Coastguard Worker { PoolingWithPadFixturePoolingWithPadFixture126*89c4ff92SAndroid Build Coastguard Worker PoolingWithPadFixture() 127*89c4ff92SAndroid Build Coastguard Worker { 128*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 129*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 130*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 131*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 132*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 133*89c4ff92SAndroid Build Coastguard Worker model_version: 1 134*89c4ff92SAndroid Build Coastguard Worker graph { 135*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 136*89c4ff92SAndroid Build Coastguard Worker input { 137*89c4ff92SAndroid Build Coastguard Worker name: "Input" 138*89c4ff92SAndroid Build Coastguard Worker type { 139*89c4ff92SAndroid Build Coastguard Worker tensor_type { 140*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 141*89c4ff92SAndroid Build Coastguard Worker shape { 142*89c4ff92SAndroid Build Coastguard Worker dim { 143*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker dim { 146*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 147*89c4ff92SAndroid Build Coastguard Worker } 148*89c4ff92SAndroid Build Coastguard Worker dim { 149*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 150*89c4ff92SAndroid Build Coastguard Worker } 151*89c4ff92SAndroid Build Coastguard Worker dim { 152*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 153*89c4ff92SAndroid Build Coastguard Worker } 154*89c4ff92SAndroid Build Coastguard Worker } 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker } 157*89c4ff92SAndroid Build Coastguard Worker } 158*89c4ff92SAndroid Build Coastguard Worker node { 159*89c4ff92SAndroid Build Coastguard Worker input: "Input" 160*89c4ff92SAndroid Build Coastguard Worker output: "Output" 161*89c4ff92SAndroid Build Coastguard Worker name: "Pooling" 162*89c4ff92SAndroid Build Coastguard Worker op_type: "AveragePool" 163*89c4ff92SAndroid Build Coastguard Worker attribute { 164*89c4ff92SAndroid Build Coastguard Worker name: "kernel_shape" 165*89c4ff92SAndroid Build Coastguard Worker ints: 4 166*89c4ff92SAndroid Build Coastguard Worker ints: 4 167*89c4ff92SAndroid Build Coastguard Worker type: INTS 168*89c4ff92SAndroid Build Coastguard Worker } 169*89c4ff92SAndroid Build Coastguard Worker attribute { 170*89c4ff92SAndroid Build Coastguard Worker name: "strides" 171*89c4ff92SAndroid Build Coastguard Worker ints: 1 172*89c4ff92SAndroid Build Coastguard Worker ints: 1 173*89c4ff92SAndroid Build Coastguard Worker type: INTS 174*89c4ff92SAndroid Build Coastguard Worker } 175*89c4ff92SAndroid Build Coastguard Worker attribute { 176*89c4ff92SAndroid Build Coastguard Worker name: "pads" 177*89c4ff92SAndroid Build Coastguard Worker ints: 1 178*89c4ff92SAndroid Build Coastguard Worker ints: 1 179*89c4ff92SAndroid Build Coastguard Worker ints: 1 180*89c4ff92SAndroid Build Coastguard Worker ints: 1 181*89c4ff92SAndroid Build Coastguard Worker type: INTS 182*89c4ff92SAndroid Build Coastguard Worker } 183*89c4ff92SAndroid Build Coastguard Worker attribute { 184*89c4ff92SAndroid Build Coastguard Worker name: "count_include_pad" 185*89c4ff92SAndroid Build Coastguard Worker i: 1 186*89c4ff92SAndroid Build Coastguard Worker type: INT 187*89c4ff92SAndroid Build Coastguard Worker } 188*89c4ff92SAndroid Build Coastguard Worker } 189*89c4ff92SAndroid Build Coastguard Worker output { 190*89c4ff92SAndroid Build Coastguard Worker name: "Output" 191*89c4ff92SAndroid Build Coastguard Worker type { 192*89c4ff92SAndroid Build Coastguard Worker tensor_type { 193*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 194*89c4ff92SAndroid Build Coastguard Worker shape { 195*89c4ff92SAndroid Build Coastguard Worker dim { 196*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 197*89c4ff92SAndroid Build Coastguard Worker } 198*89c4ff92SAndroid Build Coastguard Worker dim { 199*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 200*89c4ff92SAndroid Build Coastguard Worker } 201*89c4ff92SAndroid Build Coastguard Worker dim { 202*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 203*89c4ff92SAndroid Build Coastguard Worker } 204*89c4ff92SAndroid Build Coastguard Worker dim { 205*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 206*89c4ff92SAndroid Build Coastguard Worker } 207*89c4ff92SAndroid Build Coastguard Worker } 208*89c4ff92SAndroid Build Coastguard Worker } 209*89c4ff92SAndroid Build Coastguard Worker } 210*89c4ff92SAndroid Build Coastguard Worker } 211*89c4ff92SAndroid Build Coastguard Worker } 212*89c4ff92SAndroid Build Coastguard Worker opset_import { 213*89c4ff92SAndroid Build Coastguard Worker version: 7 214*89c4ff92SAndroid Build Coastguard Worker })"; 215*89c4ff92SAndroid Build Coastguard Worker Setup(); 216*89c4ff92SAndroid Build Coastguard Worker } 217*89c4ff92SAndroid Build Coastguard Worker }; 218*89c4ff92SAndroid Build Coastguard Worker 219*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(AvgPoolValidFixture, "AveragePoolValid") 220*89c4ff92SAndroid Build Coastguard Worker { 221*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}}); 222*89c4ff92SAndroid Build Coastguard Worker } 223*89c4ff92SAndroid Build Coastguard Worker 224*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(PoolingWithPadFixture, "ValidAvgWithPadTest") 225*89c4ff92SAndroid Build Coastguard Worker { 226*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}}); 227*89c4ff92SAndroid Build Coastguard Worker } 228*89c4ff92SAndroid Build Coastguard Worker 229*89c4ff92SAndroid Build Coastguard Worker struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 230*89c4ff92SAndroid Build Coastguard Worker { GlobalAvgFixtureGlobalAvgFixture231*89c4ff92SAndroid Build Coastguard Worker GlobalAvgFixture() 232*89c4ff92SAndroid Build Coastguard Worker { 233*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 234*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 235*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 236*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 237*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 238*89c4ff92SAndroid Build Coastguard Worker model_version: 1 239*89c4ff92SAndroid Build Coastguard Worker graph { 240*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 241*89c4ff92SAndroid Build Coastguard Worker input { 242*89c4ff92SAndroid Build Coastguard Worker name: "Input" 243*89c4ff92SAndroid Build Coastguard Worker type { 244*89c4ff92SAndroid Build Coastguard Worker tensor_type { 245*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 246*89c4ff92SAndroid Build Coastguard Worker shape { 247*89c4ff92SAndroid Build Coastguard Worker dim { 248*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 249*89c4ff92SAndroid Build Coastguard Worker } 250*89c4ff92SAndroid Build Coastguard Worker dim { 251*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 252*89c4ff92SAndroid Build Coastguard Worker } 253*89c4ff92SAndroid Build Coastguard Worker dim { 254*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 255*89c4ff92SAndroid Build Coastguard Worker } 256*89c4ff92SAndroid Build Coastguard Worker dim { 257*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 258*89c4ff92SAndroid Build Coastguard Worker } 259*89c4ff92SAndroid Build Coastguard Worker } 260*89c4ff92SAndroid Build Coastguard Worker } 261*89c4ff92SAndroid Build Coastguard Worker } 262*89c4ff92SAndroid Build Coastguard Worker } 263*89c4ff92SAndroid Build Coastguard Worker node { 264*89c4ff92SAndroid Build Coastguard Worker input: "Input" 265*89c4ff92SAndroid Build Coastguard Worker output: "Output" 266*89c4ff92SAndroid Build Coastguard Worker name: "Pooling" 267*89c4ff92SAndroid Build Coastguard Worker op_type: "GlobalAveragePool" 268*89c4ff92SAndroid Build Coastguard Worker } 269*89c4ff92SAndroid Build Coastguard Worker output { 270*89c4ff92SAndroid Build Coastguard Worker name: "Output" 271*89c4ff92SAndroid Build Coastguard Worker type { 272*89c4ff92SAndroid Build Coastguard Worker tensor_type { 273*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 274*89c4ff92SAndroid Build Coastguard Worker shape { 275*89c4ff92SAndroid Build Coastguard Worker dim { 276*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 277*89c4ff92SAndroid Build Coastguard Worker } 278*89c4ff92SAndroid Build Coastguard Worker dim { 279*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 280*89c4ff92SAndroid Build Coastguard Worker } 281*89c4ff92SAndroid Build Coastguard Worker dim { 282*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 283*89c4ff92SAndroid Build Coastguard Worker } 284*89c4ff92SAndroid Build Coastguard Worker dim { 285*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 286*89c4ff92SAndroid Build Coastguard Worker } 287*89c4ff92SAndroid Build Coastguard Worker } 288*89c4ff92SAndroid Build Coastguard Worker } 289*89c4ff92SAndroid Build Coastguard Worker } 290*89c4ff92SAndroid Build Coastguard Worker } 291*89c4ff92SAndroid Build Coastguard Worker } 292*89c4ff92SAndroid Build Coastguard Worker opset_import { 293*89c4ff92SAndroid Build Coastguard Worker version: 7 294*89c4ff92SAndroid Build Coastguard Worker })"; 295*89c4ff92SAndroid Build Coastguard Worker Setup(); 296*89c4ff92SAndroid Build Coastguard Worker } 297*89c4ff92SAndroid Build Coastguard Worker }; 298*89c4ff92SAndroid Build Coastguard Worker 299*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GlobalAvgFixture, "GlobalAvgTest") 300*89c4ff92SAndroid Build Coastguard Worker { 301*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}}); 302*89c4ff92SAndroid Build Coastguard Worker } 303*89c4ff92SAndroid Build Coastguard Worker 304*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MaxPoolInvalidFixture, "IncorrectDataTypeMaxPool") 305*89c4ff92SAndroid Build Coastguard Worker { 306*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(Setup(), armnn::ParseException); 307*89c4ff92SAndroid Build Coastguard Worker } 308*89c4ff92SAndroid Build Coastguard Worker 309*89c4ff92SAndroid Build Coastguard Worker } 310