1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_BatchNorm") 10*89c4ff92SAndroid Build Coastguard Worker { 11*89c4ff92SAndroid Build Coastguard Worker struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 12*89c4ff92SAndroid Build Coastguard Worker { BatchNormalizationMainFixtureBatchNormalizationMainFixture13*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationMainFixture() 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 16*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 17*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 18*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 19*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 20*89c4ff92SAndroid Build Coastguard Worker model_version: 1 21*89c4ff92SAndroid Build Coastguard Worker graph { 22*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 23*89c4ff92SAndroid Build Coastguard Worker input { 24*89c4ff92SAndroid Build Coastguard Worker name: "Input" 25*89c4ff92SAndroid Build Coastguard Worker type { 26*89c4ff92SAndroid Build Coastguard Worker tensor_type { 27*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 28*89c4ff92SAndroid Build Coastguard Worker shape { 29*89c4ff92SAndroid Build Coastguard Worker dim { 30*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 31*89c4ff92SAndroid Build Coastguard Worker } 32*89c4ff92SAndroid Build Coastguard Worker dim { 33*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 34*89c4ff92SAndroid Build Coastguard Worker } 35*89c4ff92SAndroid Build Coastguard Worker dim { 36*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 37*89c4ff92SAndroid Build Coastguard Worker } 38*89c4ff92SAndroid Build Coastguard Worker dim { 39*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 40*89c4ff92SAndroid Build Coastguard Worker } 41*89c4ff92SAndroid Build Coastguard Worker } 42*89c4ff92SAndroid Build Coastguard Worker } 43*89c4ff92SAndroid Build Coastguard Worker } 44*89c4ff92SAndroid Build Coastguard Worker } 45*89c4ff92SAndroid Build Coastguard Worker input { 46*89c4ff92SAndroid Build Coastguard Worker name: "mean" 47*89c4ff92SAndroid Build Coastguard Worker type { 48*89c4ff92SAndroid Build Coastguard Worker tensor_type { 49*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 50*89c4ff92SAndroid Build Coastguard Worker shape { 51*89c4ff92SAndroid Build Coastguard Worker dim { 52*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker } 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker input { 59*89c4ff92SAndroid Build Coastguard Worker name: "var" 60*89c4ff92SAndroid Build Coastguard Worker type { 61*89c4ff92SAndroid Build Coastguard Worker tensor_type { 62*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 63*89c4ff92SAndroid Build Coastguard Worker shape { 64*89c4ff92SAndroid Build Coastguard Worker dim { 65*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker } 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker input { 72*89c4ff92SAndroid Build Coastguard Worker name: "scale" 73*89c4ff92SAndroid Build Coastguard Worker type { 74*89c4ff92SAndroid Build Coastguard Worker tensor_type { 75*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 76*89c4ff92SAndroid Build Coastguard Worker shape { 77*89c4ff92SAndroid Build Coastguard Worker dim { 78*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker } 81*89c4ff92SAndroid Build Coastguard Worker } 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker input { 85*89c4ff92SAndroid Build Coastguard Worker name: "bias" 86*89c4ff92SAndroid Build Coastguard Worker type { 87*89c4ff92SAndroid Build Coastguard Worker tensor_type { 88*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 89*89c4ff92SAndroid Build Coastguard Worker shape { 90*89c4ff92SAndroid Build Coastguard Worker dim { 91*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker } 94*89c4ff92SAndroid Build Coastguard Worker } 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker } 97*89c4ff92SAndroid Build Coastguard Worker node { 98*89c4ff92SAndroid Build Coastguard Worker input: "Input" 99*89c4ff92SAndroid Build Coastguard Worker input: "scale" 100*89c4ff92SAndroid Build Coastguard Worker input: "bias" 101*89c4ff92SAndroid Build Coastguard Worker input: "mean" 102*89c4ff92SAndroid Build Coastguard Worker input: "var" 103*89c4ff92SAndroid Build Coastguard Worker output: "Output" 104*89c4ff92SAndroid Build Coastguard Worker name: "batchNorm" 105*89c4ff92SAndroid Build Coastguard Worker op_type: "BatchNormalization" 106*89c4ff92SAndroid Build Coastguard Worker attribute { 107*89c4ff92SAndroid Build Coastguard Worker name: "epsilon" 108*89c4ff92SAndroid Build Coastguard Worker f: 0.0010000000475 109*89c4ff92SAndroid Build Coastguard Worker type: 1 110*89c4ff92SAndroid Build Coastguard Worker } 111*89c4ff92SAndroid Build Coastguard Worker } 112*89c4ff92SAndroid Build Coastguard Worker initializer { 113*89c4ff92SAndroid Build Coastguard Worker dims: 1 114*89c4ff92SAndroid Build Coastguard Worker data_type: 1 115*89c4ff92SAndroid Build Coastguard Worker float_data: 5.0 116*89c4ff92SAndroid Build Coastguard Worker name: "mean" 117*89c4ff92SAndroid Build Coastguard Worker } 118*89c4ff92SAndroid Build Coastguard Worker initializer { 119*89c4ff92SAndroid Build Coastguard Worker dims: 1 120*89c4ff92SAndroid Build Coastguard Worker data_type: 1 121*89c4ff92SAndroid Build Coastguard Worker float_data: 2.0 122*89c4ff92SAndroid Build Coastguard Worker name: "var" 123*89c4ff92SAndroid Build Coastguard Worker } 124*89c4ff92SAndroid Build Coastguard Worker initializer { 125*89c4ff92SAndroid Build Coastguard Worker dims: 1 126*89c4ff92SAndroid Build Coastguard Worker data_type: 1 127*89c4ff92SAndroid Build Coastguard Worker float_data: 0.0 128*89c4ff92SAndroid Build Coastguard Worker name: "bias" 129*89c4ff92SAndroid Build Coastguard Worker } 130*89c4ff92SAndroid Build Coastguard Worker initializer { 131*89c4ff92SAndroid Build Coastguard Worker dims: 1 132*89c4ff92SAndroid Build Coastguard Worker data_type: 1 133*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 134*89c4ff92SAndroid Build Coastguard Worker name: "scale" 135*89c4ff92SAndroid Build Coastguard Worker } 136*89c4ff92SAndroid Build Coastguard Worker output { 137*89c4ff92SAndroid Build Coastguard Worker name: "Output" 138*89c4ff92SAndroid Build Coastguard Worker type { 139*89c4ff92SAndroid Build Coastguard Worker tensor_type { 140*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 141*89c4ff92SAndroid Build Coastguard Worker shape { 142*89c4ff92SAndroid Build Coastguard Worker dim { 143*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker dim { 146*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 147*89c4ff92SAndroid Build Coastguard Worker } 148*89c4ff92SAndroid Build Coastguard Worker dim { 149*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 150*89c4ff92SAndroid Build Coastguard Worker } 151*89c4ff92SAndroid Build Coastguard Worker dim { 152*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 153*89c4ff92SAndroid Build Coastguard Worker } 154*89c4ff92SAndroid Build Coastguard Worker } 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker } 157*89c4ff92SAndroid Build Coastguard Worker } 158*89c4ff92SAndroid Build Coastguard Worker } 159*89c4ff92SAndroid Build Coastguard Worker opset_import { 160*89c4ff92SAndroid Build Coastguard Worker version: 7 161*89c4ff92SAndroid Build Coastguard Worker })"; 162*89c4ff92SAndroid Build Coastguard Worker Setup(); 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker }; 165*89c4ff92SAndroid Build Coastguard Worker 166*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(BatchNormalizationMainFixture, "ValidBatchNormalizationTest") 167*89c4ff92SAndroid Build Coastguard Worker { 168*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}}, // Input data. 169*89c4ff92SAndroid Build Coastguard Worker {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f, 170*89c4ff92SAndroid Build Coastguard Worker -0.7069301f, 0.0f, 0.7069301f, 171*89c4ff92SAndroid Build Coastguard Worker 1.4138602f, 2.12079024f, 2.8277204f}}}); // Expected output data. 172*89c4ff92SAndroid Build Coastguard Worker } 173*89c4ff92SAndroid Build Coastguard Worker 174*89c4ff92SAndroid Build Coastguard Worker 175*89c4ff92SAndroid Build Coastguard Worker struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 176*89c4ff92SAndroid Build Coastguard Worker { BatchNormalizationBisFixtureBatchNormalizationBisFixture177*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationBisFixture() 178*89c4ff92SAndroid Build Coastguard Worker { 179*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 180*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 181*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 182*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 183*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 184*89c4ff92SAndroid Build Coastguard Worker model_version: 1 185*89c4ff92SAndroid Build Coastguard Worker graph { 186*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 187*89c4ff92SAndroid Build Coastguard Worker input { 188*89c4ff92SAndroid Build Coastguard Worker name: "Input" 189*89c4ff92SAndroid Build Coastguard Worker type { 190*89c4ff92SAndroid Build Coastguard Worker tensor_type { 191*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 192*89c4ff92SAndroid Build Coastguard Worker shape { 193*89c4ff92SAndroid Build Coastguard Worker dim { 194*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 195*89c4ff92SAndroid Build Coastguard Worker } 196*89c4ff92SAndroid Build Coastguard Worker dim { 197*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 198*89c4ff92SAndroid Build Coastguard Worker } 199*89c4ff92SAndroid Build Coastguard Worker dim { 200*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 201*89c4ff92SAndroid Build Coastguard Worker } 202*89c4ff92SAndroid Build Coastguard Worker dim { 203*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 204*89c4ff92SAndroid Build Coastguard Worker } 205*89c4ff92SAndroid Build Coastguard Worker } 206*89c4ff92SAndroid Build Coastguard Worker } 207*89c4ff92SAndroid Build Coastguard Worker } 208*89c4ff92SAndroid Build Coastguard Worker } 209*89c4ff92SAndroid Build Coastguard Worker input { 210*89c4ff92SAndroid Build Coastguard Worker name: "mean" 211*89c4ff92SAndroid Build Coastguard Worker type { 212*89c4ff92SAndroid Build Coastguard Worker tensor_type { 213*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 214*89c4ff92SAndroid Build Coastguard Worker shape { 215*89c4ff92SAndroid Build Coastguard Worker dim { 216*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 217*89c4ff92SAndroid Build Coastguard Worker } 218*89c4ff92SAndroid Build Coastguard Worker } 219*89c4ff92SAndroid Build Coastguard Worker } 220*89c4ff92SAndroid Build Coastguard Worker } 221*89c4ff92SAndroid Build Coastguard Worker } 222*89c4ff92SAndroid Build Coastguard Worker input { 223*89c4ff92SAndroid Build Coastguard Worker name: "var" 224*89c4ff92SAndroid Build Coastguard Worker type { 225*89c4ff92SAndroid Build Coastguard Worker tensor_type { 226*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 227*89c4ff92SAndroid Build Coastguard Worker shape { 228*89c4ff92SAndroid Build Coastguard Worker dim { 229*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 230*89c4ff92SAndroid Build Coastguard Worker } 231*89c4ff92SAndroid Build Coastguard Worker } 232*89c4ff92SAndroid Build Coastguard Worker } 233*89c4ff92SAndroid Build Coastguard Worker } 234*89c4ff92SAndroid Build Coastguard Worker } 235*89c4ff92SAndroid Build Coastguard Worker input { 236*89c4ff92SAndroid Build Coastguard Worker name: "scale" 237*89c4ff92SAndroid Build Coastguard Worker type { 238*89c4ff92SAndroid Build Coastguard Worker tensor_type { 239*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 240*89c4ff92SAndroid Build Coastguard Worker shape { 241*89c4ff92SAndroid Build Coastguard Worker dim { 242*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 243*89c4ff92SAndroid Build Coastguard Worker } 244*89c4ff92SAndroid Build Coastguard Worker } 245*89c4ff92SAndroid Build Coastguard Worker } 246*89c4ff92SAndroid Build Coastguard Worker } 247*89c4ff92SAndroid Build Coastguard Worker } 248*89c4ff92SAndroid Build Coastguard Worker input { 249*89c4ff92SAndroid Build Coastguard Worker name: "bias" 250*89c4ff92SAndroid Build Coastguard Worker type { 251*89c4ff92SAndroid Build Coastguard Worker tensor_type { 252*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 253*89c4ff92SAndroid Build Coastguard Worker shape { 254*89c4ff92SAndroid Build Coastguard Worker dim { 255*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 256*89c4ff92SAndroid Build Coastguard Worker } 257*89c4ff92SAndroid Build Coastguard Worker } 258*89c4ff92SAndroid Build Coastguard Worker } 259*89c4ff92SAndroid Build Coastguard Worker } 260*89c4ff92SAndroid Build Coastguard Worker } 261*89c4ff92SAndroid Build Coastguard Worker node { 262*89c4ff92SAndroid Build Coastguard Worker input: "Input" 263*89c4ff92SAndroid Build Coastguard Worker input: "scale" 264*89c4ff92SAndroid Build Coastguard Worker input: "bias" 265*89c4ff92SAndroid Build Coastguard Worker input: "mean" 266*89c4ff92SAndroid Build Coastguard Worker input: "var" 267*89c4ff92SAndroid Build Coastguard Worker output: "Output" 268*89c4ff92SAndroid Build Coastguard Worker name: "batchNorm" 269*89c4ff92SAndroid Build Coastguard Worker op_type: "BatchNormalization" 270*89c4ff92SAndroid Build Coastguard Worker attribute { 271*89c4ff92SAndroid Build Coastguard Worker name: "epsilon" 272*89c4ff92SAndroid Build Coastguard Worker f: 0.00001 273*89c4ff92SAndroid Build Coastguard Worker type: 1 274*89c4ff92SAndroid Build Coastguard Worker } 275*89c4ff92SAndroid Build Coastguard Worker } 276*89c4ff92SAndroid Build Coastguard Worker initializer { 277*89c4ff92SAndroid Build Coastguard Worker dims: 2 278*89c4ff92SAndroid Build Coastguard Worker data_type: 1 279*89c4ff92SAndroid Build Coastguard Worker float_data: 0.0 280*89c4ff92SAndroid Build Coastguard Worker float_data: 3.0 281*89c4ff92SAndroid Build Coastguard Worker name: "mean" 282*89c4ff92SAndroid Build Coastguard Worker } 283*89c4ff92SAndroid Build Coastguard Worker initializer { 284*89c4ff92SAndroid Build Coastguard Worker dims: 2 285*89c4ff92SAndroid Build Coastguard Worker data_type: 1 286*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 287*89c4ff92SAndroid Build Coastguard Worker float_data: 1.5 288*89c4ff92SAndroid Build Coastguard Worker name: "var" 289*89c4ff92SAndroid Build Coastguard Worker } 290*89c4ff92SAndroid Build Coastguard Worker initializer { 291*89c4ff92SAndroid Build Coastguard Worker dims: 2 292*89c4ff92SAndroid Build Coastguard Worker data_type: 1 293*89c4ff92SAndroid Build Coastguard Worker float_data: 0.0 294*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 295*89c4ff92SAndroid Build Coastguard Worker name: "bias" 296*89c4ff92SAndroid Build Coastguard Worker } 297*89c4ff92SAndroid Build Coastguard Worker initializer { 298*89c4ff92SAndroid Build Coastguard Worker dims: 2 299*89c4ff92SAndroid Build Coastguard Worker data_type: 1 300*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 301*89c4ff92SAndroid Build Coastguard Worker float_data: 1.5 302*89c4ff92SAndroid Build Coastguard Worker name: "scale" 303*89c4ff92SAndroid Build Coastguard Worker } 304*89c4ff92SAndroid Build Coastguard Worker output { 305*89c4ff92SAndroid Build Coastguard Worker name: "Output" 306*89c4ff92SAndroid Build Coastguard Worker type { 307*89c4ff92SAndroid Build Coastguard Worker tensor_type { 308*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 309*89c4ff92SAndroid Build Coastguard Worker shape { 310*89c4ff92SAndroid Build Coastguard Worker dim { 311*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 312*89c4ff92SAndroid Build Coastguard Worker } 313*89c4ff92SAndroid Build Coastguard Worker dim { 314*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 315*89c4ff92SAndroid Build Coastguard Worker } 316*89c4ff92SAndroid Build Coastguard Worker dim { 317*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 318*89c4ff92SAndroid Build Coastguard Worker } 319*89c4ff92SAndroid Build Coastguard Worker dim { 320*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 321*89c4ff92SAndroid Build Coastguard Worker } 322*89c4ff92SAndroid Build Coastguard Worker } 323*89c4ff92SAndroid Build Coastguard Worker } 324*89c4ff92SAndroid Build Coastguard Worker } 325*89c4ff92SAndroid Build Coastguard Worker } 326*89c4ff92SAndroid Build Coastguard Worker } 327*89c4ff92SAndroid Build Coastguard Worker opset_import { 328*89c4ff92SAndroid Build Coastguard Worker version: 7 329*89c4ff92SAndroid Build Coastguard Worker })"; 330*89c4ff92SAndroid Build Coastguard Worker Setup(); 331*89c4ff92SAndroid Build Coastguard Worker } 332*89c4ff92SAndroid Build Coastguard Worker }; 333*89c4ff92SAndroid Build Coastguard Worker 334*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(BatchNormalizationBisFixture, "ValidBatchNormalizationBisTest") 335*89c4ff92SAndroid Build Coastguard Worker { 336*89c4ff92SAndroid Build Coastguard Worker RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}}, // Input data. 337*89c4ff92SAndroid Build Coastguard Worker {{"Output", {-0.999995f, 0.0, 0.999995f, 338*89c4ff92SAndroid Build Coastguard Worker -0.22474074f, 1.0f, 2.2247407f}}}); // Expected output data. 339*89c4ff92SAndroid Build Coastguard Worker } 340*89c4ff92SAndroid Build Coastguard Worker 341*89c4ff92SAndroid Build Coastguard Worker } 342