xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Reshape.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 #include "OnnxParserTestUtils.hpp"
9 
10 TEST_SUITE("OnnxParser_Reshape")
11 {
12 struct ReshapeMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
13 {
ReshapeMainFixtureReshapeMainFixture14     ReshapeMainFixture(const std::string& dataType)
15     {
16         m_Prototext = R"(
17                    ir_version: 3
18                    producer_name:  "CNTK"
19                    producer_version:  "2.5.1"
20                    domain:  "ai.cntk"
21                    model_version: 1
22                    graph {
23                      name:  "CNTKGraph"
24                      input {
25                         name: "Input"
26                         type {
27                           tensor_type {
28                             elem_type: )" + dataType + R"(
29                             shape {
30                               dim {
31                                 dim_value: 4
32                               }
33                             }
34                           }
35                         }
36                       }
37                       input {
38                          name: "Shape"
39                          type {
40                            tensor_type {
41                              elem_type: 7
42                              shape {
43                                dim {
44                                  dim_value: 2
45                                }
46                              }
47                            }
48                          }
49                        }
50                      node {
51                          input: "Input"
52                          input: "Shape"
53                          output: "Output"
54                          name: "reshape"
55                          op_type: "Reshape"
56 
57                       }
58                       initializer {
59                         dims: 2
60                         data_type: 7
61                         int64_data: 2
62                         int64_data: 2
63                         name: "Shape"
64                      }
65                       output {
66                           name: "Output"
67                           type {
68                              tensor_type {
69                                elem_type: 1
70                                shape {
71                                    dim {
72                                        dim_value: 2
73                                    }
74                                    dim {
75                                        dim_value: 2
76                                    }
77                                }
78                             }
79                           }
80                        }
81                     }
82                    opset_import {
83                       version: 7
84                     })";
85     }
86 };
87 
88 struct ReshapeRank4Fixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
89 {
ReshapeRank4FixtureReshapeRank4Fixture90     ReshapeRank4Fixture(const std::string& dataType)
91     {
92         m_Prototext = R"(
93                    ir_version: 3
94                    producer_name:  "CNTK"
95                    producer_version:  "2.5.1"
96                    domain:  "ai.cntk"
97                    model_version: 1
98                    graph {
99                      name:  "CNTKGraph"
100                      input {
101                         name: "Input"
102                         type {
103                           tensor_type {
104                             elem_type: )" + dataType + R"(
105                             shape {
106                               dim {
107                                 dim_value: 2
108                               }
109                               dim {
110                                 dim_value: 2
111                               }
112                               dim {
113                                 dim_value: 3
114                               }
115                               dim {
116                                 dim_value: 3
117                               }
118                             }
119                           }
120                         }
121                       }
122                       input {
123                          name: "Shape"
124                          type {
125                            tensor_type {
126                              elem_type: 7
127                              shape {
128                                dim {
129                                  dim_value: 2
130                                }
131                              }
132                            }
133                          }
134                        }
135                      node {
136                          input: "Input"
137                          input: "Shape"
138                          output: "Output"
139                          name: "reshape"
140                          op_type: "Reshape"
141 
142                       }
143                       initializer {
144                         dims: 2
145                         data_type: 7
146                         int64_data: 2
147                         int64_data: 2
148                         name: "Shape"
149                      }
150                       output {
151                           name: "Output"
152                           type {
153                              tensor_type {
154                                elem_type: 1
155                                shape {
156                                    dim {
157                                        dim_value: 6
158                                    }
159                                    dim {
160                                        dim_value: 6
161                                    }
162                                }
163                             }
164                           }
165                        }
166                     }
167                    opset_import {
168                       version: 7
169                     })";
170     }
171 };
172 
173 struct ReshapeValidFixture : ReshapeMainFixture
174 {
ReshapeValidFixtureReshapeValidFixture175     ReshapeValidFixture() : ReshapeMainFixture("1") {
176         Setup();
177     }
178 };
179 
180 struct ReshapeValidRank4Fixture : ReshapeRank4Fixture
181 {
ReshapeValidRank4FixtureReshapeValidRank4Fixture182     ReshapeValidRank4Fixture() : ReshapeRank4Fixture("1") {
183         Setup();
184     }
185 };
186 
187 struct ReshapeInvalidFixture : ReshapeMainFixture
188 {
ReshapeInvalidFixtureReshapeInvalidFixture189     ReshapeInvalidFixture() : ReshapeMainFixture("10") { }
190 };
191 
192 TEST_CASE_FIXTURE(ReshapeValidFixture, "ValidReshapeTest")
193 {
194     RunTest<2>({{"Input", { 0.0f, 1.0f, 2.0f, 3.0f }}}, {{"Output", { 0.0f, 1.0f, 2.0f, 3.0f }}});
195 }
196 
197 TEST_CASE_FIXTURE(ReshapeValidRank4Fixture, "ValidRank4ReshapeTest")
198 {
199     RunTest<2>(
200         {{"Input",
201                    {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
202                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
203                     1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}},
204         {{"Output",
205                     {1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
206                      1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
207                      1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f}}});
208 }
209 
210 TEST_CASE_FIXTURE(ReshapeInvalidFixture, "IncorrectDataTypeReshape")
211 {
212    CHECK_THROWS_AS(Setup(), armnn::ParseException);
213 }
214 
215 struct ReshapeNegativeReshapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
216 {
ReshapeNegativeReshapeFixtureReshapeNegativeReshapeFixture217     ReshapeNegativeReshapeFixture(const std::vector<int>& inputShape,
218                                   const std::vector<int>& shapeInputShape,
219                                   const std::vector<int>& outputShape,
220                                   const std::string& shape)
221         {
222         m_Prototext = R"(
223                    ir_version: 3
224                    producer_name: "onnx-example"
225                    graph {
226                      name:  "ReshapeGrapn"
227                      input {
228                         name: "Input"
229                         type {
230                           tensor_type {
231                             elem_type: 1
232                             shape {
233                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
234                             }
235                           }
236                         }
237                       }
238                       input {
239                          name: "Shape"
240                          type {
241                            tensor_type {
242                              elem_type: 7
243                              shape {
244                                )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
245                              }
246                            }
247                          }
248                        }
249                      node {
250                          input: "Input"
251                          input: "Shape"
252                          output: "Output"
253                          name: "reshape"
254                          op_type: "Reshape"
255                       }
256                       initializer {
257                         dims: 2
258                         data_type: 7
259                         )" + shape + R"(
260                         name: "Shape"
261                      }
262                       output {
263                           name: "Output"
264                           type {
265                              tensor_type {
266                                elem_type: 1
267                                shape {
268                                  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
269                                }
270                             }
271                           }
272                        }
273                     }
274                    opset_import {
275                       version: 7
276                    })";
277     }
278 };
279 
280 struct ReshapeNegativeReshape1DFixture : ReshapeNegativeReshapeFixture
281 {
ReshapeNegativeReshape1DFixtureReshapeNegativeReshape1DFixture282     ReshapeNegativeReshape1DFixture() : ReshapeNegativeReshapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 }, "int64_data: -1")
283     {
284         Setup();
285     }
286 };
287 
288 struct ReshapeNegativeReshape2DFixture : ReshapeNegativeReshapeFixture
289 {
ReshapeNegativeReshape2DFixtureReshapeNegativeReshape2DFixture290     ReshapeNegativeReshape2DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
291                                                                       { 2 },
292                                                                       { 2, 6 },
293                                                                       "int64_data: -1  int64_data: 6")
294     {
295         Setup();
296     }
297 };
298 
299 struct ReshapeNegativeReshape3DFixture : ReshapeNegativeReshapeFixture
300 {
ReshapeNegativeReshape3DFixtureReshapeNegativeReshape3DFixture301     ReshapeNegativeReshape3DFixture() : ReshapeNegativeReshapeFixture({ 2, 3, 1, 2 },
302                                                                       { 3 },
303                                                                       { 3, 1, 4 },
304                                                                       "int64_data: 3  int64_data: -1  int64_data: 4")
305     {
306         Setup();
307     }
308 };
309 
310 struct ReshapeNegativeReshape4DFixture : ReshapeNegativeReshapeFixture
311 {
ReshapeNegativeReshape4DFixtureReshapeNegativeReshape4DFixture312     ReshapeNegativeReshape4DFixture() : ReshapeNegativeReshapeFixture(
313         { 2, 3, 1, 2 },
314         { 4 },
315         { 3, 1, 2, 2 },
316         "int64_data: 3  int64_data: 1  int64_data: 2  int64_data: -1")
317     {
318         Setup();
319     }
320 };
321 
322 TEST_CASE_FIXTURE(ReshapeNegativeReshape1DFixture, "ReshapeNegativeReshape1DTest")
323 {
324     RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
325                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
326 }
327 
328 TEST_CASE_FIXTURE(ReshapeNegativeReshape2DFixture, "ReshapeNegativeReshape2DTest")
329 {
330     RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
331                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
332                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
333                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
334 }
335 
336 TEST_CASE_FIXTURE(ReshapeNegativeReshape3DFixture, "ReshapeNegativeReshape3DTest")
337 {
338     RunTest<3, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
339                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
340                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
341                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
342 }
343 
344 TEST_CASE_FIXTURE(ReshapeNegativeReshape4DFixture, "ReshapeNegativeReshape4DTest")
345 {
346     RunTest<4, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
347                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}},
348                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
349                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f }}});
350 }
351 
352 struct ReshapeNonConstShapeFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
353 {
ReshapeNonConstShapeFixtureReshapeNonConstShapeFixture354     ReshapeNonConstShapeFixture(const std::vector<int>& inputShape,
355                                 const std::vector<int>& shapeInputShape,
356                                 const std::vector<int>& outputShape)
357     {
358         m_Prototext = R"(
359                    ir_version: 3
360                    producer_name: "onnx-example"
361                    graph {
362                      name:  "ReshapeGrapn"
363                      input {
364                         name: "Input"
365                         type {
366                           tensor_type {
367                             elem_type: 1
368                             shape {
369                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
370                             }
371                           }
372                         }
373                       }
374                       input {
375                          name: "Shape"
376                          type {
377                            tensor_type {
378                              elem_type: 7
379                              shape {
380                                )" + armnnUtils::ConstructTensorShapeString(shapeInputShape) + R"(
381                              }
382                            }
383                          }
384                        }
385                      node {
386                          input: "Input"
387                          input: "Shape"
388                          output: "Output"
389                          name: "reshape"
390                          op_type: "Reshape"
391                       }
392                       output {
393                           name: "Output"
394                           type {
395                              tensor_type {
396                                elem_type: 1
397                                shape {
398                                  )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
399                                }
400                             }
401                           }
402                        }
403                     }
404                    opset_import {
405                       version: 7
406                    })";
407     }
408 };
409 
410 struct ReshapeNonConst1DShapeFixture : ReshapeNonConstShapeFixture
411 {
ReshapeNonConst1DShapeFixtureReshapeNonConst1DShapeFixture412     ReshapeNonConst1DShapeFixture() : ReshapeNonConstShapeFixture({ 1, 3, 1, 2 }, { 1 }, { 6 })
413     {
414         Setup();
415     }
416 };
417 
418 struct ReshapeNonConst2DShapeFixture : ReshapeNonConstShapeFixture
419 {
ReshapeNonConst2DShapeFixtureReshapeNonConst2DShapeFixture420     ReshapeNonConst2DShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 2 }, { 2, 12 })
421     {
422         Setup();
423     }
424 };
425 
426 struct ReshapeInvalidNonConstShapeFixture : ReshapeNonConstShapeFixture
427 {
ReshapeInvalidNonConstShapeFixtureReshapeInvalidNonConstShapeFixture428     ReshapeInvalidNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 3 }, { 2, 3, 4 })
429     {
430     }
431 };
432 
433 struct ReshapeInvalidDimNonConstShapeFixture : ReshapeNonConstShapeFixture
434 {
ReshapeInvalidDimNonConstShapeFixtureReshapeInvalidDimNonConstShapeFixture435     ReshapeInvalidDimNonConstShapeFixture() : ReshapeNonConstShapeFixture({ 2, 3, 2, 2 }, { 1, 2 }, { 2, 3, 4 })
436     {
437     }
438 };
439 
440 TEST_CASE_FIXTURE(ReshapeNonConst1DShapeFixture, "ReshapeNonConst1DShapeTest")
441 {
442     RunTest<1, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}},
443                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f }}});
444 }
445 
446 TEST_CASE_FIXTURE(ReshapeNonConst2DShapeFixture, "ReshapeNonConst2DShapeTest")
447 {
448     RunTest<2, float>({{"Input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
449                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
450                                    13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
451                                    19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}},
452                       {{"Output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
453                                     7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
454                                     13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
455                                     19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f }}});
456 }
457 
458 TEST_CASE_FIXTURE(ReshapeInvalidNonConstShapeFixture, "ReshapeInvalidNonConstShapeTest")
459 {
460     CHECK_THROWS_AS(Setup(), armnn::ParseException);
461 }
462 
463 TEST_CASE_FIXTURE(ReshapeInvalidDimNonConstShapeFixture, "ReshapeInvalidDimNonConstShapeTest")
464 {
465     CHECK_THROWS_AS(Setup(), armnn::ParseException);
466 }
467 
468 }
469