xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/AvgPool2D.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "ParserFlatbuffersFixture.hpp"
6 
7 TEST_SUITE("TensorflowLiteParser_AvgPool2D")
8 {
9 struct AvgPool2DFixture : public ParserFlatbuffersFixture
10 {
AvgPool2DFixtureAvgPool2DFixture11     explicit AvgPool2DFixture(std::string inputdim, std::string outputdim, std::string dataType)
12     {
13         m_JsonString = R"(
14         {
15             "version": 3,
16             "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
17             "subgraphs": [
18             {
19                 "tensors": [
20                 {
21                     "shape": )"
22                     + outputdim
23                     + R"(,
24                     "type": )"
25                       + dataType
26                       + R"(,
27                             "buffer": 0,
28                             "name": "OutputTensor",
29                             "quantization": {
30                                 "min": [ 0.0 ],
31                                 "max": [ 255.0 ],
32                                 "scale": [ 1.0 ],
33                                 "zero_point": [ 0 ]
34                             }
35                 },
36                 {
37                     "shape": )"
38                     + inputdim
39                     + R"(,
40                     "type": )"
41                       + dataType
42                       + R"(,
43                             "buffer": 1,
44                             "name": "InputTensor",
45                             "quantization": {
46                                 "min": [ 0.0 ],
47                                 "max": [ 255.0 ],
48                                 "scale": [ 1.0 ],
49                                 "zero_point": [ 0 ]
50                             }
51                 }
52                 ],
53                 "inputs": [ 1 ],
54                 "outputs": [ 0 ],
55                 "operators": [ {
56                         "opcode_index": 0,
57                         "inputs": [ 1 ],
58                         "outputs": [ 0 ],
59                         "builtin_options_type": "Pool2DOptions",
60                         "builtin_options":
61                         {
62                             "padding": "VALID",
63                             "stride_w": 2,
64                             "stride_h": 2,
65                             "filter_width": 2,
66                             "filter_height": 2,
67                             "fused_activation_function": "NONE"
68                         },
69                         "custom_options_format": "FLEXBUFFERS"
70                     } ]
71                 }
72             ],
73             "description": "AvgPool2D test.",
74             "buffers" : [ {}, {} ]
75         })";
76 
77         SetupSingleInputSingleOutput("InputTensor", "OutputTensor");
78     }
79 };
80 
81 
82 struct AvgPoolLiteFixtureUint1DOutput : AvgPool2DFixture
83 {
AvgPoolLiteFixtureUint1DOutputAvgPoolLiteFixtureUint1DOutput84     AvgPoolLiteFixtureUint1DOutput() : AvgPool2DFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]", "UINT8") {}
85 };
86 
87 struct AvgPoolLiteFixtureFloat1DOutput : AvgPool2DFixture
88 {
AvgPoolLiteFixtureFloat1DOutputAvgPoolLiteFixtureFloat1DOutput89     AvgPoolLiteFixtureFloat1DOutput() : AvgPool2DFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]", "FLOAT32") {}
90 };
91 
92 struct AvgPoolLiteFixture2DOutput : AvgPool2DFixture
93 {
AvgPoolLiteFixture2DOutputAvgPoolLiteFixture2DOutput94     AvgPoolLiteFixture2DOutput() : AvgPool2DFixture("[ 1, 4, 4, 1 ]", "[ 1, 2, 2, 1 ]", "UINT8") {}
95 };
96 
97 TEST_CASE_FIXTURE(AvgPoolLiteFixtureUint1DOutput, "AvgPoolLite1DOutput")
98 {
99     RunTest<4, armnn::DataType::QAsymmU8>(0, {2, 3, 5, 2 }, { 3 });
100 }
101 
102 TEST_CASE_FIXTURE(AvgPoolLiteFixtureFloat1DOutput, "AvgPoolLiteFloat1DOutput")
103 {
104     RunTest<4, armnn::DataType::Float32>(0, { 2.0f, 3.0f, 5.0f, 2.0f },  { 3.0f });
105 }
106 
107 TEST_CASE_FIXTURE(AvgPoolLiteFixture2DOutput, "AvgPoolLite2DOutput")
108 {
109     RunTest<4, armnn::DataType::QAsymmU8>(
110         0, { 1, 2, 2, 3, 5, 6, 7, 8, 3, 2, 1, 0, 1, 2, 3, 4 }, { 4, 5, 2, 2 });
111 }
112 
113 TEST_CASE_FIXTURE(AvgPoolLiteFixtureFloat1DOutput, "IncorrectDataTypeError")
114 {
115     CHECK_THROWS_AS((RunTest<4, armnn::DataType::QAsymmU8>(0, {2, 3, 5, 2 }, { 3 })), armnn::Exception);
116 }
117 
118 }
119