xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Unsqueeze.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 #include "OnnxParserTestUtils.hpp"
9 
10 TEST_SUITE("OnnxParser_Unsqueeze")
11 {
12 
13 struct UnsqueezeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
UnsqueezeFixtureUnsqueezeFixture15     UnsqueezeFixture(const std::vector<int>& axes,
16                      const std::vector<int>& inputShape,
17                      const std::vector<int>& outputShape)
18     {
19         m_Prototext = R"(
20                     ir_version: 8
21                     producer_name: "onnx-example"
22                     graph {
23                       node {
24                         input: "Input"
25                         output: "Output"
26                         op_type: "Unsqueeze"
27                         )" + armnnUtils::ConstructIntsAttribute("axes", axes) + R"(
28                       }
29                       name: "test-model"
30                       input {
31                         name: "Input"
32                         type {
33                           tensor_type {
34                             elem_type: 1
35                             shape {
36                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
37                             }
38                           }
39                         }
40                       }
41                       output {
42                         name: "Output"
43                         type {
44                           tensor_type {
45                             elem_type: 1
46                             shape {
47                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
48                             }
49                           }
50                         }
51                       }
52                     })";
53     }
54 };
55 
56 struct UnsqueezeSingleAxesFixture : UnsqueezeFixture
57 {
UnsqueezeSingleAxesFixtureUnsqueezeSingleAxesFixture58     UnsqueezeSingleAxesFixture() : UnsqueezeFixture({ 0 }, { 2, 3 }, { 1, 2, 3 })
59     {
60         Setup();
61     }
62 };
63 
64 struct UnsqueezeMultiAxesFixture : UnsqueezeFixture
65 {
UnsqueezeMultiAxesFixtureUnsqueezeMultiAxesFixture66     UnsqueezeMultiAxesFixture() : UnsqueezeFixture({ 1, 3 }, { 3, 2, 5 }, { 3, 1, 2, 1, 5 })
67     {
68         Setup();
69     }
70 };
71 
72 struct UnsqueezeUnsortedAxesFixture : UnsqueezeFixture
73 {
UnsqueezeUnsortedAxesFixtureUnsqueezeUnsortedAxesFixture74     UnsqueezeUnsortedAxesFixture() : UnsqueezeFixture({ 3, 0, 1 }, { 2, 5 }, { 1, 1, 2, 1, 5 })
75     {
76         Setup();
77     }
78 };
79 
80 struct UnsqueezeScalarFixture : UnsqueezeFixture
81 {
UnsqueezeScalarFixtureUnsqueezeScalarFixture82     UnsqueezeScalarFixture() : UnsqueezeFixture({ 0 }, { }, { 1 })
83     {
84         Setup();
85     }
86 };
87 
88 TEST_CASE_FIXTURE(UnsqueezeSingleAxesFixture, "UnsqueezeSingleAxesTest")
89 {
90     RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
91                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
92 }
93 
94 TEST_CASE_FIXTURE(UnsqueezeMultiAxesFixture, "UnsqueezeMultiAxesTest")
95 {
96     RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
97                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
98                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
99                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
100                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
101                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
102                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
103                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
104                                     11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
105                                     16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
106                                     21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
107                                     26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
108 }
109 
110 TEST_CASE_FIXTURE(UnsqueezeUnsortedAxesFixture, "UnsqueezeUnsortedAxesTest")
111 {
112     RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
113                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
114                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
115                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
116 }
117 
118 TEST_CASE_FIXTURE(UnsqueezeScalarFixture, "UnsqueezeScalarTest")
119 {
120     RunTest<1, float>({{"Input", { 1.0f }}},
121                       {{"Output", { 1.0f }}});
122 }
123 
124 struct UnsqueezeInputAxesFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125 {
UnsqueezeInputAxesFixtureUnsqueezeInputAxesFixture126     UnsqueezeInputAxesFixture()
127     {
128         m_Prototext = R"(
129                     ir_version: 8
130                     producer_name: "onnx-example"
131                     graph {
132                       node {
133                         input: "Input"
134                         input: "Axes"
135                         output: "Output"
136                         op_type: "Unsqueeze"
137                       }
138                       initializer {
139                           dims: 2
140                           data_type: 7
141                           int64_data: 0
142                           int64_data: 3
143                           name: "Axes"
144                         }
145                       name: "test-model"
146                       input {
147                         name: "Input"
148                         type {
149                           tensor_type {
150                             elem_type: 1
151                             shape {
152                               dim {
153                                 dim_value: 3
154                               }
155                               dim {
156                                 dim_value: 2
157                               }
158                               dim {
159                                 dim_value: 5
160                               }
161                             }
162                           }
163                         }
164                       }
165                       output {
166                         name: "Output"
167                         type {
168                           tensor_type {
169                             elem_type: 1
170                             shape {
171                               dim {
172                                 dim_value: 1
173                               }
174                               dim {
175                                 dim_value: 3
176                               }
177                               dim {
178                                 dim_value: 2
179                               }
180                               dim {
181                                 dim_value: 1
182                               }
183                               dim {
184                                 dim_value: 5
185                               }
186                             }
187                           }
188                         }
189                       }
190                     })";
191         Setup();
192     }
193 };
194 
195 TEST_CASE_FIXTURE(UnsqueezeInputAxesFixture, "UnsqueezeInputAxesTest")
196 {
197     RunTest<5, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
198                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
199                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
200                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
201                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
202                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}},
203                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
204                                     6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
205                                     11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
206                                     16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
207                                     21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
208                                     26.0f, 27.0f, 28.0f, 29.0f, 30.0f }}});
209 }
210 
211 }
212