xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Shape.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Shape")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker struct ShapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14*89c4ff92SAndroid Build Coastguard Worker {
ShapeMainFixtureShapeMainFixture15*89c4ff92SAndroid Build Coastguard Worker     ShapeMainFixture(const std::string& inputType,
16*89c4ff92SAndroid Build Coastguard Worker                      const std::string& outputType,
17*89c4ff92SAndroid Build Coastguard Worker                      const std::string& outputDim,
18*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<int>& inputShape)
19*89c4ff92SAndroid Build Coastguard Worker     {
20*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
21*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
22*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
23*89c4ff92SAndroid Build Coastguard Worker                     graph {
24*89c4ff92SAndroid Build Coastguard Worker                       node {
25*89c4ff92SAndroid Build Coastguard Worker                         input: "Input"
26*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
27*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Shape"
28*89c4ff92SAndroid Build Coastguard Worker                       }
29*89c4ff92SAndroid Build Coastguard Worker                       name: "shape-model"
30*89c4ff92SAndroid Build Coastguard Worker                       input {
31*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
32*89c4ff92SAndroid Build Coastguard Worker                         type {
33*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
34*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + inputType + R"(
35*89c4ff92SAndroid Build Coastguard Worker                             shape {
36*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
37*89c4ff92SAndroid Build Coastguard Worker                             }
38*89c4ff92SAndroid Build Coastguard Worker                           }
39*89c4ff92SAndroid Build Coastguard Worker                         }
40*89c4ff92SAndroid Build Coastguard Worker                       }
41*89c4ff92SAndroid Build Coastguard Worker                       output {
42*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
43*89c4ff92SAndroid Build Coastguard Worker                         type {
44*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
45*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + outputType + R"(
46*89c4ff92SAndroid Build Coastguard Worker                             shape {
47*89c4ff92SAndroid Build Coastguard Worker                               dim {
48*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: )" + outputDim + R"(
49*89c4ff92SAndroid Build Coastguard Worker                               }
50*89c4ff92SAndroid Build Coastguard Worker                             }
51*89c4ff92SAndroid Build Coastguard Worker                           }
52*89c4ff92SAndroid Build Coastguard Worker                         }
53*89c4ff92SAndroid Build Coastguard Worker                       }
54*89c4ff92SAndroid Build Coastguard Worker                     }
55*89c4ff92SAndroid Build Coastguard Worker                     opset_import {
56*89c4ff92SAndroid Build Coastguard Worker                       version: 10
57*89c4ff92SAndroid Build Coastguard Worker                     })";
58*89c4ff92SAndroid Build Coastguard Worker     }
59*89c4ff92SAndroid Build Coastguard Worker };
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker struct ShapeFloatFixture : ShapeMainFixture
62*89c4ff92SAndroid Build Coastguard Worker {
ShapeFloatFixtureShapeFloatFixture63*89c4ff92SAndroid Build Coastguard Worker     ShapeFloatFixture() : ShapeMainFixture("1", "7", "4", { 1, 3, 1, 5 })
64*89c4ff92SAndroid Build Coastguard Worker     {
65*89c4ff92SAndroid Build Coastguard Worker         Setup();
66*89c4ff92SAndroid Build Coastguard Worker     }
67*89c4ff92SAndroid Build Coastguard Worker };
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker struct ShapeIntFixture : ShapeMainFixture
70*89c4ff92SAndroid Build Coastguard Worker {
ShapeIntFixtureShapeIntFixture71*89c4ff92SAndroid Build Coastguard Worker     ShapeIntFixture() : ShapeMainFixture("7", "7", "4", { 1, 3, 1, 5 })
72*89c4ff92SAndroid Build Coastguard Worker     {
73*89c4ff92SAndroid Build Coastguard Worker         Setup();
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker };
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker struct Shape3DFixture : ShapeMainFixture
78*89c4ff92SAndroid Build Coastguard Worker {
Shape3DFixtureShape3DFixture79*89c4ff92SAndroid Build Coastguard Worker     Shape3DFixture() : ShapeMainFixture("1", "7", "3", { 3, 2, 3 })
80*89c4ff92SAndroid Build Coastguard Worker     {
81*89c4ff92SAndroid Build Coastguard Worker         Setup();
82*89c4ff92SAndroid Build Coastguard Worker     }
83*89c4ff92SAndroid Build Coastguard Worker };
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker struct Shape2DFixture : ShapeMainFixture
86*89c4ff92SAndroid Build Coastguard Worker {
Shape2DFixtureShape2DFixture87*89c4ff92SAndroid Build Coastguard Worker     Shape2DFixture() : ShapeMainFixture("1", "7", "2", { 2, 3 })
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         Setup();
90*89c4ff92SAndroid Build Coastguard Worker     }
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker struct Shape1DFixture : ShapeMainFixture
94*89c4ff92SAndroid Build Coastguard Worker {
Shape1DFixtureShape1DFixture95*89c4ff92SAndroid Build Coastguard Worker     Shape1DFixture() : ShapeMainFixture("1", "7", "1", { 5 })
96*89c4ff92SAndroid Build Coastguard Worker     {
97*89c4ff92SAndroid Build Coastguard Worker         Setup();
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker };
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ShapeFloatFixture, "FloatValidShapeTest")
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f,
104*89c4ff92SAndroid Build Coastguard Worker                                  4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
105*89c4ff92SAndroid Build Coastguard Worker                                  0.0f, 1.0f, 2.0f, 3.0f, 4.0f }}}, {{"Output", { 1, 3, 1, 5 }}});
106*89c4ff92SAndroid Build Coastguard Worker }
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ShapeIntFixture, "IntValidShapeTest")
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, int>({{"Input", { 0, 1, 2, 3, 4,
111*89c4ff92SAndroid Build Coastguard Worker                                  4, 3, 2, 1, 0,
112*89c4ff92SAndroid Build Coastguard Worker                                  0, 1, 2, 3, 4 }}}, {{"Output", { 1, 3, 1, 5 }}});
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Shape3DFixture, "Shape3DTest")
116*89c4ff92SAndroid Build Coastguard Worker {
117*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
118*89c4ff92SAndroid Build Coastguard Worker                                  5.0f, 4.0f, 3.0f, 2.0f, 1.0f, 0.0f,
119*89c4ff92SAndroid Build Coastguard Worker                                  0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 3, 2, 3 }}});
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Shape2DFixture, "Shape2DTest")
123*89c4ff92SAndroid Build Coastguard Worker {
124*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, int>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 2, 3 }}});
125*89c4ff92SAndroid Build Coastguard Worker }
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Shape1DFixture, "Shape1DTest")
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, int>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f }}}, {{"Output", { 5 }}});
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker }
133