xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Flatten.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_Flatter")
10 {
11 struct FlattenMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
FlattenMainFixtureFlattenMainFixture13     FlattenMainFixture(const std::string& dataType)
14     {
15         m_Prototext = R"(
16                    ir_version: 3
17                    producer_name:  "CNTK"
18                    producer_version:  "2.5.1"
19                    domain:  "ai.cntk"
20                    model_version: 1
21                    graph {
22                      name:  "CNTKGraph"
23                      input {
24                         name: "Input"
25                         type {
26                           tensor_type {
27                             elem_type: )" + dataType + R"(
28                             shape {
29                               dim {
30                                 dim_value: 2
31                               }
32                               dim {
33                                 dim_value: 2
34                               }
35                               dim {
36                                 dim_value: 3
37                               }
38                               dim {
39                                 dim_value: 3
40                               }
41                             }
42                           }
43                         }
44                       }
45                      node {
46                          input: "Input"
47                          output: "Output"
48                          name: "flatten"
49                          op_type: "Flatten"
50                          attribute {
51                            name: "axis"
52                            i: 2
53                            type: INT
54                          }
55                       }
56                       output {
57                           name: "Output"
58                           type {
59                              tensor_type {
60                                elem_type: 1
61                                shape {
62                                    dim {
63                                        dim_value: 4
64                                    }
65                                    dim {
66                                        dim_value: 9
67                                    }
68                                }
69                             }
70                           }
71                        }
72                     }
73                    opset_import {
74                       version: 7
75                     })";
76     }
77 };
78 
79 struct FlattenDefaultAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
80 {
FlattenDefaultAxisFixtureFlattenDefaultAxisFixture81     FlattenDefaultAxisFixture(const std::string& dataType)
82     {
83         m_Prototext = R"(
84                    ir_version: 3
85                    producer_name:  "CNTK"
86                    producer_version:  "2.5.1"
87                    domain:  "ai.cntk"
88                    model_version: 1
89                    graph {
90                      name:  "CNTKGraph"
91                      input {
92                         name: "Input"
93                         type {
94                           tensor_type {
95                             elem_type: )" + dataType + R"(
96                             shape {
97                               dim {
98                                 dim_value: 2
99                               }
100                               dim {
101                                 dim_value: 2
102                               }
103                               dim {
104                                 dim_value: 3
105                               }
106                               dim {
107                                 dim_value: 3
108                               }
109                             }
110                           }
111                         }
112                       }
113                      node {
114                          input: "Input"
115                          output: "Output"
116                          name: "flatten"
117                          op_type: "Flatten"
118                       }
119                       output {
120                           name: "Output"
121                           type {
122                              tensor_type {
123                                elem_type: 1
124                                shape {
125                                    dim {
126                                        dim_value: 2
127                                    }
128                                    dim {
129                                        dim_value: 18
130                                    }
131                                }
132                             }
133                           }
134                        }
135                     }
136                    opset_import {
137                       version: 7
138                     })";
139     }
140 };
141 
142 struct FlattenAxisZeroFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
143 {
FlattenAxisZeroFixtureFlattenAxisZeroFixture144     FlattenAxisZeroFixture(const std::string& dataType)
145     {
146         m_Prototext = R"(
147                    ir_version: 3
148                    producer_name:  "CNTK"
149                    producer_version:  "2.5.1"
150                    domain:  "ai.cntk"
151                    model_version: 1
152                    graph {
153                      name:  "CNTKGraph"
154                      input {
155                         name: "Input"
156                         type {
157                           tensor_type {
158                             elem_type: )" + dataType + R"(
159                             shape {
160                               dim {
161                                 dim_value: 2
162                               }
163                               dim {
164                                 dim_value: 2
165                               }
166                               dim {
167                                 dim_value: 3
168                               }
169                               dim {
170                                 dim_value: 3
171                               }
172                             }
173                           }
174                         }
175                       }
176                      node {
177                          input: "Input"
178                          output: "Output"
179                          name: "flatten"
180                          op_type: "Flatten"
181                          attribute {
182                            name: "axis"
183                            i: 0
184                            type: INT
185                          }
186                       }
187                       output {
188                           name: "Output"
189                           type {
190                              tensor_type {
191                                elem_type: 1
192                                shape {
193                                    dim {
194                                        dim_value: 1
195                                    }
196                                    dim {
197                                        dim_value: 36
198                                    }
199                                }
200                             }
201                           }
202                        }
203                     }
204                    opset_import {
205                       version: 7
206                     })";
207     }
208 };
209 
210 struct FlattenNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
211 {
FlattenNegativeAxisFixtureFlattenNegativeAxisFixture212     FlattenNegativeAxisFixture(const std::string& dataType)
213     {
214         m_Prototext = R"(
215                    ir_version: 3
216                    producer_name:  "CNTK"
217                    producer_version:  "2.5.1"
218                    domain:  "ai.cntk"
219                    model_version: 1
220                    graph {
221                      name:  "CNTKGraph"
222                      input {
223                         name: "Input"
224                         type {
225                           tensor_type {
226                             elem_type: )" + dataType + R"(
227                             shape {
228                               dim {
229                                 dim_value: 2
230                               }
231                               dim {
232                                 dim_value: 2
233                               }
234                               dim {
235                                 dim_value: 3
236                               }
237                               dim {
238                                 dim_value: 3
239                               }
240                             }
241                           }
242                         }
243                       }
244                      node {
245                          input: "Input"
246                          output: "Output"
247                          name: "flatten"
248                          op_type: "Flatten"
249                          attribute {
250                            name: "axis"
251                            i: -1
252                            type: INT
253                          }
254                       }
255                       output {
256                           name: "Output"
257                           type {
258                              tensor_type {
259                                elem_type: 1
260                                shape {
261                                    dim {
262                                        dim_value: 12
263                                    }
264                                    dim {
265                                        dim_value: 3
266                                    }
267                                }
268                             }
269                           }
270                        }
271                     }
272                    opset_import {
273                       version: 7
274                     })";
275     }
276 };
277 
278 struct FlattenInvalidNegativeAxisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
279 {
FlattenInvalidNegativeAxisFixtureFlattenInvalidNegativeAxisFixture280     FlattenInvalidNegativeAxisFixture(const std::string& dataType)
281     {
282         m_Prototext = R"(
283                    ir_version: 3
284                    producer_name:  "CNTK"
285                    producer_version:  "2.5.1"
286                    domain:  "ai.cntk"
287                    model_version: 1
288                    graph {
289                      name:  "CNTKGraph"
290                      input {
291                         name: "Input"
292                         type {
293                           tensor_type {
294                             elem_type: )" + dataType + R"(
295                             shape {
296                               dim {
297                                 dim_value: 2
298                               }
299                               dim {
300                                 dim_value: 2
301                               }
302                               dim {
303                                 dim_value: 3
304                               }
305                               dim {
306                                 dim_value: 3
307                               }
308                             }
309                           }
310                         }
311                       }
312                      node {
313                          input: "Input"
314                          output: "Output"
315                          name: "flatten"
316                          op_type: "Flatten"
317                          attribute {
318                            name: "axis"
319                            i: -5
320                            type: INT
321                          }
322                       }
323                       output {
324                           name: "Output"
325                           type {
326                              tensor_type {
327                                elem_type: 1
328                                shape {
329                                    dim {
330                                        dim_value: 12
331                                    }
332                                    dim {
333                                        dim_value: 3
334                                    }
335                                }
336                             }
337                           }
338                        }
339                     }
340                    opset_import {
341                       version: 7
342                     })";
343     }
344 };
345 
346 struct FlattenValidFixture : FlattenMainFixture
347 {
FlattenValidFixtureFlattenValidFixture348     FlattenValidFixture() : FlattenMainFixture("1") {
349         Setup();
350     }
351 };
352 
353 struct FlattenDefaultValidFixture : FlattenDefaultAxisFixture
354 {
FlattenDefaultValidFixtureFlattenDefaultValidFixture355     FlattenDefaultValidFixture() : FlattenDefaultAxisFixture("1") {
356         Setup();
357     }
358 };
359 
360 struct FlattenAxisZeroValidFixture : FlattenAxisZeroFixture
361 {
FlattenAxisZeroValidFixtureFlattenAxisZeroValidFixture362     FlattenAxisZeroValidFixture() : FlattenAxisZeroFixture("1") {
363         Setup();
364     }
365 };
366 
367 struct FlattenNegativeAxisValidFixture : FlattenNegativeAxisFixture
368 {
FlattenNegativeAxisValidFixtureFlattenNegativeAxisValidFixture369     FlattenNegativeAxisValidFixture() : FlattenNegativeAxisFixture("1") {
370         Setup();
371     }
372 };
373 
374 struct FlattenInvalidFixture : FlattenMainFixture
375 {
FlattenInvalidFixtureFlattenInvalidFixture376     FlattenInvalidFixture() : FlattenMainFixture("10") { }
377 };
378 
379 struct FlattenInvalidAxisFixture : FlattenInvalidNegativeAxisFixture
380 {
FlattenInvalidAxisFixtureFlattenInvalidAxisFixture381     FlattenInvalidAxisFixture() : FlattenInvalidNegativeAxisFixture("1") { }
382 };
383 
384 TEST_CASE_FIXTURE(FlattenValidFixture, "ValidFlattenTest")
385 {
386     RunTest<2>({{"Input",
387                           { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
388                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
389                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
390                 {{"Output",
391                           { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
392                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
393                             1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
394 }
395 
396 TEST_CASE_FIXTURE(FlattenDefaultValidFixture, "ValidFlattenDefaultTest")
397 {
398     RunTest<2>({{"Input",
399                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
400                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
401                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
402                {{"Output",
403                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
404                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
405                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
406 }
407 
408 TEST_CASE_FIXTURE(FlattenAxisZeroValidFixture, "ValidFlattenAxisZeroTest")
409 {
410     RunTest<2>({{"Input",
411                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
412                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
413                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
414                {{"Output",
415                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
416                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
417                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
418 }
419 
420 TEST_CASE_FIXTURE(FlattenNegativeAxisValidFixture, "ValidFlattenNegativeAxisTest")
421 {
422     RunTest<2>({{"Input",
423                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
424                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
425                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}},
426                {{"Output",
427                     { 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
428                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f,
429                         1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f, 1.0f, 2.0f, 3.0f, 4.0f }}});
430 }
431 
432 TEST_CASE_FIXTURE(FlattenInvalidFixture, "IncorrectDataTypeFlatten")
433 {
434     CHECK_THROWS_AS(Setup(), armnn::ParseException);
435 }
436 
437 TEST_CASE_FIXTURE(FlattenInvalidAxisFixture, "IncorrectAxisFlatten")
438 {
439     CHECK_THROWS_AS(Setup(), armnn::ParseException);
440 }
441 
442 }
443