xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/AvgPool2D.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_AvgPool2D")
8*89c4ff92SAndroid Build Coastguard Worker {
9*89c4ff92SAndroid Build Coastguard Worker struct AvgPool2DFixture : public ParserFlatbuffersFixture
10*89c4ff92SAndroid Build Coastguard Worker {
AvgPool2DFixtureAvgPool2DFixture11*89c4ff92SAndroid Build Coastguard Worker     explicit AvgPool2DFixture(std::string inputdim, std::string outputdim, std::string dataType)
12*89c4ff92SAndroid Build Coastguard Worker     {
13*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
14*89c4ff92SAndroid Build Coastguard Worker         {
15*89c4ff92SAndroid Build Coastguard Worker             "version": 3,
16*89c4ff92SAndroid Build Coastguard Worker             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
17*89c4ff92SAndroid Build Coastguard Worker             "subgraphs": [
18*89c4ff92SAndroid Build Coastguard Worker             {
19*89c4ff92SAndroid Build Coastguard Worker                 "tensors": [
20*89c4ff92SAndroid Build Coastguard Worker                 {
21*89c4ff92SAndroid Build Coastguard Worker                     "shape": )"
22*89c4ff92SAndroid Build Coastguard Worker                     + outputdim
23*89c4ff92SAndroid Build Coastguard Worker                     + R"(,
24*89c4ff92SAndroid Build Coastguard Worker                     "type": )"
25*89c4ff92SAndroid Build Coastguard Worker                       + dataType
26*89c4ff92SAndroid Build Coastguard Worker                       + R"(,
27*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
28*89c4ff92SAndroid Build Coastguard Worker                             "name": "OutputTensor",
29*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
30*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
31*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
33*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ]
34*89c4ff92SAndroid Build Coastguard Worker                             }
35*89c4ff92SAndroid Build Coastguard Worker                 },
36*89c4ff92SAndroid Build Coastguard Worker                 {
37*89c4ff92SAndroid Build Coastguard Worker                     "shape": )"
38*89c4ff92SAndroid Build Coastguard Worker                     + inputdim
39*89c4ff92SAndroid Build Coastguard Worker                     + R"(,
40*89c4ff92SAndroid Build Coastguard Worker                     "type": )"
41*89c4ff92SAndroid Build Coastguard Worker                       + dataType
42*89c4ff92SAndroid Build Coastguard Worker                       + R"(,
43*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
44*89c4ff92SAndroid Build Coastguard Worker                             "name": "InputTensor",
45*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
46*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
47*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
48*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
49*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ]
50*89c4ff92SAndroid Build Coastguard Worker                             }
51*89c4ff92SAndroid Build Coastguard Worker                 }
52*89c4ff92SAndroid Build Coastguard Worker                 ],
53*89c4ff92SAndroid Build Coastguard Worker                 "inputs": [ 1 ],
54*89c4ff92SAndroid Build Coastguard Worker                 "outputs": [ 0 ],
55*89c4ff92SAndroid Build Coastguard Worker                 "operators": [ {
56*89c4ff92SAndroid Build Coastguard Worker                         "opcode_index": 0,
57*89c4ff92SAndroid Build Coastguard Worker                         "inputs": [ 1 ],
58*89c4ff92SAndroid Build Coastguard Worker                         "outputs": [ 0 ],
59*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options_type": "Pool2DOptions",
60*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options":
61*89c4ff92SAndroid Build Coastguard Worker                         {
62*89c4ff92SAndroid Build Coastguard Worker                             "padding": "VALID",
63*89c4ff92SAndroid Build Coastguard Worker                             "stride_w": 2,
64*89c4ff92SAndroid Build Coastguard Worker                             "stride_h": 2,
65*89c4ff92SAndroid Build Coastguard Worker                             "filter_width": 2,
66*89c4ff92SAndroid Build Coastguard Worker                             "filter_height": 2,
67*89c4ff92SAndroid Build Coastguard Worker                             "fused_activation_function": "NONE"
68*89c4ff92SAndroid Build Coastguard Worker                         },
69*89c4ff92SAndroid Build Coastguard Worker                         "custom_options_format": "FLEXBUFFERS"
70*89c4ff92SAndroid Build Coastguard Worker                     } ]
71*89c4ff92SAndroid Build Coastguard Worker                 }
72*89c4ff92SAndroid Build Coastguard Worker             ],
73*89c4ff92SAndroid Build Coastguard Worker             "description": "AvgPool2D test.",
74*89c4ff92SAndroid Build Coastguard Worker             "buffers" : [ {}, {} ]
75*89c4ff92SAndroid Build Coastguard Worker         })";
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker         SetupSingleInputSingleOutput("InputTensor", "OutputTensor");
78*89c4ff92SAndroid Build Coastguard Worker     }
79*89c4ff92SAndroid Build Coastguard Worker };
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker struct AvgPoolLiteFixtureUint1DOutput : AvgPool2DFixture
83*89c4ff92SAndroid Build Coastguard Worker {
AvgPoolLiteFixtureUint1DOutputAvgPoolLiteFixtureUint1DOutput84*89c4ff92SAndroid Build Coastguard Worker     AvgPoolLiteFixtureUint1DOutput() : AvgPool2DFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]", "UINT8") {}
85*89c4ff92SAndroid Build Coastguard Worker };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker struct AvgPoolLiteFixtureFloat1DOutput : AvgPool2DFixture
88*89c4ff92SAndroid Build Coastguard Worker {
AvgPoolLiteFixtureFloat1DOutputAvgPoolLiteFixtureFloat1DOutput89*89c4ff92SAndroid Build Coastguard Worker     AvgPoolLiteFixtureFloat1DOutput() : AvgPool2DFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]", "FLOAT32") {}
90*89c4ff92SAndroid Build Coastguard Worker };
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker struct AvgPoolLiteFixture2DOutput : AvgPool2DFixture
93*89c4ff92SAndroid Build Coastguard Worker {
AvgPoolLiteFixture2DOutputAvgPoolLiteFixture2DOutput94*89c4ff92SAndroid Build Coastguard Worker     AvgPoolLiteFixture2DOutput() : AvgPool2DFixture("[ 1, 4, 4, 1 ]", "[ 1, 2, 2, 1 ]", "UINT8") {}
95*89c4ff92SAndroid Build Coastguard Worker };
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(AvgPoolLiteFixtureUint1DOutput, "AvgPoolLite1DOutput")
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(0, {2, 3, 5, 2 }, { 3 });
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(AvgPoolLiteFixtureFloat1DOutput, "AvgPoolLiteFloat1DOutput")
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::Float32>(0, { 2.0f, 3.0f, 5.0f, 2.0f },  { 3.0f });
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(AvgPoolLiteFixture2DOutput, "AvgPoolLite2DOutput")
108*89c4ff92SAndroid Build Coastguard Worker {
109*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::QAsymmU8>(
110*89c4ff92SAndroid Build Coastguard Worker         0, { 1, 2, 2, 3, 5, 6, 7, 8, 3, 2, 1, 0, 1, 2, 3, 4 }, { 4, 5, 2, 2 });
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(AvgPoolLiteFixtureFloat1DOutput, "IncorrectDataTypeError")
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS((RunTest<4, armnn::DataType::QAsymmU8>(0, {2, 3, 5, 2 }, { 3 })), armnn::Exception);
116*89c4ff92SAndroid Build Coastguard Worker }
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker }
119