xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Conv3D.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 #include <sstream>
8 
9 // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
10 #if defined(ARMNN_POST_TFLITE_2_4)
11 TEST_SUITE("TensorflowLiteParser_Conv3D")
12 {
13 struct SimpleConv3DFixture : public ParserFlatbuffersFixture
14 {
SimpleConv3DFixtureSimpleConv3DFixture15     explicit SimpleConv3DFixture()
16     {
17         m_JsonString = R"(
18             {
19                 "version": 3,
20                 "operator_codes": [ { "builtin_code": "CONV_3D" } ],
21                 "subgraphs": [ {
22                     "tensors": [
23                         {
24                             "shape": [ 1, 2, 3, 3, 1 ],
25                             "type": "UINT8",
26                             "buffer": 0,
27                             "name": "inputTensor",
28                             "quantization": {
29                                 "min": [ 0.0 ],
30                                 "max": [ 255.0 ],
31                                 "scale": [ 1.0 ],
32                                 "zero_point": [ 0 ],
33                             }
34                         },
35                         {
36                             "shape": [ 1, 1, 1, 1, 1 ],
37                             "type": "UINT8",
38                             "buffer": 1,
39                             "name": "outputTensor",
40                             "quantization": {
41                                 "min": [ 0.0 ],
42                                 "max": [ 511.0 ],
43                                 "scale": [ 2.0 ],
44                                 "zero_point": [ 0 ],
45                             }
46                         },
47                         {
48                             "shape": [ 2, 3, 3, 1, 1 ],
49                             "type": "UINT8",
50                             "buffer": 2,
51                             "name": "filterTensor",
52                             "quantization": {
53                                 "min": [ 0.0 ],
54                                 "max": [ 255.0 ],
55                                 "scale": [ 1.0 ],
56                                 "zero_point": [ 0 ],
57                             }
58                         }
59                     ],
60                     "inputs": [ 0 ],
61                     "outputs": [ 1 ],
62                     "operators": [
63                         {
64                             "opcode_index": 0,
65                             "inputs": [ 0, 2 ],
66                             "outputs": [ 1 ],
67                             "builtin_options_type": "Conv3DOptions",
68                             "builtin_options": {
69                                 "padding": "VALID",
70                                 "stride_d": 1,
71                                 "stride_w": 1,
72                                 "stride_h": 1,
73                                 "fused_activation_function": "NONE"
74                             },
75                             "custom_options_format": "FLEXBUFFERS"
76                         }
77                     ],
78                 } ],
79                 "buffers" : [
80                     { },
81                     { },
82                     { "data": [ 2,1,0,  6,2,1, 4,1,2,
83                                 1,2,1,  2,0,2, 2,1,1 ], },
84                     { },
85                 ]
86             }
87         )";
88         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
89     }
90 };
91 
92 TEST_CASE_FIXTURE(SimpleConv3DFixture, "ParseSimpleConv3D")
93 {
94     RunTest<5, armnn::DataType::QAsymmU8>(
95         0,
96         {
97             1, 2, 3,
98             4, 5, 6,
99             7, 8, 9,
100 
101             10, 11, 12,
102             13, 14, 15,
103             16, 17, 18,
104         },
105         // Due to the output scaling we need to half the values.
106         {
107             (1*2 + 2*1 + 3*0 +
108              4*6 + 5*2 + 6*1 +
109              7*4 + 8*1 + 9*2 +
110 
111              10*1 + 11*2 + 12*1 +
112              13*2 + 14*0 + 15*2 +
113              16*2 + 17*1 + 18*1) /2
114         });
115 }
116 struct Conv3DWithBiasesFixture : public ParserFlatbuffersFixture
117 {
Conv3DWithBiasesFixtureConv3DWithBiasesFixture118     explicit Conv3DWithBiasesFixture(const std::string& inputShape,
119                                      const std::string& outputShape,
120                                      const std::string& filterShape,
121                                      const std::string& filterData,
122                                      const std::string& biasShape,
123                                      const std::string& biasData,
124                                      const std::string& strides,
125                                      const std::string& activation="NONE",
126                                      const std::string& filterScale="1.0",
127                                      const std::string& filterZeroPoint="0",
128                                      const std::string& outputScale="1.0",
129                                      const std::string& outputZeroPoint="0")
130     {
131         m_JsonString = R"(
132             {
133                 "version": 3,
134                 "operator_codes": [ { "builtin_code": "CONV_3D" } ],
135                 "subgraphs": [ {
136                     "tensors": [
137                         {
138                             "shape": )" + inputShape + R"(,
139                             "type": "UINT8",
140                             "buffer": 0,
141                             "name": "inputTensor",
142                             "quantization": {
143                                 "min": [ 0.0 ],
144                                 "max": [ 255.0 ],
145                                 "scale": [ 1.0 ],
146                                 "zero_point": [ 0 ],
147                             }
148                         },
149                         {
150                             "shape": )" + outputShape + R"(,
151                             "type": "UINT8",
152                             "buffer": 1,
153                             "name": "outputTensor",
154                             "quantization": {
155                                 "min": [ 0.0 ],
156                                 "max": [ 511.0 ],
157                                 "scale": [ )" + outputScale + R"( ],
158                                 "zero_point": [ )" + outputZeroPoint + R"( ],
159                             }
160                         },
161                         {
162                             "shape": )" + filterShape + R"( ,
163                             "type": "UINT8",
164                             "buffer": 2,
165                             "name": "filterTensor",
166                             "quantization": {
167                                 "min": [ 0.0 ],
168                                 "max": [ 255.0 ],
169                                 "scale": [ )" + filterScale + R"( ],
170                                 "zero_point": [ )" + filterZeroPoint + R"( ],
171                             }
172                         },
173                         {
174                             "shape": )" + biasShape + R"( ,
175                             "type": "INT32",
176                             "buffer": 3,
177                             "name": "biasTensor",
178                             "quantization": {
179                                 "min": [ 0.0 ],
180                                 "max": [ 255.0 ],
181                                 "scale": [ 1.0 ],
182                                 "zero_point": [ 0 ],
183                             }
184                         }
185                     ],
186                     "inputs": [ 0 ],
187                     "outputs": [ 1 ],
188                     "operators": [
189                         {
190                             "opcode_index": 0,
191                             "inputs": [ 0, 2, 3 ],
192                             "outputs": [ 1 ],
193                             "builtin_options_type": "Conv3DOptions",
194                             "builtin_options": {
195                                 "padding": "SAME",
196                                 "stride_d": )" + strides + R"(,
197                                 "stride_w": )" + strides + R"(,
198                                 "stride_h": )" + strides + R"(,
199                                 "fused_activation_function": )" + activation + R"(
200                             },
201                             "custom_options_format": "FLEXBUFFERS"
202                         }
203                     ],
204                 } ],
205                 "buffers" : [
206                     { },
207                     { },
208                     { "data": )" + filterData + R"(, },
209                     { "data": )" + biasData + R"(, },
210                 ]
211             }
212         )";
213         SetupSingleInputSingleOutput("inputTensor", "outputTensor");
214     }
215 };
216 
217 struct SimpleConv3DWithBiasesFixture : Conv3DWithBiasesFixture
218 {
SimpleConv3DWithBiasesFixtureSimpleConv3DWithBiasesFixture219     SimpleConv3DWithBiasesFixture()
220     : Conv3DWithBiasesFixture("[ 1, 2, 2, 2, 1 ]",      // inputShape
221                               "[ 1, 2, 2, 2, 1 ]",      // outputShape
222                               "[ 2, 2, 2, 1, 1 ]",      // filterShape
223                               "[ 2,1, 1,0, 0,1, 1,1 ]", // filterData
224                               "[ 1 ]",                  // biasShape
225                               "[ 5, 0, 0, 0 ]",         // biasData
226                               "1")                      // stride d, w and h
227     {}
228 };
229 
230 TEST_CASE_FIXTURE(SimpleConv3DWithBiasesFixture, "ParseConv3DWithBias")
231 {
232     RunTest<5,
233             armnn::DataType::QAsymmU8>(0,
234                                        { 1, 2, 3, 4, 5, 6, 7, 8 },
235                                        { 33, 21, 23, 13, 28, 25, 27, 21 });
236 }
237 
238 TEST_CASE_FIXTURE(SimpleConv3DWithBiasesFixture, "ParseDynamicConv3DWithBias")
239 {
240     RunTest<5,
241             armnn::DataType::QAsymmU8,
242             armnn::DataType::QAsymmU8>(0,
243                                        { { "inputTensor", { 2, 4, 6, 8, 10, 12, 14, 16 } } },
244                                        { { "outputTensor", {  61, 37, 41, 21, 51, 45, 49, 37 } } },
245                                        true);
246 }
247 
248 struct Relu6Conv3DWithBiasesFixture : Conv3DWithBiasesFixture
249 {
Relu6Conv3DWithBiasesFixtureRelu6Conv3DWithBiasesFixture250     Relu6Conv3DWithBiasesFixture()
251     : Conv3DWithBiasesFixture("[ 1, 2, 2, 2, 1 ]",       // inputShape
252                               "[ 1, 2, 2, 2, 1 ]",       // outputShape
253                               "[ 2, 2, 2, 1, 1 ]",       // filterShape
254                               "[ 2,1, 1,0, 0,1, 1,1 ]",  // filterData
255                               "[ 1 ]",                   // biasShape
256                               "[ 0, 0, 0, 0 ]",          // biasData
257                               "1",                       // stride d, w, and h
258                               "RELU6",                   // activation
259                               "1.0",                     // filter scale
260                               "0",                       // filter zero point
261                               "2.0",                     // output scale
262                               "0")                       // output zero point
263     {}
264 };
265 
266 TEST_CASE_FIXTURE(Relu6Conv3DWithBiasesFixture, "ParseConv3DAndRelu6WithBias")
267 {
268     uint8_t relu6Min = 6 / 2; // Divide by output scale
269 
270     RunTest<5, armnn::DataType::QAsymmU8>(
271         0,
272         {
273            1, 2, 3, 4, 5, 6, 7, 8
274         },
275         // RELU6 cuts output values at +6
276         {
277             std::min(relu6Min, static_cast<uint8_t>((1*2 + 2*1 + 3*1 + 4*0 + 5*0 + 6*1 + 7*1 + 8*1)/2)),
278             std::min(relu6Min, static_cast<uint8_t>((2*2 + 0*1 + 0*1 + 0*0 + 0*0 + 0*1 + 8*1 + 0*1)/2)),
279             std::min(relu6Min, static_cast<uint8_t>((3*2 + 0*1 + 0*1 + 0*0 + 0*0 + 8*1 + 0*1 + 0*1)/2)),
280             std::min(relu6Min, static_cast<uint8_t>((4*2 + 0*1 + 0*1 + 0*0 + 8*0 + 0*1 + 0*1 + 0*1)/2)),
281             std::min(relu6Min, static_cast<uint8_t>((5*2 + 0*1 + 0*1 + 8*0 + 0*0 + 0*1 + 0*1 + 0*1)/2)),
282             std::min(relu6Min, static_cast<uint8_t>((6*2 + 0*1 + 8*1 + 0*0 + 0*0 + 0*1 + 0*1 + 0*1)/2)),
283             std::min(relu6Min, static_cast<uint8_t>((7*2 + 8*1 + 0*1 + 0*0 + 0*0 + 0*1 + 0*1 + 0*1)/2)),
284             std::min(relu6Min, static_cast<uint8_t>((8*2 + 0*1 + 0*1 + 0*0 + 0*0 + 0*1 + 0*1 + 0*1)/2))
285         });
286 }
287 
288 }
289 #endif
290