xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/Pooling.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "armnnOnnxParser/IOnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include  "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_Pooling")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct PoolingMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
12*89c4ff92SAndroid Build Coastguard Worker {
PoolingMainFixturePoolingMainFixture13*89c4ff92SAndroid Build Coastguard Worker     PoolingMainFixture(const std::string& dataType, const std::string& op)
14*89c4ff92SAndroid Build Coastguard Worker     {
15*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
16*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
17*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
18*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
19*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
20*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
21*89c4ff92SAndroid Build Coastguard Worker                    graph {
22*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
23*89c4ff92SAndroid Build Coastguard Worker                      input {
24*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
25*89c4ff92SAndroid Build Coastguard Worker                         type {
26*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
27*89c4ff92SAndroid Build Coastguard Worker                             elem_type: )" + dataType + R"(
28*89c4ff92SAndroid Build Coastguard Worker                             shape {
29*89c4ff92SAndroid Build Coastguard Worker                               dim {
30*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
31*89c4ff92SAndroid Build Coastguard Worker                               }
32*89c4ff92SAndroid Build Coastguard Worker                               dim {
33*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
34*89c4ff92SAndroid Build Coastguard Worker                               }
35*89c4ff92SAndroid Build Coastguard Worker                               dim {
36*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
37*89c4ff92SAndroid Build Coastguard Worker                               }
38*89c4ff92SAndroid Build Coastguard Worker                               dim {
39*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
40*89c4ff92SAndroid Build Coastguard Worker                               }
41*89c4ff92SAndroid Build Coastguard Worker                             }
42*89c4ff92SAndroid Build Coastguard Worker                           }
43*89c4ff92SAndroid Build Coastguard Worker                         }
44*89c4ff92SAndroid Build Coastguard Worker                       }
45*89c4ff92SAndroid Build Coastguard Worker                      node {
46*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
47*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
48*89c4ff92SAndroid Build Coastguard Worker                          name: "Pooling"
49*89c4ff92SAndroid Build Coastguard Worker                          op_type: )" + op + R"(
50*89c4ff92SAndroid Build Coastguard Worker                          attribute {
51*89c4ff92SAndroid Build Coastguard Worker                            name: "kernel_shape"
52*89c4ff92SAndroid Build Coastguard Worker                            ints: 2
53*89c4ff92SAndroid Build Coastguard Worker                            ints: 2
54*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
55*89c4ff92SAndroid Build Coastguard Worker                          }
56*89c4ff92SAndroid Build Coastguard Worker                          attribute {
57*89c4ff92SAndroid Build Coastguard Worker                            name: "strides"
58*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
59*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
60*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
61*89c4ff92SAndroid Build Coastguard Worker                          }
62*89c4ff92SAndroid Build Coastguard Worker                          attribute {
63*89c4ff92SAndroid Build Coastguard Worker                            name: "pads"
64*89c4ff92SAndroid Build Coastguard Worker                            ints: 0
65*89c4ff92SAndroid Build Coastguard Worker                            ints: 0
66*89c4ff92SAndroid Build Coastguard Worker                            ints: 0
67*89c4ff92SAndroid Build Coastguard Worker                            ints: 0
68*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
69*89c4ff92SAndroid Build Coastguard Worker                          }
70*89c4ff92SAndroid Build Coastguard Worker                       }
71*89c4ff92SAndroid Build Coastguard Worker                       output {
72*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
73*89c4ff92SAndroid Build Coastguard Worker                           type {
74*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
75*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
76*89c4ff92SAndroid Build Coastguard Worker                                shape {
77*89c4ff92SAndroid Build Coastguard Worker                                    dim {
78*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
79*89c4ff92SAndroid Build Coastguard Worker                                    }
80*89c4ff92SAndroid Build Coastguard Worker                                    dim {
81*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
82*89c4ff92SAndroid Build Coastguard Worker                                    }
83*89c4ff92SAndroid Build Coastguard Worker                                    dim {
84*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
85*89c4ff92SAndroid Build Coastguard Worker                                    }
86*89c4ff92SAndroid Build Coastguard Worker                                    dim {
87*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
88*89c4ff92SAndroid Build Coastguard Worker                                    }
89*89c4ff92SAndroid Build Coastguard Worker                                }
90*89c4ff92SAndroid Build Coastguard Worker                             }
91*89c4ff92SAndroid Build Coastguard Worker                         }
92*89c4ff92SAndroid Build Coastguard Worker                         }
93*89c4ff92SAndroid Build Coastguard Worker                     }
94*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
95*89c4ff92SAndroid Build Coastguard Worker                       version: 7
96*89c4ff92SAndroid Build Coastguard Worker                     })";
97*89c4ff92SAndroid Build Coastguard Worker     }
98*89c4ff92SAndroid Build Coastguard Worker };
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker struct MaxPoolValidFixture : PoolingMainFixture
101*89c4ff92SAndroid Build Coastguard Worker {
MaxPoolValidFixtureMaxPoolValidFixture102*89c4ff92SAndroid Build Coastguard Worker     MaxPoolValidFixture() : PoolingMainFixture("1", "\"MaxPool\"") {
103*89c4ff92SAndroid Build Coastguard Worker         Setup();
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker };
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker struct MaxPoolInvalidFixture : PoolingMainFixture
108*89c4ff92SAndroid Build Coastguard Worker {
MaxPoolInvalidFixtureMaxPoolInvalidFixture109*89c4ff92SAndroid Build Coastguard Worker     MaxPoolInvalidFixture() : PoolingMainFixture("10", "\"MaxPool\"") { }
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MaxPoolValidFixture, "ValidMaxPoolTest")
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {3.0f}}});
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker struct AvgPoolValidFixture : PoolingMainFixture
118*89c4ff92SAndroid Build Coastguard Worker {
AvgPoolValidFixtureAvgPoolValidFixture119*89c4ff92SAndroid Build Coastguard Worker     AvgPoolValidFixture() : PoolingMainFixture("1", "\"AveragePool\"") {
120*89c4ff92SAndroid Build Coastguard Worker         Setup();
121*89c4ff92SAndroid Build Coastguard Worker     }
122*89c4ff92SAndroid Build Coastguard Worker };
123*89c4ff92SAndroid Build Coastguard Worker 
124*89c4ff92SAndroid Build Coastguard Worker struct PoolingWithPadFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
125*89c4ff92SAndroid Build Coastguard Worker {
PoolingWithPadFixturePoolingWithPadFixture126*89c4ff92SAndroid Build Coastguard Worker     PoolingWithPadFixture()
127*89c4ff92SAndroid Build Coastguard Worker     {
128*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
129*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
130*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
131*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
132*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
133*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
134*89c4ff92SAndroid Build Coastguard Worker                    graph {
135*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
136*89c4ff92SAndroid Build Coastguard Worker                      input {
137*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
138*89c4ff92SAndroid Build Coastguard Worker                         type {
139*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
140*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
141*89c4ff92SAndroid Build Coastguard Worker                             shape {
142*89c4ff92SAndroid Build Coastguard Worker                               dim {
143*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
144*89c4ff92SAndroid Build Coastguard Worker                               }
145*89c4ff92SAndroid Build Coastguard Worker                               dim {
146*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
147*89c4ff92SAndroid Build Coastguard Worker                               }
148*89c4ff92SAndroid Build Coastguard Worker                               dim {
149*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
150*89c4ff92SAndroid Build Coastguard Worker                               }
151*89c4ff92SAndroid Build Coastguard Worker                               dim {
152*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
153*89c4ff92SAndroid Build Coastguard Worker                               }
154*89c4ff92SAndroid Build Coastguard Worker                             }
155*89c4ff92SAndroid Build Coastguard Worker                           }
156*89c4ff92SAndroid Build Coastguard Worker                         }
157*89c4ff92SAndroid Build Coastguard Worker                       }
158*89c4ff92SAndroid Build Coastguard Worker                      node {
159*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
160*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
161*89c4ff92SAndroid Build Coastguard Worker                          name: "Pooling"
162*89c4ff92SAndroid Build Coastguard Worker                          op_type: "AveragePool"
163*89c4ff92SAndroid Build Coastguard Worker                          attribute {
164*89c4ff92SAndroid Build Coastguard Worker                            name: "kernel_shape"
165*89c4ff92SAndroid Build Coastguard Worker                            ints: 4
166*89c4ff92SAndroid Build Coastguard Worker                            ints: 4
167*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
168*89c4ff92SAndroid Build Coastguard Worker                          }
169*89c4ff92SAndroid Build Coastguard Worker                          attribute {
170*89c4ff92SAndroid Build Coastguard Worker                            name: "strides"
171*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
172*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
173*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
174*89c4ff92SAndroid Build Coastguard Worker                          }
175*89c4ff92SAndroid Build Coastguard Worker                          attribute {
176*89c4ff92SAndroid Build Coastguard Worker                            name: "pads"
177*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
178*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
179*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
180*89c4ff92SAndroid Build Coastguard Worker                            ints: 1
181*89c4ff92SAndroid Build Coastguard Worker                            type: INTS
182*89c4ff92SAndroid Build Coastguard Worker                          }
183*89c4ff92SAndroid Build Coastguard Worker                          attribute {
184*89c4ff92SAndroid Build Coastguard Worker                            name: "count_include_pad"
185*89c4ff92SAndroid Build Coastguard Worker                            i: 1
186*89c4ff92SAndroid Build Coastguard Worker                            type: INT
187*89c4ff92SAndroid Build Coastguard Worker                          }
188*89c4ff92SAndroid Build Coastguard Worker                       }
189*89c4ff92SAndroid Build Coastguard Worker                       output {
190*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
191*89c4ff92SAndroid Build Coastguard Worker                           type {
192*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
193*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
194*89c4ff92SAndroid Build Coastguard Worker                                shape {
195*89c4ff92SAndroid Build Coastguard Worker                                    dim {
196*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
197*89c4ff92SAndroid Build Coastguard Worker                                    }
198*89c4ff92SAndroid Build Coastguard Worker                                    dim {
199*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
200*89c4ff92SAndroid Build Coastguard Worker                                    }
201*89c4ff92SAndroid Build Coastguard Worker                                    dim {
202*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
203*89c4ff92SAndroid Build Coastguard Worker                                    }
204*89c4ff92SAndroid Build Coastguard Worker                                    dim {
205*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
206*89c4ff92SAndroid Build Coastguard Worker                                    }
207*89c4ff92SAndroid Build Coastguard Worker                                }
208*89c4ff92SAndroid Build Coastguard Worker                             }
209*89c4ff92SAndroid Build Coastguard Worker                         }
210*89c4ff92SAndroid Build Coastguard Worker                         }
211*89c4ff92SAndroid Build Coastguard Worker                     }
212*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
213*89c4ff92SAndroid Build Coastguard Worker                       version: 7
214*89c4ff92SAndroid Build Coastguard Worker                     })";
215*89c4ff92SAndroid Build Coastguard Worker         Setup();
216*89c4ff92SAndroid Build Coastguard Worker     }
217*89c4ff92SAndroid Build Coastguard Worker };
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(AvgPoolValidFixture, "AveragePoolValid")
220*89c4ff92SAndroid Build Coastguard Worker {
221*89c4ff92SAndroid Build Coastguard Worker     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {0.5}}});
222*89c4ff92SAndroid Build Coastguard Worker }
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(PoolingWithPadFixture, "ValidAvgWithPadTest")
225*89c4ff92SAndroid Build Coastguard Worker {
226*89c4ff92SAndroid Build Coastguard Worker     RunTest<4>({{"Input", {1.0f, 2.0f, 3.0f, -4.0f}}}, {{"Output", {1.0/8.0}}});
227*89c4ff92SAndroid Build Coastguard Worker }
228*89c4ff92SAndroid Build Coastguard Worker 
229*89c4ff92SAndroid Build Coastguard Worker struct GlobalAvgFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
230*89c4ff92SAndroid Build Coastguard Worker {
GlobalAvgFixtureGlobalAvgFixture231*89c4ff92SAndroid Build Coastguard Worker     GlobalAvgFixture()
232*89c4ff92SAndroid Build Coastguard Worker     {
233*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
234*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
235*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
236*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
237*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
238*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
239*89c4ff92SAndroid Build Coastguard Worker                    graph {
240*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
241*89c4ff92SAndroid Build Coastguard Worker                      input {
242*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
243*89c4ff92SAndroid Build Coastguard Worker                         type {
244*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
245*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
246*89c4ff92SAndroid Build Coastguard Worker                             shape {
247*89c4ff92SAndroid Build Coastguard Worker                               dim {
248*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
249*89c4ff92SAndroid Build Coastguard Worker                               }
250*89c4ff92SAndroid Build Coastguard Worker                               dim {
251*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
252*89c4ff92SAndroid Build Coastguard Worker                               }
253*89c4ff92SAndroid Build Coastguard Worker                               dim {
254*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
255*89c4ff92SAndroid Build Coastguard Worker                               }
256*89c4ff92SAndroid Build Coastguard Worker                               dim {
257*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 2
258*89c4ff92SAndroid Build Coastguard Worker                               }
259*89c4ff92SAndroid Build Coastguard Worker                             }
260*89c4ff92SAndroid Build Coastguard Worker                           }
261*89c4ff92SAndroid Build Coastguard Worker                         }
262*89c4ff92SAndroid Build Coastguard Worker                       }
263*89c4ff92SAndroid Build Coastguard Worker                      node {
264*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
265*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
266*89c4ff92SAndroid Build Coastguard Worker                          name: "Pooling"
267*89c4ff92SAndroid Build Coastguard Worker                          op_type: "GlobalAveragePool"
268*89c4ff92SAndroid Build Coastguard Worker                       }
269*89c4ff92SAndroid Build Coastguard Worker                       output {
270*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
271*89c4ff92SAndroid Build Coastguard Worker                           type {
272*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
273*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
274*89c4ff92SAndroid Build Coastguard Worker                                shape {
275*89c4ff92SAndroid Build Coastguard Worker                                    dim {
276*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
277*89c4ff92SAndroid Build Coastguard Worker                                    }
278*89c4ff92SAndroid Build Coastguard Worker                                    dim {
279*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 2
280*89c4ff92SAndroid Build Coastguard Worker                                    }
281*89c4ff92SAndroid Build Coastguard Worker                                    dim {
282*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
283*89c4ff92SAndroid Build Coastguard Worker                                    }
284*89c4ff92SAndroid Build Coastguard Worker                                    dim {
285*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 1
286*89c4ff92SAndroid Build Coastguard Worker                                    }
287*89c4ff92SAndroid Build Coastguard Worker                                }
288*89c4ff92SAndroid Build Coastguard Worker                             }
289*89c4ff92SAndroid Build Coastguard Worker                         }
290*89c4ff92SAndroid Build Coastguard Worker                         }
291*89c4ff92SAndroid Build Coastguard Worker                     }
292*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
293*89c4ff92SAndroid Build Coastguard Worker                       version: 7
294*89c4ff92SAndroid Build Coastguard Worker                     })";
295*89c4ff92SAndroid Build Coastguard Worker         Setup();
296*89c4ff92SAndroid Build Coastguard Worker     }
297*89c4ff92SAndroid Build Coastguard Worker };
298*89c4ff92SAndroid Build Coastguard Worker 
299*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GlobalAvgFixture, "GlobalAvgTest")
300*89c4ff92SAndroid Build Coastguard Worker {
301*89c4ff92SAndroid Build Coastguard Worker     RunTest<4>({{"Input", {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}}}, {{"Output", {10/4.0, 26/4.0}}});
302*89c4ff92SAndroid Build Coastguard Worker }
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MaxPoolInvalidFixture, "IncorrectDataTypeMaxPool")
305*89c4ff92SAndroid Build Coastguard Worker {
306*89c4ff92SAndroid Build Coastguard Worker    CHECK_THROWS_AS(Setup(), armnn::ParseException);
307*89c4ff92SAndroid Build Coastguard Worker }
308*89c4ff92SAndroid Build Coastguard Worker 
309*89c4ff92SAndroid Build Coastguard Worker }
310