1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Unsqueeze") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14*89c4ff92SAndroid Build Coastguard Worker { UnsqueezeFixtureUnsqueezeFixture15*89c4ff92SAndroid Build Coastguard Worker UnsqueezeFixture(const std::vector<int>& axes, 16*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputShape, 17*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& outputShape) 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 20*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 21*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 22*89c4ff92SAndroid Build Coastguard Worker graph { 23*89c4ff92SAndroid Build Coastguard Worker node { 24*89c4ff92SAndroid Build Coastguard Worker input: "Input" 25*89c4ff92SAndroid Build Coastguard Worker output: "Output" 26*89c4ff92SAndroid Build Coastguard Worker op_type: "Unsqueeze" 27*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructIntsAttribute("axes", axes) + R"( 28*89c4ff92SAndroid Build Coastguard Worker } 29*89c4ff92SAndroid Build Coastguard Worker name: "test-model" 30*89c4ff92SAndroid Build Coastguard Worker input { 31*89c4ff92SAndroid Build Coastguard Worker name: "Input" 32*89c4ff92SAndroid Build Coastguard Worker type { 33*89c4ff92SAndroid Build Coastguard Worker tensor_type { 34*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 35*89c4ff92SAndroid Build Coastguard Worker shape { 36*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 37*89c4ff92SAndroid Build Coastguard Worker } 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker } 40*89c4ff92SAndroid Build Coastguard Worker } 41*89c4ff92SAndroid Build Coastguard Worker output { 42*89c4ff92SAndroid Build Coastguard Worker name: "Output" 43*89c4ff92SAndroid Build Coastguard Worker type { 44*89c4ff92SAndroid Build Coastguard Worker tensor_type { 45*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 46*89c4ff92SAndroid Build Coastguard Worker shape { 47*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"( 48*89c4ff92SAndroid Build Coastguard Worker } 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker } 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker })"; 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker }; 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeSingleAxesFixture : UnsqueezeFixture 57*89c4ff92SAndroid Build Coastguard Worker { UnsqueezeSingleAxesFixtureUnsqueezeSingleAxesFixture58*89c4ff92SAndroid Build Coastguard Worker UnsqueezeSingleAxesFixture() : UnsqueezeFixture({ 0 }, { 2, 3 }, { 1, 2, 3 }) 59*89c4ff92SAndroid Build Coastguard Worker { 60*89c4ff92SAndroid Build Coastguard Worker Setup(); 61*89c4ff92SAndroid Build Coastguard Worker } 62*89c4ff92SAndroid Build Coastguard Worker }; 63*89c4ff92SAndroid Build Coastguard Worker 64*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeMultiAxesFixture : UnsqueezeFixture 65*89c4ff92SAndroid Build Coastguard Worker { UnsqueezeMultiAxesFixtureUnsqueezeMultiAxesFixture66*89c4ff92SAndroid Build Coastguard Worker UnsqueezeMultiAxesFixture() : UnsqueezeFixture({ 1, 3 }, { 3, 2, 5 }, { 3, 1, 2, 1, 5 }) 67*89c4ff92SAndroid Build Coastguard Worker { 68*89c4ff92SAndroid Build Coastguard Worker Setup(); 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker }; 71*89c4ff92SAndroid Build Coastguard Worker 72*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture 73*89c4ff92SAndroid Build Coastguard Worker { UnsqueezeUnsortedAxesFixtureUnsqueezeUnsortedAxesFixture74*89c4ff92SAndroid Build Coastguard Worker UnsqueezeUnsortedAxesFixture() : UnsqueezeFixture({ 3, 0, 1 }, { 2, 5 }, { 1, 1, 2, 1, 5 }) 75*89c4ff92SAndroid Build Coastguard Worker { 76*89c4ff92SAndroid Build Coastguard Worker Setup(); 77*89c4ff92SAndroid Build Coastguard Worker } 78*89c4ff92SAndroid Build Coastguard Worker }; 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeScalarFixture : UnsqueezeFixture 81*89c4ff92SAndroid Build Coastguard Worker { UnsqueezeScalarFixtureUnsqueezeScalarFixture82*89c4ff92SAndroid Build Coastguard Worker UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 }) 83*89c4ff92SAndroid Build Coastguard Worker { 84*89c4ff92SAndroid Build Coastguard Worker Setup(); 85*89c4ff92SAndroid Build Coastguard Worker } 86*89c4ff92SAndroid Build Coastguard Worker }; 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest") 89*89c4ff92SAndroid Build Coastguard Worker { 90*89c4ff92SAndroid Build Coastguard Worker RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}, 91*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}}); 92*89c4ff92SAndroid Build Coastguard Worker } 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeMultiAxesFixture, "UnsqueezeMultiAxesTest") 95*89c4ff92SAndroid Build Coastguard Worker { 96*89c4ff92SAndroid Build Coastguard Worker RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 97*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 98*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 99*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 100*89c4ff92SAndroid Build Coastguard Worker 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 101*89c4ff92SAndroid Build Coastguard Worker 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}, 102*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 103*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 104*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 105*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 106*89c4ff92SAndroid Build Coastguard Worker 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 107*89c4ff92SAndroid Build Coastguard Worker 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}); 108*89c4ff92SAndroid Build Coastguard Worker } 109*89c4ff92SAndroid Build Coastguard Worker 110*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest") 111*89c4ff92SAndroid Build Coastguard Worker { 112*89c4ff92SAndroid Build Coastguard Worker RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 113*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}, 114*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 115*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}}); 116*89c4ff92SAndroid Build Coastguard Worker } 117*89c4ff92SAndroid Build Coastguard Worker 118*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest") 119*89c4ff92SAndroid Build Coastguard Worker { 120*89c4ff92SAndroid Build Coastguard Worker RunTest<1, float>({{"Input", { 1.0f }}}, 121*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f }}}); 122*89c4ff92SAndroid Build Coastguard Worker } 123*89c4ff92SAndroid Build Coastguard Worker 124*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 125*89c4ff92SAndroid Build Coastguard Worker { UnsqueezeInputAxesFixtureUnsqueezeInputAxesFixture126*89c4ff92SAndroid Build Coastguard Worker UnsqueezeInputAxesFixture() 127*89c4ff92SAndroid Build Coastguard Worker { 128*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 129*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 130*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 131*89c4ff92SAndroid Build Coastguard Worker graph { 132*89c4ff92SAndroid Build Coastguard Worker node { 133*89c4ff92SAndroid Build Coastguard Worker input: "Input" 134*89c4ff92SAndroid Build Coastguard Worker input: "Axes" 135*89c4ff92SAndroid Build Coastguard Worker output: "Output" 136*89c4ff92SAndroid Build Coastguard Worker op_type: "Unsqueeze" 137*89c4ff92SAndroid Build Coastguard Worker } 138*89c4ff92SAndroid Build Coastguard Worker initializer { 139*89c4ff92SAndroid Build Coastguard Worker dims: 2 140*89c4ff92SAndroid Build Coastguard Worker data_type: 7 141*89c4ff92SAndroid Build Coastguard Worker int64_data: 0 142*89c4ff92SAndroid Build Coastguard Worker int64_data: 3 143*89c4ff92SAndroid Build Coastguard Worker name: "Axes" 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker name: "test-model" 146*89c4ff92SAndroid Build Coastguard Worker input { 147*89c4ff92SAndroid Build Coastguard Worker name: "Input" 148*89c4ff92SAndroid Build Coastguard Worker type { 149*89c4ff92SAndroid Build Coastguard Worker tensor_type { 150*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 151*89c4ff92SAndroid Build Coastguard Worker shape { 152*89c4ff92SAndroid Build Coastguard Worker dim { 153*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 154*89c4ff92SAndroid Build Coastguard Worker } 155*89c4ff92SAndroid Build Coastguard Worker dim { 156*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 157*89c4ff92SAndroid Build Coastguard Worker } 158*89c4ff92SAndroid Build Coastguard Worker dim { 159*89c4ff92SAndroid Build Coastguard Worker dim_value: 5 160*89c4ff92SAndroid Build Coastguard Worker } 161*89c4ff92SAndroid Build Coastguard Worker } 162*89c4ff92SAndroid Build Coastguard Worker } 163*89c4ff92SAndroid Build Coastguard Worker } 164*89c4ff92SAndroid Build Coastguard Worker } 165*89c4ff92SAndroid Build Coastguard Worker output { 166*89c4ff92SAndroid Build Coastguard Worker name: "Output" 167*89c4ff92SAndroid Build Coastguard Worker type { 168*89c4ff92SAndroid Build Coastguard Worker tensor_type { 169*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 170*89c4ff92SAndroid Build Coastguard Worker shape { 171*89c4ff92SAndroid Build Coastguard Worker dim { 172*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 173*89c4ff92SAndroid Build Coastguard Worker } 174*89c4ff92SAndroid Build Coastguard Worker dim { 175*89c4ff92SAndroid Build Coastguard Worker dim_value: 3 176*89c4ff92SAndroid Build Coastguard Worker } 177*89c4ff92SAndroid Build Coastguard Worker dim { 178*89c4ff92SAndroid Build Coastguard Worker dim_value: 2 179*89c4ff92SAndroid Build Coastguard Worker } 180*89c4ff92SAndroid Build Coastguard Worker dim { 181*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 182*89c4ff92SAndroid Build Coastguard Worker } 183*89c4ff92SAndroid Build Coastguard Worker dim { 184*89c4ff92SAndroid Build Coastguard Worker dim_value: 5 185*89c4ff92SAndroid Build Coastguard Worker } 186*89c4ff92SAndroid Build Coastguard Worker } 187*89c4ff92SAndroid Build Coastguard Worker } 188*89c4ff92SAndroid Build Coastguard Worker } 189*89c4ff92SAndroid Build Coastguard Worker } 190*89c4ff92SAndroid Build Coastguard Worker })"; 191*89c4ff92SAndroid Build Coastguard Worker Setup(); 192*89c4ff92SAndroid Build Coastguard Worker } 193*89c4ff92SAndroid Build Coastguard Worker }; 194*89c4ff92SAndroid Build Coastguard Worker 195*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeInputAxesFixture, "UnsqueezeInputAxesTest") 196*89c4ff92SAndroid Build Coastguard Worker { 197*89c4ff92SAndroid Build Coastguard Worker RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 198*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 199*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 200*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 201*89c4ff92SAndroid Build Coastguard Worker 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 202*89c4ff92SAndroid Build Coastguard Worker 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}, 203*89c4ff92SAndroid Build Coastguard Worker {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 204*89c4ff92SAndroid Build Coastguard Worker 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 205*89c4ff92SAndroid Build Coastguard Worker 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 206*89c4ff92SAndroid Build Coastguard Worker 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 207*89c4ff92SAndroid Build Coastguard Worker 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 208*89c4ff92SAndroid Build Coastguard Worker 26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}}); 209*89c4ff92SAndroid Build Coastguard Worker } 210*89c4ff92SAndroid Build Coastguard Worker 211*89c4ff92SAndroid Build Coastguard Worker } 212