xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Gather.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 #include "OnnxParserTestUtils.hpp"
9 
10 TEST_SUITE("OnnxParser_Gather")
11 {
12 
13 struct GatherMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
GatherMainFixtureGatherMainFixture15     GatherMainFixture(const std::vector<int>& indicesShape,
16                       const std::vector<int>& indices,
17                       const std::vector<int>& inputShape,
18                       const std::vector<int>& outputShape)
19     {
20         m_Prototext = R"(
21                     ir_version: 8
22                     producer_name: "onnx-example"
23                     graph {
24                       node {
25                         output: "indices"
26                         op_type: "Constant"
27                         attribute {
28                           name: "value"
29                           t {
30                             data_type: 7
31                             )" + ConstructIndicesString(indicesShape, indices) + R"(
32                             name: "value"
33                           }
34                           type: TENSOR
35                         }
36                       }
37                       node {
38                         input: "input"
39                         input: "indices"
40                         output: "output"
41                         op_type: "Gather"
42                         attribute {
43                           name: "axis"
44                           i: 0
45                           type: INT
46                         }
47                       }
48                       name: "gather-model"
49                       input {
50                         name: "input"
51                         type {
52                           tensor_type {
53                             elem_type: 1
54                             shape {
55                               )" + armnnUtils::ConstructTensorShapeString(inputShape) + R"(
56                             }
57                           }
58                         }
59                       }
60                       output {
61                         name: "output"
62                         type {
63                           tensor_type {
64                             elem_type: 1
65                             shape {
66                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
67                             }
68                           }
69                         }
70                       }
71                     })";
72     }
ConstructIndicesStringGatherMainFixture73     std::string ConstructIndicesString(const std::vector<int>& indicesShape, const std::vector<int>& indices)
74     {
75         std::string shapeStr;
76         for (int i : indicesShape)
77         {
78             shapeStr = fmt::format(" {} dims: {}", shapeStr, i);
79         }
80         for (int i : indices)
81         {
82             shapeStr = fmt::format(" {} int64_data: {}", shapeStr, i);
83         }
84         return shapeStr;
85     }
86 };
87 
88 struct GatherScalarFixture : GatherMainFixture
89 {
GatherScalarFixtureGatherScalarFixture90     GatherScalarFixture() : GatherMainFixture({ }, { 0 }, { 8 }, { })
91     {
92         Setup();
93     }
94 };
95 
96 struct Gather1dFixture : GatherMainFixture
97 {
Gather1dFixtureGather1dFixture98     Gather1dFixture() : GatherMainFixture({ 4 }, { 0, 2, 1, 5 }, { 8 }, { 4 })
99     {
100         Setup();
101     }
102 };
103 
104 struct Gather2dFixture : GatherMainFixture
105 {
Gather2dFixtureGather2dFixture106     Gather2dFixture() : GatherMainFixture({ 3 }, { 1, 3, 4 }, { 5, 2 }, { 3, 2 })
107     {
108         Setup();
109     }
110 };
111 
112 struct Gather3dMultiIndicesFixture : GatherMainFixture
113 {
Gather3dMultiIndicesFixtureGather3dMultiIndicesFixture114     Gather3dMultiIndicesFixture() : GatherMainFixture({ 2, 3 }, { 1, 2, 1, 2, 1, 0 }, { 3, 2, 3 }, { 2, 3, 2, 3 })
115     {
116         Setup();
117     }
118 };
119 
120 struct Gather4dFixture : GatherMainFixture
121 {
Gather4dFixtureGather4dFixture122     Gather4dFixture() : GatherMainFixture({ 3 }, { 0, 1, 3 }, { 5, 4, 3, 2 }, { 3, 4, 3, 2 })
123     {
124         Setup();
125     }
126 };
127 
128 TEST_CASE_FIXTURE(GatherScalarFixture, "GatherScalarTest")
129 {
130     RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
131                       {{"output", { 1.0f }}});
132 }
133 
134 TEST_CASE_FIXTURE(Gather1dFixture, "Gather1dTest")
135 {
136     RunTest<1, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }}},
137                       {{"output", { 1.0f, 3.0f, 2.0f, 6.0f }}});
138 }
139 
140 TEST_CASE_FIXTURE(Gather2dFixture, "Gather2dTest")
141 {
142     RunTest<2, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f }}},
143                       {{"output", { 3.0f, 4.0f, 7.0f, 8.0f, 9.0f, 10.0f }}});
144 }
145 
146 TEST_CASE_FIXTURE(Gather3dMultiIndicesFixture, "Gather3dMultiIndicesTest")
147 {
148     RunTest<3, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
149                                    7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
150                                    13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f }}},
151                       {{"output", { 7.0f,  8.0f,  9.0f,
152                                     10.0f, 11.0f, 12.0f,
153                                     13.0f, 14.0f, 15.0f,
154                                     16.0f, 17.0f, 18.0f,
155                                     7.0f,  8.0f,  9.0f,
156                                     10.0f, 11.0f, 12.0f,
157                                     13.0f, 14.0f, 15.0f,
158                                     16.0f, 17.0f, 18.0f,
159                                     7.0f,  8.0f,  9.0f,
160                                     10.0f, 11.0f, 12.0f,
161                                     1.0f,  2.0f,  3.0f,
162                                     4.0f,  5.0f,  6.0f }}});
163 }
164 
165 TEST_CASE_FIXTURE(Gather4dFixture, "Gather4dTest")
166 {
167     RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
168                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
169                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
170                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
171                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
172                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
173                                    31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
174                                    36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
175                                    41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
176                                    46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
177                                    51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
178                                    56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
179                                    61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
180                                    66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
181                                    71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
182                                    76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
183                                    81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
184                                    86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
185                                    91.0f, 92.0f, 93.0f, 94.0f, 95.0f,
186                                    96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
187                                    101.0f, 102.0f, 103.0f, 104.0f, 105.0f,
188                                    106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
189                                    111.0f, 112.0f, 113.0f, 114.0f, 115.0f,
190                                    116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}},
191                       {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
192                                     7.0f,  8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
193                                     13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
194                                     19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
195                                     25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
196                                     31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,
197                                     37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f,
198                                     43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
199                                     73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f,
200                                     79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f,
201                                     85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
202                                     91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}});
203 }
204 
205 struct GatherRawDataFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
206 {
GatherRawDataFixtureGatherRawDataFixture207     GatherRawDataFixture()
208     {
209         m_Prototext = R"(
210                     ir_version: 8
211                     producer_name: "onnx-example"
212                     graph {
213                       node {
214                         output: "indices"
215                         op_type: "Constant"
216                         attribute {
217                           name: "value"
218                           t {
219                             dims: 3
220                             data_type: 7
221                             raw_data:
222                       "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000\003\000\000\000\000\000\000\000"
223                             name: "value"
224                           }
225                           type: TENSOR
226                         }
227                       }
228                       node {
229                         input: "input"
230                         input: "indices"
231                         output: "output"
232                         op_type: "Gather"
233                         attribute {
234                           name: "axis"
235                           i: 0
236                           type: INT
237                         }
238                       }
239                       name: "gather-model"
240                       input {
241                         name: "input"
242                         type {
243                           tensor_type {
244                             elem_type: 1
245                             shape {
246                               dim {
247                                 dim_value: 5
248                               }
249                               dim {
250                                 dim_value: 4
251                               }
252                               dim {
253                                 dim_value: 3
254                               }
255                               dim {
256                                 dim_value: 2
257                               }
258                             }
259                           }
260                         }
261                       }
262                       output {
263                         name: "output"
264                         type {
265                           tensor_type {
266                             elem_type: 1
267                             shape {
268                               dim {
269                                 dim_value: 3
270                               }
271                               dim {
272                                 dim_value: 4
273                               }
274                               dim {
275                                 dim_value: 3
276                               }
277                               dim {
278                                 dim_value: 2
279                               }
280                             }
281                           }
282                         }
283                       }
284                     })";
285         Setup();
286     }
287 };
288 
289 TEST_CASE_FIXTURE(GatherRawDataFixture, "GatherRawDataTest")
290 {
291     RunTest<4, float>({{"input", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
292                                    6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
293                                    11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
294                                    16.0f, 17.0f, 18.0f, 19.0f, 20.0f,
295                                    21.0f, 22.0f, 23.0f, 24.0f, 25.0f,
296                                    26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
297                                    31.0f, 32.0f, 33.0f, 34.0f, 35.0f,
298                                    36.0f, 37.0f, 38.0f, 39.0f, 40.0f,
299                                    41.0f, 42.0f, 43.0f, 44.0f, 45.0f,
300                                    46.0f, 47.0f, 48.0f, 49.0f, 50.0f,
301                                    51.0f, 52.0f, 53.0f, 54.0f, 55.0f,
302                                    56.0f, 57.0f, 58.0f, 59.0f, 60.0f,
303                                    61.0f, 62.0f, 63.0f, 64.0f, 65.0f,
304                                    66.0f, 67.0f, 68.0f, 69.0f, 70.0f,
305                                    71.0f, 72.0f, 73.0f, 74.0f, 75.0f,
306                                    76.0f, 77.0f, 78.0f, 79.0f, 80.0f,
307                                    81.0f, 82.0f, 83.0f, 84.0f, 85.0f,
308                                    86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
309                                    91.0f, 92.0f, 93.0f, 94.0f, 95.0f,
310                                    96.0f, 97.0f, 98.0f, 99.0f, 100.0f,
311                                    101.0f, 102.0f, 103.0f, 104.0f, 105.0f,
312                                    106.0f, 107.0f, 108.0f, 109.0f, 110.0f,
313                                    111.0f, 112.0f, 113.0f, 114.0f, 115.0f,
314                                    116.0f, 117.0f, 118.0f, 119.0f, 120.0f }}},
315                       {{"output", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f,
316                                     7.0f,  8.0f, 9.0f, 10.0f, 11.0f, 12.0f,
317                                     13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f,
318                                     19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
319                                     25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f,
320                                     31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f,
321                                     37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f,
322                                     43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f,
323                                     73.0f, 74.0f, 75.0f, 76.0f, 77.0f, 78.0f,
324                                     79.0f, 80.0f, 81.0f, 82.0f, 83.0f, 84.0f,
325                                     85.0f, 86.0f, 87.0f, 88.0f, 89.0f, 90.0f,
326                                     91.0f, 92.0f, 93.0f, 94.0f, 95.0f, 96.0f }}});
327 }
328 
329 }
330