xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Conv2D.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnOnnxParser/IOnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 
9 TEST_SUITE("OnnxParser_Conv2D")
10 {
11 struct SimpleConv2DFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12 {
SimpleConv2DFixtureSimpleConv2DFixture13     SimpleConv2DFixture()
14     {
15         m_Prototext = R"(
16                    ir_version: 3
17                    producer_name:  "CNTK"
18                    producer_version:  "2.5.1"
19                    domain:  "ai.cntk"
20                    model_version: 1
21                    graph {
22                      name:  "CNTKGraph"
23                      input {
24                         name: "Input"
25                         type {
26                           tensor_type {
27                             elem_type: 1
28                             shape {
29                               dim {
30                                 dim_value: 1
31                               }
32                               dim {
33                                 dim_value: 1
34                               }
35                               dim {
36                                 dim_value: 3
37                               }
38                               dim {
39                                 dim_value: 3
40                               }
41                             }
42                           }
43                         }
44                       }
45                       input {
46                         name: "Weight"
47                         type {
48                           tensor_type {
49                             elem_type: 1
50                             shape {
51                               dim {
52                                 dim_value: 1
53                               }
54                               dim {
55                                 dim_value: 1
56                               }
57                               dim {
58                                 dim_value: 3
59                               }
60                               dim {
61                                 dim_value: 3
62                               }
63                             }
64                           }
65                         }
66                       }
67                       initializer {
68                           dims: 1
69                           dims: 1
70                           dims: 3
71                           dims: 3
72                           data_type: 1
73                           float_data: 2
74                           float_data: 1
75                           float_data: 0
76                           float_data: 6
77                           float_data: 2
78                           float_data: 1
79                           float_data: 4
80                           float_data: 1
81                           float_data: 2
82                           name: "Weight"
83                         }
84                       node {
85                          input: "Input"
86                          input: "Weight"
87                          output: "Output"
88                          name: "Convolution"
89                          op_type: "Conv"
90                          attribute {
91                            name: "kernel_shape"
92                            ints: 3
93                            ints: 3
94                            type: INTS
95                          }
96                          attribute {
97                            name: "strides"
98                            ints: 1
99                            ints: 1
100                            type: INTS
101                          }
102                          attribute {
103                            name: "auto_pad"
104                            s: "VALID"
105                            type: STRING
106                          }
107                          attribute {
108                            name: "group"
109                            i: 1
110                            type: INT
111                          }
112                          attribute {
113                            name: "dilations"
114                            ints: 1
115                            ints: 1
116                            type: INTS
117                          }
118                          doc_string: ""
119                          domain: ""
120                        }
121                       output {
122                           name: "Output"
123                           type {
124                              tensor_type {
125                                elem_type: 1
126                                shape {
127                                    dim {
128                                        dim_value: 1
129                                    }
130                                    dim {
131                                        dim_value: 1
132                                    }
133                                    dim {
134                                        dim_value: 1
135                                    }
136                                    dim {
137                                        dim_value: 1
138                                    }
139                                }
140                             }
141                         }
142                         }
143                     }
144                    opset_import {
145                       version: 7
146                     })";
147         Setup();
148     }
149 };
150 
151 struct Conv2DWithBiasesFixture :  public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
152 {
Conv2DWithBiasesFixtureConv2DWithBiasesFixture153     Conv2DWithBiasesFixture() {
154         m_Prototext = R"(
155                    ir_version: 3
156                    producer_name:  "CNTK"
157                    producer_version:  "2.5.1"
158                    domain:  "ai.cntk"
159                    model_version: 1
160                    graph {
161                      name:  "CNTKGraph"
162                      input {
163                         name: "Input"
164                         type {
165                           tensor_type {
166                             elem_type: 1
167                             shape {
168                               dim {
169                                 dim_value: 1
170                               }
171                               dim {
172                                 dim_value: 1
173                               }
174                               dim {
175                                 dim_value: 2
176                               }
177                               dim {
178                                 dim_value: 2
179                               }
180                             }
181                           }
182                         }
183                       }
184                       input {
185                         name: "Weight"
186                         type {
187                           tensor_type {
188                             elem_type: 1
189                             shape {
190                               dim {
191                                 dim_value: 1
192                               }
193                               dim {
194                                 dim_value: 1
195                               }
196                               dim {
197                                 dim_value: 2
198                               }
199                               dim {
200                                 dim_value: 2
201                               }
202                             }
203                           }
204                         }
205                       }
206                       initializer {
207                           dims: 1
208                           dims: 1
209                           dims: 2
210                           dims: 2
211                           data_type: 1
212                           float_data: 2
213                           float_data: 1
214                           float_data: 0
215                           float_data: 6
216                           name: "Weight"
217                         }
218                         input {
219                           name: "Bias"
220                           type {
221                             tensor_type {
222                               elem_type: 1
223                               shape {
224                                 dim {
225                                   dim_value: 4
226                                 }
227                               }
228                             }
229                           }
230                         }
231                         initializer {
232                             dims: 4
233                             data_type: 1
234                             float_data: 10
235                             float_data: 0
236                             float_data: 0
237                             float_data: 0
238                             name: "Bias"
239                           }
240                       node {
241                          input: "Input"
242                          input: "Weight"
243                          input: "Bias"
244                          output: "Output"
245                          name: "Convolution"
246                          op_type: "Conv"
247                          attribute {
248                            name: "kernel_shape"
249                            ints: 2
250                            ints: 2
251                            type: INTS
252                          }
253                          attribute {
254                            name: "strides"
255                            ints: 1
256                            ints: 1
257                            type: INTS
258                          }
259                          attribute {
260                            name: "auto_pad"
261                            s: "SAME_UPPER"
262                            type: STRING
263                          }
264                          attribute {
265                            name: "group"
266                            i: 1
267                            type: INT
268                          }
269                          attribute {
270                            name: "dilations"
271                            ints: 1
272                            ints: 1
273                            type: INTS
274                          }
275                          doc_string: ""
276                          domain: ""
277                        }
278                       output {
279                           name: "Output"
280                           type {
281                              tensor_type {
282                                elem_type: 1
283                                shape {
284                                    dim {
285                                        dim_value: 1
286                                    }
287                                    dim {
288                                        dim_value: 1
289                                    }
290                                    dim {
291                                        dim_value: 2
292                                    }
293                                    dim {
294                                        dim_value: 2
295                                    }
296                                }
297                             }
298                         }
299                         }
300                     }
301                    opset_import {
302                       version: 7
303                     })";
304         Setup();
305     }
306 };
307 
308 
309 struct Conv2DDimReducingFixture :  public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
310 {
Conv2DDimReducingFixtureConv2DDimReducingFixture311     Conv2DDimReducingFixture() {
312         m_Prototext = R"(
313                    ir_version: 3
314                    producer_name:  "CNTK"
315                    producer_version:  "2.5.1"
316                    domain:  "ai.cntk"
317                    model_version: 1
318                    graph {
319                      name:  "CNTKGraph"
320                      input {
321                         name: "Input"
322                         type {
323                           tensor_type {
324                             elem_type: 1
325                             shape {
326                               dim {
327                                 dim_value: 1
328                               }
329                               dim {
330                                 dim_value: 3
331                               }
332                               dim {
333                                 dim_value: 2
334                               }
335                               dim {
336                                 dim_value: 2
337                               }
338                             }
339                           }
340                         }
341                       }
342                       input {
343                         name: "Weight"
344                         type {
345                           tensor_type {
346                             elem_type: 1
347                             shape {
348                               dim {
349                                 dim_value: 2
350                               }
351                               dim {
352                                 dim_value: 3
353                               }
354                               dim {
355                                 dim_value: 1
356                               }
357                               dim {
358                                 dim_value: 1
359                               }
360                             }
361                           }
362                         }
363                       }
364                       initializer {
365                           dims: 2
366                           dims: 3
367                           dims: 1
368                           dims: 1
369                           data_type: 1
370                           float_data: -1
371                           float_data: 2
372                           float_data: 0
373                           float_data: 1
374                           float_data: 0
375                           float_data: 0
376                           name: "Weight"
377                         }
378                       node {
379                          input: "Input"
380                          input: "Weight"
381                          output: "Output"
382                          name: "Convolution"
383                          op_type: "Conv"
384                          attribute {
385                            name: "kernel_shape"
386                            ints: 1
387                            ints: 1
388                            type: INTS
389                          }
390                          attribute {
391                            name: "strides"
392                            ints: 1
393                            ints: 1
394                            type: INTS
395                          }
396                          attribute {
397                            name: "group"
398                            i: 1
399                            type: INT
400                          }
401                          attribute {
402                            name: "dilations"
403                            ints: 1
404                            ints: 1
405                            type: INTS
406                          }
407                          doc_string: ""
408                          domain: ""
409                        }
410                       output {
411                           name: "Output"
412                           type {
413                              tensor_type {
414                                elem_type: 1
415                                shape {
416                                    dim {
417                                        dim_value: 1
418                                    }
419                                    dim {
420                                        dim_value: 2
421                                    }
422                                    dim {
423                                        dim_value: 2
424                                    }
425                                    dim {
426                                        dim_value: 2
427                                    }
428                                }
429                             }
430                         }
431                         }
432                     }
433                    opset_import {
434                       version: 7
435                     })";
436         Setup();
437     }
438 };
439 
440 struct Conv2DwithDilationFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
441 {
Conv2DwithDilationFixtureConv2DwithDilationFixture442     Conv2DwithDilationFixture()
443     {
444         m_Prototext = R"(
445                    ir_version: 3
446                    producer_name:  "CNTK"
447                    producer_version:  "2.5.1"
448                    domain:  "ai.cntk"
449                    model_version: 1
450                    graph {
451                      name:  "CNTKGraph"
452                      input {
453                         name: "Input"
454                         type {
455                           tensor_type {
456                             elem_type: 1
457                             shape {
458                               dim {
459                                 dim_value: 1
460                               }
461                               dim {
462                                 dim_value: 1
463                               }
464                               dim {
465                                 dim_value: 6
466                               }
467                               dim {
468                                 dim_value: 6
469                               }
470                             }
471                           }
472                         }
473                       }
474                       input {
475                         name: "Weight"
476                         type {
477                           tensor_type {
478                             elem_type: 1
479                             shape {
480                               dim {
481                                 dim_value: 1
482                               }
483                               dim {
484                                 dim_value: 1
485                               }
486                               dim {
487                                 dim_value: 3
488                               }
489                               dim {
490                                 dim_value: 3
491                               }
492                             }
493                           }
494                         }
495                       }
496                       initializer {
497                           dims: 1
498                           dims: 1
499                           dims: 3
500                           dims: 3
501                           data_type: 1
502                           float_data: 2
503                           float_data: 1
504                           float_data: 0
505                           float_data: 6
506                           float_data: 2
507                           float_data: 1
508                           float_data: 4
509                           float_data: 1
510                           float_data: 2
511                           name: "Weight"
512                         }
513                       node {
514                          input: "Input"
515                          input: "Weight"
516                          output: "Output"
517                          name: "Convolution"
518                          op_type: "Conv"
519                          attribute {
520                            name: "kernel_shape"
521                            ints: 3
522                            ints: 3
523                            type: INTS
524                          }
525                          attribute {
526                            name: "strides"
527                            ints: 1
528                            ints: 1
529                            type: INTS
530                          }
531                          attribute {
532                            name: "auto_pad"
533                            s: "VALID"
534                            type: STRING
535                          }
536                          attribute {
537                            name: "group"
538                            i: 1
539                            type: INT
540                          }
541                          attribute {
542                            name: "dilations"
543                            ints: 2
544                            ints: 2
545                            type: INTS
546                          }
547                          doc_string: ""
548                          domain: ""
549                        }
550                       output {
551                           name: "Output"
552                           type {
553                              tensor_type {
554                                elem_type: 1
555                                shape {
556                                    dim {
557                                        dim_value: 1
558                                    }
559                                    dim {
560                                        dim_value: 1
561                                    }
562                                    dim {
563                                        dim_value: 2
564                                    }
565                                    dim {
566                                        dim_value: 2
567                                    }
568                                }
569                             }
570                         }
571                         }
572                     }
573                    opset_import {
574                       version: 7
575                     })";
576         Setup();
577     }
578 };
579 
580 TEST_CASE_FIXTURE(SimpleConv2DFixture, "ValidConvTest")
581 {
582     RunTest<4>({{"Input", {1.0, 2.0, 3.0,
583                            4.0, 5.0, 6.0,
584                            7.0, 8.0, 9.0}}},
585               {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 +
586                            4.0 * 6 + 5.0 * 2 + 6.0 * 1 +
587                            7.0 * 4 + 8.0 * 1 + 9.0 * 2}}});
588 }
589 
590 TEST_CASE_FIXTURE(Conv2DWithBiasesFixture, "ValidConvWithBiasTest")
591 {
592     RunTest<4>({{"Input", {1.0, 2.0,
593                            3.0, 4.0}}},
594               {{"Output", {1.0 * 2 + 2.0 * 1 + 3.0 * 0 + 4 * 6 + 10,
595                            2.0 * 2 + 0 * 1 + 4.0 * 0 + 0 * 6 + 10,
596                            3.0 * 2 + 4.0 * 1 + 0 * 0 + 0 * 6 + 10,
597                            4.0 * 2 + 0 * 1 + 0 * 0 + 0 * 6 + 10}}});
598 }
599 
600 TEST_CASE_FIXTURE(Conv2DDimReducingFixture, "ValidConvDimReducTest")
601 {
602     RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, -1, -2, 3, 4, 1 , 1, 1, 1 }}},
603               {{"Output", {-1 * 1 + 2 * -1, -1 * 2 + 2 * -2,
604                            -1 * 3 + 2 * 3,  -1 * 4 + 2 * 4,
605                            1, 2, 3, 4}}});
606 }
607 
608 TEST_CASE_FIXTURE(Conv2DwithDilationFixture, "ValidConvWithDilationTest")
609 {
610     RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
611                            7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
612                            1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
613                            7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
614                            1.0, 2.0, 3.0, 4.0, 5.0, 6.0,
615                            7.0, 8.0, 9.0, 10.0, 11.0, 12.0}}},
616                {{"Output", {39.0, 58.0, 153.0, 172.0 }}});
617 }
618 
619 }
620