xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Gemm.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include "ParserPrototxtFixture.hpp"
8 #include "OnnxParserTestUtils.hpp"
9 
10 TEST_SUITE("OnnxParser_Gemm")
11 {
12 
13 struct GemmFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
14 {
GemmFixtureGemmFixture15     GemmFixture(const std::string& alpha,
16                 const std::string& beta,
17                 const std::string& transA,
18                 const std::string& transB,
19                 const std::vector<int>& inputAShape,
20                 const std::vector<int>& inputBShape,
21                 const std::vector<int>& inputCShape,
22                 const std::vector<int>& outputShape)
23     {
24         m_Prototext = R"(
25                     ir_version: 8
26                     producer_name: "onnx-example"
27                     graph {
28                       node {
29                         input: "A"
30                         input: "B"
31                         input: "C"
32                         output: "Output"
33                         op_type: "Gemm"
34                         attribute {
35                           name: "alpha"
36                           f: )" + alpha + R"(
37                           type: FLOAT
38                         }
39                         attribute {
40                           name: "beta"
41                           f: )" + beta + R"(
42                           type: FLOAT
43                         }
44                         attribute {
45                           name: "transA"
46                           i: )" + transA + R"(
47                           type: INT
48                         }
49                         attribute {
50                           name: "transB"
51                           i: )" + transB + R"(
52                           type: INT
53                         }
54                       }
55                       name: "gem-model"
56                       input {
57                         name: "A"
58                         type {
59                           tensor_type {
60                             elem_type: 1
61                             shape {
62                               )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
63                             }
64                           }
65                         }
66                       }
67                       input {
68                         name: "B"
69                         type {
70                           tensor_type {
71                             elem_type: 1
72                             shape {
73                               )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
74                             }
75                           }
76                         }
77                       }
78                       input {
79                         name: "C"
80                         type {
81                           tensor_type {
82                             elem_type: 1
83                             shape {
84                               )" + armnnUtils::ConstructTensorShapeString(inputCShape) + R"(
85                             }
86                           }
87                         }
88                       }
89                       output {
90                         name: "Output"
91                         type {
92                           tensor_type {
93                             elem_type: 1
94                             shape {
95                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
96                             }
97                           }
98                         }
99                       }
100                     })";
101     }
102 };
103 
104 struct GemmAllAttributesFixture : GemmFixture
105 {
GemmAllAttributesFixtureGemmAllAttributesFixture106     GemmAllAttributesFixture() : GemmFixture("0.25", "0.35", "1", "1", { 4, 3 }, { 5, 4 }, { 5 }, { 3, 5 })
107     {
108         Setup();
109     }
110 };
111 
112 struct GemmSimpleFixture : GemmFixture
113 {
GemmSimpleFixtureGemmSimpleFixture114     GemmSimpleFixture() : GemmFixture("1", "1", "0", "0", { 3, 4 }, { 4, 5 }, { 5 }, { 3, 5 })
115     {
116         Setup();
117     }
118 };
119 
120 struct GemmTransAFixture : GemmFixture
121 {
GemmTransAFixtureGemmTransAFixture122     GemmTransAFixture() : GemmFixture("1", "1", "1", "0", { 4, 3 }, { 4, 5 }, { 5 }, { 3, 5 })
123     {
124         Setup();
125     }
126 };
127 
128 struct GemmTransBFixture : GemmFixture
129 {
GemmTransBFixtureGemmTransBFixture130     GemmTransBFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 5 }, { 3, 5 })
131     {
132         Setup();
133     }
134 };
135 
136 struct GemmParseExceptionFixture : GemmFixture
137 {
GemmParseExceptionFixtureGemmParseExceptionFixture138     GemmParseExceptionFixture() : GemmFixture("1", "1", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }, { 3, 5 }) {}
139 };
140 
141 TEST_CASE_FIXTURE(GemmAllAttributesFixture, "GemmTest")
142 {
143     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
144                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
145                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
146                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
147                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
148                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
149                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
150                       {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
151                                     12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
152                                     10.035f, 32.07f,  54.105f, 76.14f, 98.175f }}});
153 }
154 
155 TEST_CASE_FIXTURE(GemmSimpleFixture, "GemmSimpleTest")
156 {
157     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
158                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
159                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
160                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
161                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
162                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
163                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
164                       {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
165                                     196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
166                                     60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
167 }
168 
169 TEST_CASE_FIXTURE(GemmTransAFixture, "GemmTransposeATest")
170 {
171     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
172                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
173                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
174                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
175                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
176                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
177                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
178                       {{"Output", { 180.1f, 210.2f, 240.3f, 270.4f, 300.5f,
179                                     146.1f, 172.2f, 198.3f, 224.4f, 250.5f,
180                                     112.1f, 134.2f, 156.3f, 178.4f, 200.5f }}});
181 }
182 
183 TEST_CASE_FIXTURE(GemmTransBFixture, "GemmTransposeBTest")
184 {
185     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
186                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
187                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
188                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
189                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
190                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }},
191                        {"C", { 0.10f, 0.20f, 0.30f, 0.40f, 0.50f }}},
192                       {{"Output", { 100.1f, 268.2f, 436.3f, 604.4f, 772.5f,
193                                     60.1f, 164.2f, 268.3f, 372.4f, 476.5f,
194                                     20.1f, 60.2f, 100.3f, 140.4f, 180.5f }}});
195 }
196 
197 TEST_CASE_FIXTURE(GemmParseExceptionFixture, "GemmParseExceptionTest")
198 {
199     // ParseException because Input C is non-constant and has 2 dimension (should be 1 dimension)
200     CHECK_THROWS_AS(Setup(), armnn::ParseException);
201 }
202 
203 struct GemmConstantFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
204 {
GemmConstantFixtureGemmConstantFixture205     GemmConstantFixture()
206     {
207         m_Prototext = R"(
208                     ir_version: 8
209                     producer_name: "onnx-example"
210                     graph {
211                       node {
212                         input: "A"
213                         input: "B"
214                         input: "C"
215                         output: "Output"
216                         op_type: "Gemm"
217                         attribute {
218                           name: "alpha"
219                           f: 0.25
220                           type: FLOAT
221                         }
222                         attribute {
223                           name: "beta"
224                           f: 0.35
225                           type: FLOAT
226                         }
227                         attribute {
228                           name: "transA"
229                           i: 1
230                           type: INT
231                         }
232                         attribute {
233                           name: "transB"
234                           i: 1
235                           type: INT
236                         }
237                       }
238                       name: "gem-model"
239                       initializer {
240                         dims: 5
241                         dims: 4
242                         data_type: 1
243                         float_data: 1.0
244                         float_data: 2.0
245                         float_data: 3.0
246                         float_data: 4.0
247                         float_data: 5.0
248                         float_data: 6.0
249                         float_data: 7.0
250                         float_data: 8.0
251                         float_data: 9.0
252                         float_data: 10.0
253                         float_data: 11.0
254                         float_data: 12.0
255                         float_data: 13.0
256                         float_data: 14.0
257                         float_data: 15.0
258                         float_data: 16.0
259                         float_data: 17.0
260                         float_data: 18.0
261                         float_data: 19.0
262                         float_data: 20.0
263                         name: "B"
264                       }
265                       initializer {
266                         dims: 1
267                         dims: 5
268                         data_type: 1
269                         float_data: 0.1
270                         float_data: 0.2
271                         float_data: 0.3
272                         float_data: 0.4
273                         float_data: 0.5
274                         name: "C"
275                       }
276                       input {
277                         name: "A"
278                         type {
279                           tensor_type {
280                             elem_type: 1
281                             shape {
282                               dim {
283                                 dim_value: 4
284                               }
285                               dim {
286                                 dim_value: 3
287                               }
288                             }
289                           }
290                         }
291                       }
292                       output {
293                         name: "Output"
294                         type {
295                           tensor_type {
296                             elem_type: 1
297                             shape {
298                               dim {
299                                 dim_value: 3
300                               }
301                               dim {
302                                 dim_value: 5
303                               }
304                             }
305                           }
306                         }
307                       }
308                     })";
309         Setup();
310     }
311 };
312 
313 TEST_CASE_FIXTURE(GemmConstantFixture, "GemmConstantTest")
314 {
315     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
316                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
317                       {{"Output", { 15.035f, 45.07f, 75.105f, 105.14f, 135.175f,
318                                     12.535f, 38.57f, 64.605f, 90.64f, 116.675f,
319                                     10.035f, 32.07f,  54.105f, 76.14f, 98.175f }}});
320 }
321 
322 struct GemmConstantSimpleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
323 {
GemmConstantSimpleFixtureGemmConstantSimpleFixture324     GemmConstantSimpleFixture()
325     {
326         m_Prototext = R"(
327                     ir_version: 8
328                     producer_name: "onnx-example"
329                     graph {
330                       node {
331                         input: "A"
332                         input: "B"
333                         input: "C"
334                         output: "Output"
335                         op_type: "Gemm"
336                         attribute {
337                           name: "alpha"
338                           f: 1
339                           type: FLOAT
340                         }
341                         attribute {
342                           name: "beta"
343                           f: 1
344                           type: FLOAT
345                         }
346                         attribute {
347                           name: "transA"
348                           i: 0
349                           type: INT
350                         }
351                         attribute {
352                           name: "transB"
353                           i: 0
354                           type: INT
355                         }
356                       }
357                       name: "gem-model"
358                       initializer {
359                         dims: 4
360                         dims: 5
361                         data_type: 1
362                         float_data: 1.0
363                         float_data: 2.0
364                         float_data: 3.0
365                         float_data: 4.0
366                         float_data: 5.0
367                         float_data: 6.0
368                         float_data: 7.0
369                         float_data: 8.0
370                         float_data: 9.0
371                         float_data: 10.0
372                         float_data: 11.0
373                         float_data: 12.0
374                         float_data: 13.0
375                         float_data: 14.0
376                         float_data: 15.0
377                         float_data: 16.0
378                         float_data: 17.0
379                         float_data: 18.0
380                         float_data: 19.0
381                         float_data: 20.0
382                         name: "B"
383                       }
384                       initializer {
385                         dims: 1
386                         dims: 5
387                         data_type: 1
388                         float_data: 0.1
389                         float_data: 0.2
390                         float_data: 0.3
391                         float_data: 0.4
392                         float_data: 0.5
393                         name: "C"
394                       }
395                       input {
396                         name: "A"
397                         type {
398                           tensor_type {
399                             elem_type: 1
400                             shape {
401                               dim {
402                                 dim_value: 3
403                               }
404                               dim {
405                                 dim_value: 4
406                               }
407                             }
408                           }
409                         }
410                       }
411                       output {
412                         name: "Output"
413                         type {
414                           tensor_type {
415                             elem_type: 1
416                             shape {
417                               dim {
418                                 dim_value: 3
419                               }
420                               dim {
421                                 dim_value: 5
422                               }
423                             }
424                           }
425                         }
426                       }
427                     })";
428         Setup();
429     }
430 };
431 
432 TEST_CASE_FIXTURE(GemmConstantSimpleFixture, "GemmConstantSimpleTest")
433 {
434     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
435                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }}},
436                       {{"Output", { 332.1f, 374.2f, 416.3f, 458.4f, 500.5f,
437                                     196.1f, 222.2f, 248.3f, 274.4f, 300.5f,
438                                     60.1f, 70.2f, 80.3f, 90.4f, 100.5f }}});
439 }
440 
441 struct GemmABFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
442 {
GemmABFixtureGemmABFixture443     GemmABFixture(const std::string& alpha,
444                   const std::string& beta,
445                   const std::string& transA,
446                   const std::string& transB,
447                   const std::vector<int>& inputAShape,
448                   const std::vector<int>& inputBShape,
449                   const std::vector<int>& outputShape)
450     {
451         m_Prototext = R"(
452                     ir_version: 8
453                     producer_name: "onnx-example"
454                     graph {
455                       node {
456                         input: "A"
457                         input: "B"
458                         output: "Output"
459                         op_type: "Gemm"
460                         attribute {
461                           name: "alpha"
462                           f: )" + alpha + R"(
463                           type: FLOAT
464                         }
465                         attribute {
466                           name: "beta"
467                           f: )" + beta + R"(
468                           type: FLOAT
469                         }
470                         attribute {
471                           name: "transA"
472                           i: )" + transA + R"(
473                           type: INT
474                         }
475                         attribute {
476                           name: "transB"
477                           i: )" + transB + R"(
478                           type: INT
479                         }
480                       }
481                       name: "gem-model"
482                       input {
483                         name: "A"
484                         type {
485                           tensor_type {
486                             elem_type: 1
487                             shape {
488                               )" + armnnUtils::ConstructTensorShapeString(inputAShape) + R"(
489                             }
490                           }
491                         }
492                       }
493                       input {
494                         name: "B"
495                         type {
496                           tensor_type {
497                             elem_type: 1
498                             shape {
499                               )" + armnnUtils::ConstructTensorShapeString(inputBShape) + R"(
500                             }
501                           }
502                         }
503                       }
504                       output {
505                         name: "Output"
506                         type {
507                           tensor_type {
508                             elem_type: 1
509                             shape {
510                               )" + armnnUtils::ConstructTensorShapeString(outputShape) + R"(
511                             }
512                           }
513                         }
514                       }
515                     })";
516         Setup();
517     }
518 };
519 
520 struct GemmAlphaTransAFixture : GemmABFixture
521 {
GemmAlphaTransAFixtureGemmAlphaTransAFixture522     GemmAlphaTransAFixture() : GemmABFixture("0.25", "0.35", "1", "0", { 4, 3 }, { 4, 5 }, { 3, 5 }) {}
523 };
524 
525 struct GemmAlphaTransBFixture : GemmABFixture
526 {
GemmAlphaTransBFixtureGemmAlphaTransBFixture527     GemmAlphaTransBFixture() : GemmABFixture("0.25", "0.35", "0", "1", { 3, 4 }, { 5, 4 }, { 3, 5 }) {}
528 };
529 
530 TEST_CASE_FIXTURE(GemmAlphaTransAFixture, "GemmAlphaTransATest")
531 {
532     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
533                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
534                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
535                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
536                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
537                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
538                       {{"Output", { 45.0f, 52.5f, 60.0f, 67.5f, 75.0f,
539                                     36.5f, 43.0f, 49.5f, 56.0f, 62.5f,
540                                     28.0f, 33.5f, 39.0f, 44.5f, 50.0f }}});
541 }
542 
543 TEST_CASE_FIXTURE(GemmAlphaTransBFixture, "GemmAlphaTransBTest")
544 {
545     RunTest<2, float>({{"A", { 12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f,
546                                6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f }},
547                        {"B", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f,
548                                6.0f, 7.0f, 8.0f, 9.0f, 10.0f,
549                                11.0f, 12.0f, 13.0f, 14.0f, 15.0f,
550                                16.0f, 17.0f, 18.0f, 19.0f, 20.0f }}},
551                       {{"Output", { 25.0f, 67.0f, 109.0f, 151.0f, 193.0f,
552                                     15.0f, 41.0f, 67.0f, 93.0f, 119.0f,
553                                     5.0f, 15.0f, 25.0f, 35.0f, 45.0f }}});
554 }
555 
556 }
557