1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Shape") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 14*89c4ff92SAndroid Build Coastguard Worker { ShapeMainFixtureShapeMainFixture15*89c4ff92SAndroid Build Coastguard Worker ShapeMainFixture(const std::string& inputType, 16*89c4ff92SAndroid Build Coastguard Worker const std::string& outputType, 17*89c4ff92SAndroid Build Coastguard Worker const std::string& outputDim, 18*89c4ff92SAndroid Build Coastguard Worker const std::vector<int>& inputShape) 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 21*89c4ff92SAndroid Build Coastguard Worker ir_version: 8 22*89c4ff92SAndroid Build Coastguard Worker producer_name: "onnx-example" 23*89c4ff92SAndroid Build Coastguard Worker graph { 24*89c4ff92SAndroid Build Coastguard Worker node { 25*89c4ff92SAndroid Build Coastguard Worker input: "Input" 26*89c4ff92SAndroid Build Coastguard Worker output: "Output" 27*89c4ff92SAndroid Build Coastguard Worker op_type: "Shape" 28*89c4ff92SAndroid Build Coastguard Worker } 29*89c4ff92SAndroid Build Coastguard Worker name: "shape-model" 30*89c4ff92SAndroid Build Coastguard Worker input { 31*89c4ff92SAndroid Build Coastguard Worker name: "Input" 32*89c4ff92SAndroid Build Coastguard Worker type { 33*89c4ff92SAndroid Build Coastguard Worker tensor_type { 34*89c4ff92SAndroid Build Coastguard Worker elem_type: )" + inputType + R"( 35*89c4ff92SAndroid Build Coastguard Worker shape { 36*89c4ff92SAndroid Build Coastguard Worker )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"( 37*89c4ff92SAndroid Build Coastguard Worker } 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker } 40*89c4ff92SAndroid Build Coastguard Worker } 41*89c4ff92SAndroid Build Coastguard Worker output { 42*89c4ff92SAndroid Build Coastguard Worker name: "Output" 43*89c4ff92SAndroid Build Coastguard Worker type { 44*89c4ff92SAndroid Build Coastguard Worker tensor_type { 45*89c4ff92SAndroid Build Coastguard Worker elem_type: )" + outputType + R"( 46*89c4ff92SAndroid Build Coastguard Worker shape { 47*89c4ff92SAndroid Build Coastguard Worker dim { 48*89c4ff92SAndroid Build Coastguard Worker dim_value: )" + outputDim + R"( 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker } 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker } 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker opset_import { 56*89c4ff92SAndroid Build Coastguard Worker version: 10 57*89c4ff92SAndroid Build Coastguard Worker })"; 58*89c4ff92SAndroid Build Coastguard Worker } 59*89c4ff92SAndroid Build Coastguard Worker }; 60*89c4ff92SAndroid Build Coastguard Worker 61*89c4ff92SAndroid Build Coastguard Worker struct ShapeFloatFixture : ShapeMainFixture 62*89c4ff92SAndroid Build Coastguard Worker { ShapeFloatFixtureShapeFloatFixture63*89c4ff92SAndroid Build Coastguard Worker ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 }) 64*89c4ff92SAndroid Build Coastguard Worker { 65*89c4ff92SAndroid Build Coastguard Worker Setup(); 66*89c4ff92SAndroid Build Coastguard Worker } 67*89c4ff92SAndroid Build Coastguard Worker }; 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker struct ShapeIntFixture : ShapeMainFixture 70*89c4ff92SAndroid Build Coastguard Worker { ShapeIntFixtureShapeIntFixture71*89c4ff92SAndroid Build Coastguard Worker ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 }) 72*89c4ff92SAndroid Build Coastguard Worker { 73*89c4ff92SAndroid Build Coastguard Worker Setup(); 74*89c4ff92SAndroid Build Coastguard Worker } 75*89c4ff92SAndroid Build Coastguard Worker }; 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker struct Shape3DFixture : ShapeMainFixture 78*89c4ff92SAndroid Build Coastguard Worker { Shape3DFixtureShape3DFixture79*89c4ff92SAndroid Build Coastguard Worker Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 }) 80*89c4ff92SAndroid Build Coastguard Worker { 81*89c4ff92SAndroid Build Coastguard Worker Setup(); 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker }; 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker struct Shape2DFixture : ShapeMainFixture 86*89c4ff92SAndroid Build Coastguard Worker { Shape2DFixtureShape2DFixture87*89c4ff92SAndroid Build Coastguard Worker Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 }) 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker Setup(); 90*89c4ff92SAndroid Build Coastguard Worker } 91*89c4ff92SAndroid Build Coastguard Worker }; 92*89c4ff92SAndroid Build Coastguard Worker 93*89c4ff92SAndroid Build Coastguard Worker struct Shape1DFixture : ShapeMainFixture 94*89c4ff92SAndroid Build Coastguard Worker { Shape1DFixtureShape1DFixture95*89c4ff92SAndroid Build Coastguard Worker Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 }) 96*89c4ff92SAndroid Build Coastguard Worker { 97*89c4ff92SAndroid Build Coastguard Worker Setup(); 98*89c4ff92SAndroid Build Coastguard Worker } 99*89c4ff92SAndroid Build Coastguard Worker }; 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest") 102*89c4ff92SAndroid Build Coastguard Worker { 103*89c4ff92SAndroid Build Coastguard Worker RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 104*89c4ff92SAndroid Build Coastguard Worker 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, 105*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}}); 106*89c4ff92SAndroid Build Coastguard Worker } 107*89c4ff92SAndroid Build Coastguard Worker 108*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest") 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4, 111*89c4ff92SAndroid Build Coastguard Worker 4, 3, 2, 1, 0, 112*89c4ff92SAndroid Build Coastguard Worker 0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}}); 113*89c4ff92SAndroid Build Coastguard Worker } 114*89c4ff92SAndroid Build Coastguard Worker 115*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest") 116*89c4ff92SAndroid Build Coastguard Worker { 117*89c4ff92SAndroid Build Coastguard Worker RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 118*89c4ff92SAndroid Build Coastguard Worker 5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f, 119*89c4ff92SAndroid Build Coastguard Worker 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}}); 120*89c4ff92SAndroid Build Coastguard Worker } 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest") 123*89c4ff92SAndroid Build Coastguard Worker { 124*89c4ff92SAndroid Build Coastguard Worker RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}}); 125*89c4ff92SAndroid Build Coastguard Worker } 126*89c4ff92SAndroid Build Coastguard Worker 127*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest") 128*89c4ff92SAndroid Build Coastguard Worker { 129*89c4ff92SAndroid Build Coastguard Worker RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}}); 130*89c4ff92SAndroid Build Coastguard Worker } 131*89c4ff92SAndroid Build Coastguard Worker 132*89c4ff92SAndroid Build Coastguard Worker } 133