xref: /aosp_15_r20/external/armnn/delegate/test/SplitTestHelper.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "TestUtils.hpp"
9 
10 #include <armnn_delegate.hpp>
11 #include <DelegateTestInterpreter.hpp>
12 
13 #include <flatbuffers/flatbuffers.h>
14 #include <tensorflow/lite/kernels/register.h>
15 #include <tensorflow/lite/version.h>
16 
17 #include <schema_generated.h>
18 
19 #include <doctest/doctest.h>
20 
21 namespace
22 {
23 
CreateSplitTfLiteModel(tflite::TensorType tensorType,std::vector<int32_t> & axisTensorShape,std::vector<int32_t> & inputTensorShape,const std::vector<std::vector<int32_t>> & outputTensorShapes,std::vector<int32_t> & axisData,const int32_t numSplits,float quantScale=1.0f,int quantOffset=0)24 std::vector<char> CreateSplitTfLiteModel(tflite::TensorType tensorType,
25                                          std::vector<int32_t>& axisTensorShape,
26                                          std::vector<int32_t>& inputTensorShape,
27                                          const std::vector<std::vector<int32_t>>& outputTensorShapes,
28                                          std::vector<int32_t>& axisData,
29                                          const int32_t numSplits,
30                                          float quantScale = 1.0f,
31                                          int quantOffset  = 0)
32 {
33     using namespace tflite;
34     flatbuffers::FlatBufferBuilder flatBufferBuilder;
35 
36     std::vector<flatbuffers::Offset<tflite::Buffer>> buffers;
37     buffers.push_back(CreateBuffer(flatBufferBuilder));
38     buffers.push_back(CreateBuffer(flatBufferBuilder));
39     buffers.push_back(CreateBuffer(flatBufferBuilder,
40                                    flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(axisData.data()),
41                                                                   sizeof(int32_t) * axisData.size())));
42 
43     auto quantizationParameters =
44             CreateQuantizationParameters(flatBufferBuilder,
45                                          0,
46                                          0,
47                                          flatBufferBuilder.CreateVector<float>({ quantScale }),
48                                          flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
49 
50     std::array<flatbuffers::Offset<Tensor>, 4> tensors;
51     tensors[0] = CreateTensor(flatBufferBuilder,
52                               flatBufferBuilder.CreateVector<int32_t>(axisTensorShape.data(),
53                                                                       axisTensorShape.size()),
54                               ::tflite::TensorType_INT32,
55                               2,
56                               flatBufferBuilder.CreateString("axis"),
57                               quantizationParameters);
58     tensors[1] = CreateTensor(flatBufferBuilder,
59                               flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
60                                                                       inputTensorShape.size()),
61                               tensorType,
62                               1,
63                               flatBufferBuilder.CreateString("input"),
64                               quantizationParameters);
65 
66     // Create output tensor
67     for (unsigned int i = 0; i < outputTensorShapes.size(); ++i)
68     {
69         buffers.push_back(CreateBuffer(flatBufferBuilder));
70         tensors[i + 2] = CreateTensor(flatBufferBuilder,
71                                       flatBufferBuilder.CreateVector<int32_t>(outputTensorShapes[i].data(),
72                                                                               outputTensorShapes[i].size()),
73                                       tensorType,
74                                       (i+3),
75                                       flatBufferBuilder.CreateString("output"),
76                                       quantizationParameters);
77     }
78 
79     // create operator. Mean uses ReducerOptions.
80     tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_SplitOptions;
81     flatbuffers::Offset<void> operatorBuiltinOptions = CreateSplitOptions(flatBufferBuilder, numSplits).Union();
82 
83     const std::vector<int> operatorInputs{ {0, 1} };
84     const std::vector<int> operatorOutputs{ {2, 3} };
85     flatbuffers::Offset <Operator> controlOperator =
86             CreateOperator(flatBufferBuilder,
87                            0,
88                            flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
89                            flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
90                            operatorBuiltinOptionsType,
91                            operatorBuiltinOptions);
92 
93     const std::vector<int> subgraphInputs{ {0, 1} };
94     const std::vector<int> subgraphOutputs{ {2, 3} };
95     flatbuffers::Offset <SubGraph> subgraph =
96             CreateSubGraph(flatBufferBuilder,
97                            flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
98                            flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
99                            flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
100                            flatBufferBuilder.CreateVector(&controlOperator, 1));
101 
102     flatbuffers::Offset <flatbuffers::String> modelDescription =
103             flatBufferBuilder.CreateString("ArmnnDelegate: SPLIT Operator Model");
104     flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, BuiltinOperator_SPLIT);
105 
106     flatbuffers::Offset <Model> flatbufferModel =
107             CreateModel(flatBufferBuilder,
108                         TFLITE_SCHEMA_VERSION,
109                         flatBufferBuilder.CreateVector(&operatorCode, 1),
110                         flatBufferBuilder.CreateVector(&subgraph, 1),
111                         modelDescription,
112                         flatBufferBuilder.CreateVector(buffers));
113 
114     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
115 
116     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
117                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
118 }
119 
120 template <typename T>
SplitTest(tflite::TensorType tensorType,std::vector<armnn::BackendId> & backends,std::vector<int32_t> & axisTensorShape,std::vector<int32_t> & inputTensorShape,std::vector<std::vector<int32_t>> & outputTensorShapes,std::vector<int32_t> & axisData,std::vector<T> & inputValues,std::vector<std::vector<T>> & expectedOutputValues,const int32_t numSplits,float quantScale=1.0f,int quantOffset=0)121 void SplitTest(tflite::TensorType tensorType,
122                std::vector<armnn::BackendId>& backends,
123                std::vector<int32_t>& axisTensorShape,
124                std::vector<int32_t>& inputTensorShape,
125                std::vector<std::vector<int32_t>>& outputTensorShapes,
126                std::vector<int32_t>& axisData,
127                std::vector<T>& inputValues,
128                std::vector<std::vector<T>>& expectedOutputValues,
129                const int32_t numSplits,
130                float quantScale = 1.0f,
131                int quantOffset  = 0)
132 {
133     using namespace delegateTestInterpreter;
134     std::vector<char> modelBuffer = CreateSplitTfLiteModel(tensorType,
135                                                            axisTensorShape,
136                                                            inputTensorShape,
137                                                            outputTensorShapes,
138                                                            axisData,
139                                                            numSplits,
140                                                            quantScale,
141                                                            quantOffset);
142     // Setup interpreter with just TFLite Runtime.
143     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
144     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
145     CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 1) == kTfLiteOk);
146     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
147 
148     // Setup interpreter with Arm NN Delegate applied.
149     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
150     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
151     CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 1) == kTfLiteOk);
152     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
153 
154     // Compare output data
155     for (unsigned int i = 0; i < expectedOutputValues.size(); ++i)
156     {
157         std::vector<T>       tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(i);
158         std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(i);
159 
160         std::vector<T>       armnnOutputValues = armnnInterpreter.GetOutputResult<T>(i);
161         std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(i);
162 
163         armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues[i]);
164         armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShapes[i]);
165     }
166 
167     tfLiteInterpreter.Cleanup();
168     armnnInterpreter.Cleanup();
169 
170 } // End of SPLIT Test
171 
CreateSplitVTfLiteModel(tflite::TensorType tensorType,std::vector<int32_t> & inputTensorShape,std::vector<int32_t> & splitsTensorShape,std::vector<int32_t> & axisTensorShape,const std::vector<std::vector<int32_t>> & outputTensorShapes,std::vector<int32_t> & splitsData,std::vector<int32_t> & axisData,const int32_t numSplits,float quantScale=1.0f,int quantOffset=0)172 std::vector<char> CreateSplitVTfLiteModel(tflite::TensorType tensorType,
173                                           std::vector<int32_t>& inputTensorShape,
174                                           std::vector<int32_t>& splitsTensorShape,
175                                           std::vector<int32_t>& axisTensorShape,
176                                           const std::vector<std::vector<int32_t>>& outputTensorShapes,
177                                           std::vector<int32_t>& splitsData,
178                                           std::vector<int32_t>& axisData,
179                                           const int32_t numSplits,
180                                           float quantScale = 1.0f,
181                                           int quantOffset  = 0)
182 {
183     using namespace tflite;
184     flatbuffers::FlatBufferBuilder flatBufferBuilder;
185 
186     std::array<flatbuffers::Offset<tflite::Buffer>, 3> buffers;
187     buffers[0] = CreateBuffer(flatBufferBuilder, flatBufferBuilder.CreateVector({}));
188     buffers[1] = CreateBuffer(flatBufferBuilder,
189                               flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(splitsData.data()),
190                                                              sizeof(int32_t) * splitsData.size()));
191     buffers[2] = CreateBuffer(flatBufferBuilder,
192                               flatBufferBuilder.CreateVector(reinterpret_cast<const uint8_t*>(axisData.data()),
193                                                              sizeof(int32_t) * axisData.size()));
194 
195     auto quantizationParameters =
196             CreateQuantizationParameters(flatBufferBuilder,
197                                          0,
198                                          0,
199                                          flatBufferBuilder.CreateVector<float>({ quantScale }),
200                                          flatBufferBuilder.CreateVector<int64_t>({ quantOffset }));
201 
202     std::array<flatbuffers::Offset<Tensor>, 5> tensors;
203     tensors[0] = CreateTensor(flatBufferBuilder,
204                               flatBufferBuilder.CreateVector<int32_t>(inputTensorShape.data(),
205                                                                       inputTensorShape.size()),
206                               tensorType,
207                               0,
208                               flatBufferBuilder.CreateString("input"),
209                               quantizationParameters);
210     tensors[1] = CreateTensor(flatBufferBuilder,
211                               flatBufferBuilder.CreateVector<int32_t>(splitsTensorShape.data(),
212                                                                       splitsTensorShape.size()),
213                               ::tflite::TensorType_INT32,
214                               1,
215                               flatBufferBuilder.CreateString("splits"),
216                               quantizationParameters);
217     tensors[2] = CreateTensor(flatBufferBuilder,
218                               flatBufferBuilder.CreateVector<int32_t>(axisTensorShape.data(),
219                                                                       axisTensorShape.size()),
220                               ::tflite::TensorType_INT32,
221                               2,
222                               flatBufferBuilder.CreateString("axis"),
223                               quantizationParameters);
224 
225     // Create output tensor
226     for (unsigned int i = 0; i < outputTensorShapes.size(); ++i)
227     {
228         tensors[i + 3] = CreateTensor(flatBufferBuilder,
229                                       flatBufferBuilder.CreateVector<int32_t>(outputTensorShapes[i].data(),
230                                                                               outputTensorShapes[i].size()),
231                                       tensorType,
232                                       0,
233                                       flatBufferBuilder.CreateString("output"),
234                                       quantizationParameters);
235     }
236 
237     // create operator. Mean uses ReducerOptions.
238     tflite::BuiltinOptions operatorBuiltinOptionsType = tflite::BuiltinOptions_SplitVOptions;
239     flatbuffers::Offset<void> operatorBuiltinOptions = CreateSplitVOptions(flatBufferBuilder, numSplits).Union();
240 
241     const std::vector<int> operatorInputs{ {0, 1, 2} };
242     const std::vector<int> operatorOutputs{ {3, 4} };
243     flatbuffers::Offset <Operator> controlOperator =
244             CreateOperator(flatBufferBuilder,
245                            0,
246                            flatBufferBuilder.CreateVector<int32_t>(operatorInputs.data(), operatorInputs.size()),
247                            flatBufferBuilder.CreateVector<int32_t>(operatorOutputs.data(), operatorOutputs.size()),
248                            operatorBuiltinOptionsType,
249                            operatorBuiltinOptions);
250 
251     const std::vector<int> subgraphInputs{ {0, 1, 2} };
252     const std::vector<int> subgraphOutputs{ {3, 4} };
253     flatbuffers::Offset <SubGraph> subgraph =
254             CreateSubGraph(flatBufferBuilder,
255                            flatBufferBuilder.CreateVector(tensors.data(), tensors.size()),
256                            flatBufferBuilder.CreateVector<int32_t>(subgraphInputs.data(), subgraphInputs.size()),
257                            flatBufferBuilder.CreateVector<int32_t>(subgraphOutputs.data(), subgraphOutputs.size()),
258                            flatBufferBuilder.CreateVector(&controlOperator, 1));
259 
260     flatbuffers::Offset <flatbuffers::String> modelDescription =
261             flatBufferBuilder.CreateString("ArmnnDelegate: SPLIT_V Operator Model");
262     flatbuffers::Offset <OperatorCode> operatorCode = CreateOperatorCode(flatBufferBuilder, BuiltinOperator_SPLIT_V);
263 
264     flatbuffers::Offset <Model> flatbufferModel =
265             CreateModel(flatBufferBuilder,
266                         TFLITE_SCHEMA_VERSION,
267                         flatBufferBuilder.CreateVector(&operatorCode, 1),
268                         flatBufferBuilder.CreateVector(&subgraph, 1),
269                         modelDescription,
270                         flatBufferBuilder.CreateVector(buffers.data(), buffers.size()));
271 
272     flatBufferBuilder.Finish(flatbufferModel, armnnDelegate::FILE_IDENTIFIER);
273 
274     return std::vector<char>(flatBufferBuilder.GetBufferPointer(),
275                              flatBufferBuilder.GetBufferPointer() + flatBufferBuilder.GetSize());
276 }
277 
278 template <typename T>
SplitVTest(tflite::TensorType tensorType,std::vector<armnn::BackendId> & backends,std::vector<int32_t> & inputTensorShape,std::vector<int32_t> & splitsTensorShape,std::vector<int32_t> & axisTensorShape,std::vector<std::vector<int32_t>> & outputTensorShapes,std::vector<T> & inputValues,std::vector<int32_t> & splitsData,std::vector<int32_t> & axisData,std::vector<std::vector<T>> & expectedOutputValues,const int32_t numSplits,float quantScale=1.0f,int quantOffset=0)279 void SplitVTest(tflite::TensorType tensorType,
280                 std::vector<armnn::BackendId>& backends,
281                 std::vector<int32_t>& inputTensorShape,
282                 std::vector<int32_t>& splitsTensorShape,
283                 std::vector<int32_t>& axisTensorShape,
284                 std::vector<std::vector<int32_t>>& outputTensorShapes,
285                 std::vector<T>& inputValues,
286                 std::vector<int32_t>& splitsData,
287                 std::vector<int32_t>& axisData,
288                 std::vector<std::vector<T>>& expectedOutputValues,
289                 const int32_t numSplits,
290                 float quantScale = 1.0f,
291                 int quantOffset  = 0)
292 {
293     using namespace delegateTestInterpreter;
294     std::vector<char> modelBuffer = CreateSplitVTfLiteModel(tensorType,
295                                                             inputTensorShape,
296                                                             splitsTensorShape,
297                                                             axisTensorShape,
298                                                             outputTensorShapes,
299                                                             splitsData,
300                                                             axisData,
301                                                             numSplits,
302                                                             quantScale,
303                                                             quantOffset);
304 
305     // Setup interpreter with just TFLite Runtime.
306     auto tfLiteInterpreter = DelegateTestInterpreter(modelBuffer);
307     CHECK(tfLiteInterpreter.AllocateTensors() == kTfLiteOk);
308     CHECK(tfLiteInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
309     CHECK(tfLiteInterpreter.Invoke() == kTfLiteOk);
310 
311     // Setup interpreter with Arm NN Delegate applied.
312     auto armnnInterpreter = DelegateTestInterpreter(modelBuffer, backends);
313     CHECK(armnnInterpreter.AllocateTensors() == kTfLiteOk);
314     CHECK(armnnInterpreter.FillInputTensor<T>(inputValues, 0) == kTfLiteOk);
315     CHECK(armnnInterpreter.Invoke() == kTfLiteOk);
316 
317     // Compare output data
318     for (unsigned int i = 0; i < expectedOutputValues.size(); ++i)
319     {
320         std::vector<T>       tfLiteOutputValues = tfLiteInterpreter.GetOutputResult<T>(i);
321         std::vector<int32_t> tfLiteOutputShape  = tfLiteInterpreter.GetOutputShape(i);
322 
323         std::vector<T>       armnnOutputValues = armnnInterpreter.GetOutputResult<T>(i);
324         std::vector<int32_t> armnnOutputShape  = armnnInterpreter.GetOutputShape(i);
325 
326         armnnDelegate::CompareOutputData<T>(tfLiteOutputValues, armnnOutputValues, expectedOutputValues[i]);
327         armnnDelegate::CompareOutputShape(tfLiteOutputShape, armnnOutputShape, outputTensorShapes[i]);
328     }
329 
330     tfLiteInterpreter.Cleanup();
331     armnnInterpreter.Cleanup();
332 } // End of SPLIT_V Test
333 
334 } // anonymous namespace