xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializePooling2d.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Pooling2d")
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker struct Pooling2dFixture : public ParserFlatbuffersSerializeFixture
14*89c4ff92SAndroid Build Coastguard Worker {
Pooling2dFixturePooling2dFixture15*89c4ff92SAndroid Build Coastguard Worker     explicit Pooling2dFixture(const std::string &inputShape,
16*89c4ff92SAndroid Build Coastguard Worker                               const std::string &outputShape,
17*89c4ff92SAndroid Build Coastguard Worker                               const std::string &dataType,
18*89c4ff92SAndroid Build Coastguard Worker                               const std::string &dataLayout,
19*89c4ff92SAndroid Build Coastguard Worker                               const std::string &poolingAlgorithm)
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker             inputIds: [0],
24*89c4ff92SAndroid Build Coastguard Worker             outputIds: [2],
25*89c4ff92SAndroid Build Coastguard Worker             layers: [
26*89c4ff92SAndroid Build Coastguard Worker             {
27*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "InputLayer",
28*89c4ff92SAndroid Build Coastguard Worker                 layer: {
29*89c4ff92SAndroid Build Coastguard Worker                       base: {
30*89c4ff92SAndroid Build Coastguard Worker                             layerBindingId: 0,
31*89c4ff92SAndroid Build Coastguard Worker                             base: {
32*89c4ff92SAndroid Build Coastguard Worker                                 index: 0,
33*89c4ff92SAndroid Build Coastguard Worker                                 layerName: "InputLayer",
34*89c4ff92SAndroid Build Coastguard Worker                                 layerType: "Input",
35*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots: [{
36*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
37*89c4ff92SAndroid Build Coastguard Worker                                     connection: {sourceLayerIndex:0, outputSlotIndex:0 },
38*89c4ff92SAndroid Build Coastguard Worker                                 }],
39*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots: [ {
40*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
41*89c4ff92SAndroid Build Coastguard Worker                                     tensorInfo: {
42*89c4ff92SAndroid Build Coastguard Worker                                         dimensions: )" + inputShape + R"(,
43*89c4ff92SAndroid Build Coastguard Worker                                         dataType: )" + dataType + R"(
44*89c4ff92SAndroid Build Coastguard Worker                                         }}]
45*89c4ff92SAndroid Build Coastguard Worker                                 }
46*89c4ff92SAndroid Build Coastguard Worker                 }}},
47*89c4ff92SAndroid Build Coastguard Worker                 {
48*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "Pooling2dLayer",
49*89c4ff92SAndroid Build Coastguard Worker                 layer: {
50*89c4ff92SAndroid Build Coastguard Worker                       base: {
51*89c4ff92SAndroid Build Coastguard Worker                            index: 1,
52*89c4ff92SAndroid Build Coastguard Worker                            layerName: "Pooling2dLayer",
53*89c4ff92SAndroid Build Coastguard Worker                            layerType: "Pooling2d",
54*89c4ff92SAndroid Build Coastguard Worker                            inputSlots: [{
55*89c4ff92SAndroid Build Coastguard Worker                                   index: 0,
56*89c4ff92SAndroid Build Coastguard Worker                                   connection: {sourceLayerIndex:0, outputSlotIndex:0 },
57*89c4ff92SAndroid Build Coastguard Worker                            }],
58*89c4ff92SAndroid Build Coastguard Worker                            outputSlots: [ {
59*89c4ff92SAndroid Build Coastguard Worker                                   index: 0,
60*89c4ff92SAndroid Build Coastguard Worker                                   tensorInfo: {
61*89c4ff92SAndroid Build Coastguard Worker                                        dimensions: )" + outputShape + R"(,
62*89c4ff92SAndroid Build Coastguard Worker                                        dataType: )" + dataType + R"(
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker                            }}]},
65*89c4ff92SAndroid Build Coastguard Worker                       descriptor: {
66*89c4ff92SAndroid Build Coastguard Worker                            poolType: )" + poolingAlgorithm + R"(,
67*89c4ff92SAndroid Build Coastguard Worker                            outputShapeRounding: "Floor",
68*89c4ff92SAndroid Build Coastguard Worker                            paddingMethod: Exclude,
69*89c4ff92SAndroid Build Coastguard Worker                            dataLayout: )" + dataLayout + R"(,
70*89c4ff92SAndroid Build Coastguard Worker                            padLeft: 0,
71*89c4ff92SAndroid Build Coastguard Worker                            padRight: 0,
72*89c4ff92SAndroid Build Coastguard Worker                            padTop: 0,
73*89c4ff92SAndroid Build Coastguard Worker                            padBottom: 0,
74*89c4ff92SAndroid Build Coastguard Worker                            poolWidth: 2,
75*89c4ff92SAndroid Build Coastguard Worker                            poolHeight: 2,
76*89c4ff92SAndroid Build Coastguard Worker                            strideX: 2,
77*89c4ff92SAndroid Build Coastguard Worker                            strideY: 2
78*89c4ff92SAndroid Build Coastguard Worker                            }
79*89c4ff92SAndroid Build Coastguard Worker                 }},
80*89c4ff92SAndroid Build Coastguard Worker                 {
81*89c4ff92SAndroid Build Coastguard Worker                 layer_type: "OutputLayer",
82*89c4ff92SAndroid Build Coastguard Worker                 layer: {
83*89c4ff92SAndroid Build Coastguard Worker                     base:{
84*89c4ff92SAndroid Build Coastguard Worker                           layerBindingId: 0,
85*89c4ff92SAndroid Build Coastguard Worker                           base: {
86*89c4ff92SAndroid Build Coastguard Worker                                 index: 2,
87*89c4ff92SAndroid Build Coastguard Worker                                 layerName: "OutputLayer",
88*89c4ff92SAndroid Build Coastguard Worker                                 layerType: "Output",
89*89c4ff92SAndroid Build Coastguard Worker                                 inputSlots: [{
90*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
91*89c4ff92SAndroid Build Coastguard Worker                                     connection: {sourceLayerIndex:1, outputSlotIndex:0 },
92*89c4ff92SAndroid Build Coastguard Worker                                 }],
93*89c4ff92SAndroid Build Coastguard Worker                                 outputSlots: [ {
94*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
95*89c4ff92SAndroid Build Coastguard Worker                                     tensorInfo: {
96*89c4ff92SAndroid Build Coastguard Worker                                         dimensions: )" + outputShape + R"(,
97*89c4ff92SAndroid Build Coastguard Worker                                         dataType: )" + dataType + R"(
98*89c4ff92SAndroid Build Coastguard Worker                                     },
99*89c4ff92SAndroid Build Coastguard Worker                             }],
100*89c4ff92SAndroid Build Coastguard Worker                         }}},
101*89c4ff92SAndroid Build Coastguard Worker             }]
102*89c4ff92SAndroid Build Coastguard Worker      }
103*89c4ff92SAndroid Build Coastguard Worker  )";
104*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputLayer", "OutputLayer");
105*89c4ff92SAndroid Build Coastguard Worker     }
106*89c4ff92SAndroid Build Coastguard Worker };
107*89c4ff92SAndroid Build Coastguard Worker 
108*89c4ff92SAndroid Build Coastguard Worker struct SimpleAvgPooling2dFixture : Pooling2dFixture
109*89c4ff92SAndroid Build Coastguard Worker {
SimpleAvgPooling2dFixtureSimpleAvgPooling2dFixture110*89c4ff92SAndroid Build Coastguard Worker     SimpleAvgPooling2dFixture() : Pooling2dFixture("[ 1, 2, 2, 1 ]",
111*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 1, 1, 1 ]",
112*89c4ff92SAndroid Build Coastguard Worker                                                    "Float32", "NHWC", "Average") {}
113*89c4ff92SAndroid Build Coastguard Worker };
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker struct SimpleAvgPooling2dFixture2 : Pooling2dFixture
116*89c4ff92SAndroid Build Coastguard Worker {
SimpleAvgPooling2dFixture2SimpleAvgPooling2dFixture2117*89c4ff92SAndroid Build Coastguard Worker     SimpleAvgPooling2dFixture2() : Pooling2dFixture("[ 1, 2, 2, 1 ]",
118*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1 ]",
119*89c4ff92SAndroid Build Coastguard Worker                                                     "QuantisedAsymm8", "NHWC", "Average") {}
120*89c4ff92SAndroid Build Coastguard Worker };
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker struct SimpleMaxPooling2dFixture : Pooling2dFixture
123*89c4ff92SAndroid Build Coastguard Worker {
SimpleMaxPooling2dFixtureSimpleMaxPooling2dFixture124*89c4ff92SAndroid Build Coastguard Worker     SimpleMaxPooling2dFixture() : Pooling2dFixture("[ 1, 1, 2, 2 ]",
125*89c4ff92SAndroid Build Coastguard Worker                                                    "[ 1, 1, 1, 1 ]",
126*89c4ff92SAndroid Build Coastguard Worker                                                    "Float32", "NCHW", "Max") {}
127*89c4ff92SAndroid Build Coastguard Worker };
128*89c4ff92SAndroid Build Coastguard Worker 
129*89c4ff92SAndroid Build Coastguard Worker struct SimpleMaxPooling2dFixture2 : Pooling2dFixture
130*89c4ff92SAndroid Build Coastguard Worker {
SimpleMaxPooling2dFixture2SimpleMaxPooling2dFixture2131*89c4ff92SAndroid Build Coastguard Worker     SimpleMaxPooling2dFixture2() : Pooling2dFixture("[ 1, 1, 2, 2 ]",
132*89c4ff92SAndroid Build Coastguard Worker                                                     "[ 1, 1, 1, 1 ]",
133*89c4ff92SAndroid Build Coastguard Worker                                                     "QuantisedAsymm8", "NCHW", "Max") {}
134*89c4ff92SAndroid Build Coastguard Worker };
135*89c4ff92SAndroid Build Coastguard Worker 
136*89c4ff92SAndroid Build Coastguard Worker struct SimpleL2Pooling2dFixture : Pooling2dFixture
137*89c4ff92SAndroid Build Coastguard Worker {
SimpleL2Pooling2dFixtureSimpleL2Pooling2dFixture138*89c4ff92SAndroid Build Coastguard Worker     SimpleL2Pooling2dFixture() : Pooling2dFixture("[ 1, 2, 2, 1 ]",
139*89c4ff92SAndroid Build Coastguard Worker                                                   "[ 1, 1, 1, 1 ]",
140*89c4ff92SAndroid Build Coastguard Worker                                                   "Float32", "NHWC", "L2") {}
141*89c4ff92SAndroid Build Coastguard Worker };
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleAvgPooling2dFixture, "Pooling2dFloat32Avg")
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0, { 2, 3, 5, 2 }, { 3 });
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleAvgPooling2dFixture2, "Pooling2dQuantisedAsymm8Avg")
149*89c4ff92SAndroid Build Coastguard Worker {
150*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80 },{ 50 });
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMaxPooling2dFixture, "Pooling2dFloat32Max")
154*89c4ff92SAndroid Build Coastguard Worker {
155*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0, { 2, 5, 5, 2 }, { 5 });
156*89c4ff92SAndroid Build Coastguard Worker }
157*89c4ff92SAndroid Build Coastguard Worker 
158*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMaxPooling2dFixture2, "Pooling2dQuantisedAsymm8Max")
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(0,{ 20, 40, 60, 80 },{ 80 });
161*89c4ff92SAndroid Build Coastguard Worker }
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleL2Pooling2dFixture, "Pooling2dFloat32L2")
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0, { 2, 3, 5, 2 }, { 3.2403703f });
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
170