xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Unsqueeze.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "OnnxParserTestUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Unsqueeze")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14*89c4ff92SAndroid Build Coastguard Worker {
UnsqueezeFixtureUnsqueezeFixture15*89c4ff92SAndroid Build Coastguard Worker     UnsqueezeFixture(const std::vector<int>& axes,
16*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<int>& inputShape,
17*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<int>& outputShape)
18*89c4ff92SAndroid Build Coastguard Worker     {
19*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
20*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
21*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
22*89c4ff92SAndroid Build Coastguard Worker                     graph {
23*89c4ff92SAndroid Build Coastguard Worker                       node {
24*89c4ff92SAndroid Build Coastguard Worker                         input: "Input"
25*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
26*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Unsqueeze"
27*89c4ff92SAndroid Build Coastguard Worker                         )" + armnnUtils::ConstructIntsAttribute("axes", axes) + R"(
28*89c4ff92SAndroid Build Coastguard Worker                       }
29*89c4ff92SAndroid Build Coastguard Worker                       name: "test-model"
30*89c4ff92SAndroid Build Coastguard Worker                       input {
31*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
32*89c4ff92SAndroid Build Coastguard Worker                         type {
33*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
34*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
35*89c4ff92SAndroid Build Coastguard Worker                             shape {
36*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
37*89c4ff92SAndroid Build Coastguard Worker                             }
38*89c4ff92SAndroid Build Coastguard Worker                           }
39*89c4ff92SAndroid Build Coastguard Worker                         }
40*89c4ff92SAndroid Build Coastguard Worker                       }
41*89c4ff92SAndroid Build Coastguard Worker                       output {
42*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
43*89c4ff92SAndroid Build Coastguard Worker                         type {
44*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
45*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
46*89c4ff92SAndroid Build Coastguard Worker                             shape {
47*89c4ff92SAndroid Build Coastguard Worker                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
48*89c4ff92SAndroid Build Coastguard Worker                             }
49*89c4ff92SAndroid Build Coastguard Worker                           }
50*89c4ff92SAndroid Build Coastguard Worker                         }
51*89c4ff92SAndroid Build Coastguard Worker                       }
52*89c4ff92SAndroid Build Coastguard Worker                     })";
53*89c4ff92SAndroid Build Coastguard Worker     }
54*89c4ff92SAndroid Build Coastguard Worker };
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeSingleAxesFixture : UnsqueezeFixture
57*89c4ff92SAndroid Build Coastguard Worker {
UnsqueezeSingleAxesFixtureUnsqueezeSingleAxesFixture58*89c4ff92SAndroid Build Coastguard Worker     UnsqueezeSingleAxesFixture() : UnsqueezeFixture({ 0 }, { 2, 3 }, { 1, 2, 3 })
59*89c4ff92SAndroid Build Coastguard Worker     {
60*89c4ff92SAndroid Build Coastguard Worker         Setup();
61*89c4ff92SAndroid Build Coastguard Worker     }
62*89c4ff92SAndroid Build Coastguard Worker };
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeMultiAxesFixture : UnsqueezeFixture
65*89c4ff92SAndroid Build Coastguard Worker {
UnsqueezeMultiAxesFixtureUnsqueezeMultiAxesFixture66*89c4ff92SAndroid Build Coastguard Worker     UnsqueezeMultiAxesFixture() : UnsqueezeFixture({ 1, 3 }, { 3, 2, 5 }, { 3, 1, 2, 1, 5 })
67*89c4ff92SAndroid Build Coastguard Worker     {
68*89c4ff92SAndroid Build Coastguard Worker         Setup();
69*89c4ff92SAndroid Build Coastguard Worker     }
70*89c4ff92SAndroid Build Coastguard Worker };
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture
73*89c4ff92SAndroid Build Coastguard Worker {
UnsqueezeUnsortedAxesFixtureUnsqueezeUnsortedAxesFixture74*89c4ff92SAndroid Build Coastguard Worker     UnsqueezeUnsortedAxesFixture() : UnsqueezeFixture({ 3, 0, 1 }, { 2, 5 }, { 1, 1, 2, 1, 5 })
75*89c4ff92SAndroid Build Coastguard Worker     {
76*89c4ff92SAndroid Build Coastguard Worker         Setup();
77*89c4ff92SAndroid Build Coastguard Worker     }
78*89c4ff92SAndroid Build Coastguard Worker };
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeScalarFixture : UnsqueezeFixture
81*89c4ff92SAndroid Build Coastguard Worker {
UnsqueezeScalarFixtureUnsqueezeScalarFixture82*89c4ff92SAndroid Build Coastguard Worker     UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 })
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         Setup();
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker };
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest")
89*89c4ff92SAndroid Build Coastguard Worker {
90*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
91*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
92*89c4ff92SAndroid Build Coastguard Worker }
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeMultiAxesFixture, "UnsqueezeMultiAxesTest")
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
97*89c4ff92SAndroid Build Coastguard Worker                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
98*89c4ff92SAndroid Build Coastguard Worker                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
99*89c4ff92SAndroid Build Coastguard Worker                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
100*89c4ff92SAndroid Build Coastguard Worker                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
101*89c4ff92SAndroid Build Coastguard Worker                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
102*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
103*89c4ff92SAndroid Build Coastguard Worker                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
104*89c4ff92SAndroid Build Coastguard Worker                                     11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
105*89c4ff92SAndroid Build Coastguard Worker                                     16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
106*89c4ff92SAndroid Build Coastguard Worker                                     21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
107*89c4ff92SAndroid Build Coastguard Worker                                     26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest")
111*89c4ff92SAndroid Build Coastguard Worker {
112*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
113*89c4ff92SAndroid Build Coastguard Worker                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
114*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
115*89c4ff92SAndroid Build Coastguard Worker                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest")
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, float>({{"Input", { 1.0f }}},
121*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f }}});
122*89c4ff92SAndroid Build Coastguard Worker }
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125*89c4ff92SAndroid Build Coastguard Worker {
UnsqueezeInputAxesFixtureUnsqueezeInputAxesFixture126*89c4ff92SAndroid Build Coastguard Worker     UnsqueezeInputAxesFixture()
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
129*89c4ff92SAndroid Build Coastguard Worker                     ir_version: 8
130*89c4ff92SAndroid Build Coastguard Worker                     producer_name: "onnx-example"
131*89c4ff92SAndroid Build Coastguard Worker                     graph {
132*89c4ff92SAndroid Build Coastguard Worker                       node {
133*89c4ff92SAndroid Build Coastguard Worker                         input: "Input"
134*89c4ff92SAndroid Build Coastguard Worker                         input: "Axes"
135*89c4ff92SAndroid Build Coastguard Worker                         output: "Output"
136*89c4ff92SAndroid Build Coastguard Worker                         op_type: "Unsqueeze"
137*89c4ff92SAndroid Build Coastguard Worker                       }
138*89c4ff92SAndroid Build Coastguard Worker                       initializer {
139*89c4ff92SAndroid Build Coastguard Worker                           dims: 2
140*89c4ff92SAndroid Build Coastguard Worker                           data_type: 7
141*89c4ff92SAndroid Build Coastguard Worker                           int64_data: 0
142*89c4ff92SAndroid Build Coastguard Worker                           int64_data: 3
143*89c4ff92SAndroid Build Coastguard Worker                           name: "Axes"
144*89c4ff92SAndroid Build Coastguard Worker                         }
145*89c4ff92SAndroid Build Coastguard Worker                       name: "test-model"
146*89c4ff92SAndroid Build Coastguard Worker                       input {
147*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
148*89c4ff92SAndroid Build Coastguard Worker                         type {
149*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
150*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
151*89c4ff92SAndroid Build Coastguard Worker                             shape {
152*89c4ff92SAndroid Build Coastguard Worker                               dim {
153*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
154*89c4ff92SAndroid Build Coastguard Worker                               }
155*89c4ff92SAndroid Build Coastguard Worker                               dim {
156*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
157*89c4ff92SAndroid Build Coastguard Worker                               }
158*89c4ff92SAndroid Build Coastguard Worker                               dim {
159*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 5
160*89c4ff92SAndroid Build Coastguard Worker                               }
161*89c4ff92SAndroid Build Coastguard Worker                             }
162*89c4ff92SAndroid Build Coastguard Worker                           }
163*89c4ff92SAndroid Build Coastguard Worker                         }
164*89c4ff92SAndroid Build Coastguard Worker                       }
165*89c4ff92SAndroid Build Coastguard Worker                       output {
166*89c4ff92SAndroid Build Coastguard Worker                         name: "Output"
167*89c4ff92SAndroid Build Coastguard Worker                         type {
168*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
169*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
170*89c4ff92SAndroid Build Coastguard Worker                             shape {
171*89c4ff92SAndroid Build Coastguard Worker                               dim {
172*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
173*89c4ff92SAndroid Build Coastguard Worker                               }
174*89c4ff92SAndroid Build Coastguard Worker                               dim {
175*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 3
176*89c4ff92SAndroid Build Coastguard Worker                               }
177*89c4ff92SAndroid Build Coastguard Worker                               dim {
178*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
179*89c4ff92SAndroid Build Coastguard Worker                               }
180*89c4ff92SAndroid Build Coastguard Worker                               dim {
181*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
182*89c4ff92SAndroid Build Coastguard Worker                               }
183*89c4ff92SAndroid Build Coastguard Worker                               dim {
184*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 5
185*89c4ff92SAndroid Build Coastguard Worker                               }
186*89c4ff92SAndroid Build Coastguard Worker                             }
187*89c4ff92SAndroid Build Coastguard Worker                           }
188*89c4ff92SAndroid Build Coastguard Worker                         }
189*89c4ff92SAndroid Build Coastguard Worker                       }
190*89c4ff92SAndroid Build Coastguard Worker                     })";
191*89c4ff92SAndroid Build Coastguard Worker         Setup();
192*89c4ff92SAndroid Build Coastguard Worker     }
193*89c4ff92SAndroid Build Coastguard Worker };
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(UnsqueezeInputAxesFixture, "UnsqueezeInputAxesTest")
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker     RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
198*89c4ff92SAndroid Build Coastguard Worker                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
199*89c4ff92SAndroid Build Coastguard Worker                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
200*89c4ff92SAndroid Build Coastguard Worker                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
201*89c4ff92SAndroid Build Coastguard Worker                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
202*89c4ff92SAndroid Build Coastguard Worker                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
203*89c4ff92SAndroid Build Coastguard Worker                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
204*89c4ff92SAndroid Build Coastguard Worker                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
205*89c4ff92SAndroid Build Coastguard Worker                                     11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
206*89c4ff92SAndroid Build Coastguard Worker                                     16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
207*89c4ff92SAndroid Build Coastguard Worker                                     21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
208*89c4ff92SAndroid Build Coastguard Worker                                     26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
209*89c4ff92SAndroid Build Coastguard Worker }
210*89c4ff92SAndroid Build Coastguard Worker 
211*89c4ff92SAndroid Build Coastguard Worker }
212