xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/MaxPool2D.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 TEST_SUITE("TensorflowLiteParser_MaxPool2D")
9 {
10 struct MaxPool2DFixture : public ParserFlatbuffersFixture
11 {
MaxPool2DFixtureMaxPool2DFixture12     explicit MaxPool2DFixture(std::string inputdim, std::string outputdim, std::string dataType)
13     {
14         m_JsonString = R"(
15         {
16             "version": 3,
17             "operator_codes": [ { "builtin_code": "MAX_POOL_2D" } ],
18             "subgraphs": [
19             {
20                 "tensors": [
21                 {
22                     "shape": )"
23                     + outputdim
24                     + R"(,
25                     "type": )"
26                       + dataType
27                       + R"(,
28                             "buffer": 0,
29                             "name": "OutputTensor",
30                             "quantization": {
31                                 "min": [ 0.0 ],
32                                 "max": [ 255.0 ],
33                                 "scale": [ 1.0 ],
34                                 "zero_point": [ 0 ]
35                             }
36                 },
37                 {
38                     "shape": )"
39                     + inputdim
40                     + R"(,
41                     "type": )"
42                       + dataType
43                       + R"(,
44                             "buffer": 1,
45                             "name": "InputTensor",
46                             "quantization": {
47                                 "min": [ 0.0 ],
48                                 "max": [ 255.0 ],
49                                 "scale": [ 1.0 ],
50                                 "zero_point": [ 0 ]
51                             }
52                 }
53                 ],
54                 "inputs": [ 1 ],
55                 "outputs": [ 0 ],
56                 "operators": [ {
57                         "opcode_index": 0,
58                         "inputs": [ 1 ],
59                         "outputs": [ 0 ],
60                         "builtin_options_type": "Pool2DOptions",
61                         "builtin_options":
62                         {
63                             "padding": "VALID",
64                             "stride_w": 2,
65                             "stride_h": 2,
66                             "filter_width": 2,
67                             "filter_height": 2,
68                             "fused_activation_function": "NONE"
69                         },
70                         "custom_options_format": "FLEXBUFFERS"
71                     } ]
72                 }
73             ],
74             "description": "MaxPool2D test.",
75             "buffers" : [ {}, {} ]
76         })";
77 
78         SetupSingleInputSingleOutput("InputTensor", "OutputTensor");
79     }
80 };
81 
82 
83 struct MaxPoolLiteFixtureUint1DOutput : MaxPool2DFixture
84 {
MaxPoolLiteFixtureUint1DOutputMaxPoolLiteFixtureUint1DOutput85     MaxPoolLiteFixtureUint1DOutput() : MaxPool2DFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]", "UINT8") {}
86 };
87 
88 struct MaxPoolLiteFixtureFloat1DOutput : MaxPool2DFixture
89 {
MaxPoolLiteFixtureFloat1DOutputMaxPoolLiteFixtureFloat1DOutput90     MaxPoolLiteFixtureFloat1DOutput() : MaxPool2DFixture("[ 1, 2, 2, 1 ]", "[ 1, 1, 1, 1 ]", "FLOAT32") {}
91 };
92 
93 struct MaxPoolLiteFixtureUint2DOutput : MaxPool2DFixture
94 {
MaxPoolLiteFixtureUint2DOutputMaxPoolLiteFixtureUint2DOutput95     MaxPoolLiteFixtureUint2DOutput() : MaxPool2DFixture("[ 1, 4, 4, 1 ]", "[ 1, 2, 2, 1 ]", "UINT8") {}
96 };
97 
98 TEST_CASE_FIXTURE(MaxPoolLiteFixtureUint1DOutput, "MaxPoolLiteUint1DOutput")
99 {
100     RunTest<4, armnn::DataType::QAsymmU8>(0, { 2, 3, 5, 2 }, { 5 });
101 }
102 
103 TEST_CASE_FIXTURE(MaxPoolLiteFixtureFloat1DOutput, "MaxPoolLiteFloat1DOutput")
104 {
105     RunTest<4, armnn::DataType::Float32>(0, { 2.0f, 3.0f, 5.0f, 2.0f },  { 5.0f });
106 }
107 
108 TEST_CASE_FIXTURE(MaxPoolLiteFixtureUint2DOutput, "MaxPoolLiteUint2DOutput")
109 {
110     RunTest<4, armnn::DataType::QAsymmU8>(
111         0, { 1, 2, 2, 3, 5, 6, 7, 8, 3, 2, 1, 0, 1, 2, 3, 4 }, { 6, 8, 3, 4 });
112 }
113 
114 TEST_CASE_FIXTURE(MaxPoolLiteFixtureFloat1DOutput, "MaxPoolIncorrectDataTypeError")
115 {
116     CHECK_THROWS_AS((RunTest<4, armnn::DataType::QAsymmU8>(0, { 2, 3, 5, 2 }, { 5 })), armnn::Exception);
117 }
118 
119 }
120