xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/ParserFlatbuffersFixture.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "Schema.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker #include "../TfLiteParser.hpp"
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorHelpers.hpp>
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
22*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker #include "flatbuffers/idl.h"
25*89c4ff92SAndroid Build Coastguard Worker #include "flatbuffers/util.h"
26*89c4ff92SAndroid Build Coastguard Worker #include "flatbuffers/flexbuffers.h"
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::ITfLiteParser;
32*89c4ff92SAndroid Build Coastguard Worker using armnnTfLiteParser::ITfLiteParserPtr;
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtr = const tflite::TensorT *;
35*89c4ff92SAndroid Build Coastguard Worker struct ParserFlatbuffersFixture
36*89c4ff92SAndroid Build Coastguard Worker {
ParserFlatbuffersFixtureParserFlatbuffersFixture37*89c4ff92SAndroid Build Coastguard Worker     ParserFlatbuffersFixture() :
38*89c4ff92SAndroid Build Coastguard Worker             m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions())),
39*89c4ff92SAndroid Build Coastguard Worker             m_NetworkIdentifier(0),
40*89c4ff92SAndroid Build Coastguard Worker             m_DynamicNetworkIdentifier(1)
41*89c4ff92SAndroid Build Coastguard Worker     {
42*89c4ff92SAndroid Build Coastguard Worker         ITfLiteParser::TfLiteParserOptions options;
43*89c4ff92SAndroid Build Coastguard Worker         options.m_StandInLayerForUnsupported = true;
44*89c4ff92SAndroid Build Coastguard Worker         options.m_InferAndValidate = true;
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker         m_Parser = std::make_unique<armnnTfLiteParser::TfLiteParserImpl>(
47*89c4ff92SAndroid Build Coastguard Worker                         armnn::Optional<ITfLiteParser::TfLiteParserOptions>(options));
48*89c4ff92SAndroid Build Coastguard Worker     }
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_GraphBinary;
51*89c4ff92SAndroid Build Coastguard Worker     std::string          m_JsonString;
52*89c4ff92SAndroid Build Coastguard Worker     armnn::IRuntimePtr   m_Runtime;
53*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId     m_NetworkIdentifier;
54*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId     m_DynamicNetworkIdentifier;
55*89c4ff92SAndroid Build Coastguard Worker     bool                 m_TestDynamic;
56*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<armnnTfLiteParser::TfLiteParserImpl> m_Parser;
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
59*89c4ff92SAndroid Build Coastguard Worker     /// so they don't need to be passed to the single-input-single-output overload of RunTest().
60*89c4ff92SAndroid Build Coastguard Worker     std::string m_SingleInputName;
61*89c4ff92SAndroid Build Coastguard Worker     std::string m_SingleOutputName;
62*89c4ff92SAndroid Build Coastguard Worker 
SetupParserFlatbuffersFixture63*89c4ff92SAndroid Build Coastguard Worker     void Setup(bool testDynamic = true)
64*89c4ff92SAndroid Build Coastguard Worker     {
65*89c4ff92SAndroid Build Coastguard Worker         m_TestDynamic = testDynamic;
66*89c4ff92SAndroid Build Coastguard Worker         loadNetwork(m_NetworkIdentifier, false);
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker         if (m_TestDynamic)
69*89c4ff92SAndroid Build Coastguard Worker         {
70*89c4ff92SAndroid Build Coastguard Worker             loadNetwork(m_DynamicNetworkIdentifier, true);
71*89c4ff92SAndroid Build Coastguard Worker         }
72*89c4ff92SAndroid Build Coastguard Worker     }
73*89c4ff92SAndroid Build Coastguard Worker 
MakeModelDynamicParserFlatbuffersFixture74*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<tflite::ModelT> MakeModelDynamic(std::vector<uint8_t> graphBinary)
75*89c4ff92SAndroid Build Coastguard Worker     {
76*89c4ff92SAndroid Build Coastguard Worker         const uint8_t* binaryContent = graphBinary.data();
77*89c4ff92SAndroid Build Coastguard Worker         const size_t len = graphBinary.size();
78*89c4ff92SAndroid Build Coastguard Worker         if (binaryContent == nullptr)
79*89c4ff92SAndroid Build Coastguard Worker         {
80*89c4ff92SAndroid Build Coastguard Worker             throw armnn::InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
81*89c4ff92SAndroid Build Coastguard Worker                                                                CHECK_LOCATION().AsString()));
82*89c4ff92SAndroid Build Coastguard Worker         }
83*89c4ff92SAndroid Build Coastguard Worker         flatbuffers::Verifier verifier(binaryContent, len);
84*89c4ff92SAndroid Build Coastguard Worker         if (verifier.VerifyBuffer<tflite::Model>() == false)
85*89c4ff92SAndroid Build Coastguard Worker         {
86*89c4ff92SAndroid Build Coastguard Worker             throw armnn::ParseException(fmt::format("Buffer doesn't conform to the expected Tensorflow Lite "
87*89c4ff92SAndroid Build Coastguard Worker                                                     "flatbuffers format. size:{} {}",
88*89c4ff92SAndroid Build Coastguard Worker                                                     len,
89*89c4ff92SAndroid Build Coastguard Worker                                                     CHECK_LOCATION().AsString()));
90*89c4ff92SAndroid Build Coastguard Worker         }
91*89c4ff92SAndroid Build Coastguard Worker         auto model =  tflite::UnPackModel(binaryContent);
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker         for (auto const& subgraph : model->subgraphs)
94*89c4ff92SAndroid Build Coastguard Worker         {
95*89c4ff92SAndroid Build Coastguard Worker             std::vector<int32_t> inputIds = subgraph->inputs;
96*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int tensorIndex = 0; tensorIndex < subgraph->tensors.size(); ++tensorIndex)
97*89c4ff92SAndroid Build Coastguard Worker             {
98*89c4ff92SAndroid Build Coastguard Worker                 if (std::find(inputIds.begin(), inputIds.end(), tensorIndex) != inputIds.end())
99*89c4ff92SAndroid Build Coastguard Worker                 {
100*89c4ff92SAndroid Build Coastguard Worker                     continue;
101*89c4ff92SAndroid Build Coastguard Worker                 }
102*89c4ff92SAndroid Build Coastguard Worker                 for (auto const& tensor : subgraph->tensors)
103*89c4ff92SAndroid Build Coastguard Worker                 {
104*89c4ff92SAndroid Build Coastguard Worker                     if (tensor->shape_signature.size() != 0)
105*89c4ff92SAndroid Build Coastguard Worker                     {
106*89c4ff92SAndroid Build Coastguard Worker                         continue;
107*89c4ff92SAndroid Build Coastguard Worker                     }
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker                     for (unsigned int i = 0; i < tensor->shape.size(); ++i)
110*89c4ff92SAndroid Build Coastguard Worker                     {
111*89c4ff92SAndroid Build Coastguard Worker                         tensor->shape_signature.push_back(-1);
112*89c4ff92SAndroid Build Coastguard Worker                     }
113*89c4ff92SAndroid Build Coastguard Worker                 }
114*89c4ff92SAndroid Build Coastguard Worker             }
115*89c4ff92SAndroid Build Coastguard Worker         }
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker         return model;
118*89c4ff92SAndroid Build Coastguard Worker     }
119*89c4ff92SAndroid Build Coastguard Worker 
loadNetworkParserFlatbuffersFixture120*89c4ff92SAndroid Build Coastguard Worker     void loadNetwork(armnn::NetworkId networkId, bool loadDynamic)
121*89c4ff92SAndroid Build Coastguard Worker     {
122*89c4ff92SAndroid Build Coastguard Worker         if (!ReadStringToBinary())
123*89c4ff92SAndroid Build Coastguard Worker         {
124*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("LoadNetwork failed while reading binary input");
125*89c4ff92SAndroid Build Coastguard Worker         }
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker         armnn::INetworkPtr network = loadDynamic ? m_Parser->LoadModel(MakeModelDynamic(m_GraphBinary))
128*89c4ff92SAndroid Build Coastguard Worker                                                  : m_Parser->CreateNetworkFromBinary(m_GraphBinary);
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker         if (!network) {
131*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception("The parser failed to create an ArmNN network");
132*89c4ff92SAndroid Build Coastguard Worker         }
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker         auto optimized = Optimize(*network, { armnn::Compute::CpuRef },
135*89c4ff92SAndroid Build Coastguard Worker                                   m_Runtime->GetDeviceSpec());
136*89c4ff92SAndroid Build Coastguard Worker         std::string errorMessage;
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker         armnn::Status ret = m_Runtime->LoadNetwork(networkId, move(optimized), errorMessage);
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker         if (ret != armnn::Status::Success)
141*89c4ff92SAndroid Build Coastguard Worker         {
142*89c4ff92SAndroid Build Coastguard Worker             throw armnn::Exception(
143*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("The runtime failed to load the network. "
144*89c4ff92SAndroid Build Coastguard Worker                             "Error was: {}. in {} [{}:{}]",
145*89c4ff92SAndroid Build Coastguard Worker                             errorMessage,
146*89c4ff92SAndroid Build Coastguard Worker                             __func__,
147*89c4ff92SAndroid Build Coastguard Worker                             __FILE__,
148*89c4ff92SAndroid Build Coastguard Worker                             __LINE__));
149*89c4ff92SAndroid Build Coastguard Worker         }
150*89c4ff92SAndroid Build Coastguard Worker     }
151*89c4ff92SAndroid Build Coastguard Worker 
SetupSingleInputSingleOutputParserFlatbuffersFixture152*89c4ff92SAndroid Build Coastguard Worker     void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName)
153*89c4ff92SAndroid Build Coastguard Worker     {
154*89c4ff92SAndroid Build Coastguard Worker         // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
155*89c4ff92SAndroid Build Coastguard Worker         m_SingleInputName = inputName;
156*89c4ff92SAndroid Build Coastguard Worker         m_SingleOutputName = outputName;
157*89c4ff92SAndroid Build Coastguard Worker         Setup();
158*89c4ff92SAndroid Build Coastguard Worker     }
159*89c4ff92SAndroid Build Coastguard Worker 
ReadStringToBinaryParserFlatbuffersFixture160*89c4ff92SAndroid Build Coastguard Worker     bool ReadStringToBinary()
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         std::string schemafile(g_TfLiteSchemaText, g_TfLiteSchemaText + g_TfLiteSchemaText_len);
163*89c4ff92SAndroid Build Coastguard Worker 
164*89c4ff92SAndroid Build Coastguard Worker         // parse schema first, so we can use it to parse the data after
165*89c4ff92SAndroid Build Coastguard Worker         flatbuffers::Parser parser;
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker         bool ok = parser.Parse(schemafile.c_str());
168*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(ok, std::string("Failed to parse schema file. Error was: " + parser.error_).c_str());
169*89c4ff92SAndroid Build Coastguard Worker 
170*89c4ff92SAndroid Build Coastguard Worker         ok = parser.Parse(m_JsonString.c_str());
171*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(ok, std::string("Failed to parse json input. Error was: " + parser.error_).c_str());
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker         {
174*89c4ff92SAndroid Build Coastguard Worker             const uint8_t * bufferPtr = parser.builder_.GetBufferPointer();
175*89c4ff92SAndroid Build Coastguard Worker             size_t size = static_cast<size_t>(parser.builder_.GetSize());
176*89c4ff92SAndroid Build Coastguard Worker             m_GraphBinary.assign(bufferPtr, bufferPtr+size);
177*89c4ff92SAndroid Build Coastguard Worker         }
178*89c4ff92SAndroid Build Coastguard Worker         return ok;
179*89c4ff92SAndroid Build Coastguard Worker     }
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker     /// Executes the network with the given input tensor and checks the result against the given output tensor.
182*89c4ff92SAndroid Build Coastguard Worker     /// This assumes the network has a single input and a single output.
183*89c4ff92SAndroid Build Coastguard Worker     template <std::size_t NumOutputDimensions,
184*89c4ff92SAndroid Build Coastguard Worker               armnn::DataType ArmnnType>
185*89c4ff92SAndroid Build Coastguard Worker     void RunTest(size_t subgraphId,
186*89c4ff92SAndroid Build Coastguard Worker                  const std::vector<armnn::ResolveType<ArmnnType>>& inputData,
187*89c4ff92SAndroid Build Coastguard Worker                  const std::vector<armnn::ResolveType<ArmnnType>>& expectedOutputData);
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker     /// Executes the network with the given input tensors and checks the results against the given output tensors.
190*89c4ff92SAndroid Build Coastguard Worker     /// This overload supports multiple inputs and multiple outputs, identified by name.
191*89c4ff92SAndroid Build Coastguard Worker     template <std::size_t NumOutputDimensions,
192*89c4ff92SAndroid Build Coastguard Worker               armnn::DataType ArmnnType>
193*89c4ff92SAndroid Build Coastguard Worker     void RunTest(size_t subgraphId,
194*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType>>>& inputData,
195*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType>>>& expectedOutputData);
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker     /// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.
198*89c4ff92SAndroid Build Coastguard Worker     /// Executes the network with the given input tensors and checks the results against the given output tensors.
199*89c4ff92SAndroid Build Coastguard Worker     /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
200*89c4ff92SAndroid Build Coastguard Worker     /// the input datatype to be different to the output
201*89c4ff92SAndroid Build Coastguard Worker     template <std::size_t NumOutputDimensions,
202*89c4ff92SAndroid Build Coastguard Worker               armnn::DataType ArmnnType1,
203*89c4ff92SAndroid Build Coastguard Worker               armnn::DataType ArmnnType2>
204*89c4ff92SAndroid Build Coastguard Worker     void RunTest(size_t subgraphId,
205*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType1>>>& inputData,
206*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType2>>>& expectedOutputData,
207*89c4ff92SAndroid Build Coastguard Worker                  bool isDynamic = false);
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker     /// Multiple Inputs with different DataTypes, Multiple Outputs w/ Variable DataTypes
210*89c4ff92SAndroid Build Coastguard Worker     /// Executes the network with the given input tensors and checks the results against the given output tensors.
211*89c4ff92SAndroid Build Coastguard Worker     /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
212*89c4ff92SAndroid Build Coastguard Worker     /// the input datatype to be different to the output
213*89c4ff92SAndroid Build Coastguard Worker     template <std::size_t NumOutputDimensions,
214*89c4ff92SAndroid Build Coastguard Worker         armnn::DataType inputType1,
215*89c4ff92SAndroid Build Coastguard Worker         armnn::DataType inputType2,
216*89c4ff92SAndroid Build Coastguard Worker         armnn::DataType outputType>
217*89c4ff92SAndroid Build Coastguard Worker     void RunTest(size_t subgraphId,
218*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<inputType1>>>& input1Data,
219*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<inputType2>>>& input2Data,
220*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<outputType>>>& expectedOutputData);
221*89c4ff92SAndroid Build Coastguard Worker 
222*89c4ff92SAndroid Build Coastguard Worker     /// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.
223*89c4ff92SAndroid Build Coastguard Worker     /// Executes the network with the given input tensors and checks the results against the given output tensors.
224*89c4ff92SAndroid Build Coastguard Worker     /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
225*89c4ff92SAndroid Build Coastguard Worker     /// the input datatype to be different to the output
226*89c4ff92SAndroid Build Coastguard Worker     template<armnn::DataType ArmnnType1,
227*89c4ff92SAndroid Build Coastguard Worker              armnn::DataType ArmnnType2>
228*89c4ff92SAndroid Build Coastguard Worker     void RunTest(std::size_t subgraphId,
229*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType1>>>& inputData,
230*89c4ff92SAndroid Build Coastguard Worker                  const std::map<std::string, std::vector<armnn::ResolveType<ArmnnType2>>>& expectedOutputData);
231*89c4ff92SAndroid Build Coastguard Worker 
GenerateDetectionPostProcessJsonStringParserFlatbuffersFixture232*89c4ff92SAndroid Build Coastguard Worker     static inline std::string GenerateDetectionPostProcessJsonString(
233*89c4ff92SAndroid Build Coastguard Worker         const armnn::DetectionPostProcessDescriptor& descriptor)
234*89c4ff92SAndroid Build Coastguard Worker     {
235*89c4ff92SAndroid Build Coastguard Worker         flexbuffers::Builder detectPostProcess;
236*89c4ff92SAndroid Build Coastguard Worker         detectPostProcess.Map([&]() {
237*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Bool("use_regular_nms", descriptor.m_UseRegularNms);
238*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Int("max_detections", descriptor.m_MaxDetections);
239*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Int("max_classes_per_detection", descriptor.m_MaxClassesPerDetection);
240*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Int("detections_per_class", descriptor.m_DetectionsPerClass);
241*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Int("num_classes", descriptor.m_NumClasses);
242*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Float("nms_score_threshold", descriptor.m_NmsScoreThreshold);
243*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Float("nms_iou_threshold", descriptor.m_NmsIouThreshold);
244*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Float("h_scale", descriptor.m_ScaleH);
245*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Float("w_scale", descriptor.m_ScaleW);
246*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Float("x_scale", descriptor.m_ScaleX);
247*89c4ff92SAndroid Build Coastguard Worker             detectPostProcess.Float("y_scale", descriptor.m_ScaleY);
248*89c4ff92SAndroid Build Coastguard Worker         });
249*89c4ff92SAndroid Build Coastguard Worker         detectPostProcess.Finish();
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker         // Create JSON string
252*89c4ff92SAndroid Build Coastguard Worker         std::stringstream strStream;
253*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint8_t> buffer = detectPostProcess.GetBuffer();
254*89c4ff92SAndroid Build Coastguard Worker         std::copy(buffer.begin(), buffer.end(),std::ostream_iterator<int>(strStream,","));
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker         return strStream.str();
257*89c4ff92SAndroid Build Coastguard Worker     }
258*89c4ff92SAndroid Build Coastguard Worker 
CheckTensorsParserFlatbuffersFixture259*89c4ff92SAndroid Build Coastguard Worker     void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape,
260*89c4ff92SAndroid Build Coastguard Worker                       tflite::TensorType tensorType, uint32_t buffer, const std::string& name,
261*89c4ff92SAndroid Build Coastguard Worker                       const std::vector<float>& min, const std::vector<float>& max,
262*89c4ff92SAndroid Build Coastguard Worker                       const std::vector<float>& scale, const std::vector<int64_t>& zeroPoint)
263*89c4ff92SAndroid Build Coastguard Worker     {
264*89c4ff92SAndroid Build Coastguard Worker         CHECK(tensors);
265*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(shapeSize, tensors->shape.size());
266*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(shape.begin(), shape.end(), tensors->shape.begin(), tensors->shape.end()));
267*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(tensorType, tensors->type);
268*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(buffer, tensors->buffer);
269*89c4ff92SAndroid Build Coastguard Worker         CHECK_EQ(name, tensors->name);
270*89c4ff92SAndroid Build Coastguard Worker         CHECK(tensors->quantization);
271*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(min.begin(), min.end(), tensors->quantization.get()->min.begin(),
272*89c4ff92SAndroid Build Coastguard Worker                                       tensors->quantization.get()->min.end()));
273*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(max.begin(), max.end(), tensors->quantization.get()->max.begin(),
274*89c4ff92SAndroid Build Coastguard Worker                                       tensors->quantization.get()->max.end()));
275*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(scale.begin(), scale.end(), tensors->quantization.get()->scale.begin(),
276*89c4ff92SAndroid Build Coastguard Worker                                       tensors->quantization.get()->scale.end()));
277*89c4ff92SAndroid Build Coastguard Worker         CHECK(std::equal(zeroPoint.begin(), zeroPoint.end(),
278*89c4ff92SAndroid Build Coastguard Worker                                       tensors->quantization.get()->zero_point.begin(),
279*89c4ff92SAndroid Build Coastguard Worker                                       tensors->quantization.get()->zero_point.end()));
280*89c4ff92SAndroid Build Coastguard Worker     }
281*89c4ff92SAndroid Build Coastguard Worker 
282*89c4ff92SAndroid Build Coastguard Worker private:
283*89c4ff92SAndroid Build Coastguard Worker     /// Fills the InputTensors with given input data
284*89c4ff92SAndroid Build Coastguard Worker     template <armnn::DataType dataType>
285*89c4ff92SAndroid Build Coastguard Worker     void FillInputTensors(armnn::InputTensors& inputTensors,
286*89c4ff92SAndroid Build Coastguard Worker                           const std::map<std::string, std::vector<armnn::ResolveType<dataType>>>& inputData,
287*89c4ff92SAndroid Build Coastguard Worker                           size_t subgraphId);
288*89c4ff92SAndroid Build Coastguard Worker };
289*89c4ff92SAndroid Build Coastguard Worker 
290*89c4ff92SAndroid Build Coastguard Worker /// Fills the InputTensors with given input data
291*89c4ff92SAndroid Build Coastguard Worker template <armnn::DataType dataType>
FillInputTensors(armnn::InputTensors & inputTensors,const std::map<std::string,std::vector<armnn::ResolveType<dataType>>> & inputData,size_t subgraphId)292*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersFixture::FillInputTensors(
293*89c4ff92SAndroid Build Coastguard Worker                   armnn::InputTensors& inputTensors,
294*89c4ff92SAndroid Build Coastguard Worker                   const std::map<std::string, std::vector<armnn::ResolveType<dataType>>>& inputData,
295*89c4ff92SAndroid Build Coastguard Worker                   size_t subgraphId)
296*89c4ff92SAndroid Build Coastguard Worker {
297*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : inputData)
298*89c4ff92SAndroid Build Coastguard Worker     {
299*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkInputBindingInfo(subgraphId, it.first);
300*89c4ff92SAndroid Build Coastguard Worker         bindingInfo.second.SetConstant(true);
301*89c4ff92SAndroid Build Coastguard Worker         armnn::VerifyTensorInfoDataType(bindingInfo.second, dataType);
302*89c4ff92SAndroid Build Coastguard Worker         inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
303*89c4ff92SAndroid Build Coastguard Worker     }
304*89c4ff92SAndroid Build Coastguard Worker }
305*89c4ff92SAndroid Build Coastguard Worker 
306*89c4ff92SAndroid Build Coastguard Worker /// Single Input, Single Output
307*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensor and checks the result against the given output tensor.
308*89c4ff92SAndroid Build Coastguard Worker /// This overload assumes the network has a single input and a single output.
309*89c4ff92SAndroid Build Coastguard Worker template <std::size_t NumOutputDimensions,
310*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType armnnType>
RunTest(size_t subgraphId,const std::vector<armnn::ResolveType<armnnType>> & inputData,const std::vector<armnn::ResolveType<armnnType>> & expectedOutputData)311*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
312*89c4ff92SAndroid Build Coastguard Worker                                        const std::vector<armnn::ResolveType<armnnType>>& inputData,
313*89c4ff92SAndroid Build Coastguard Worker                                        const std::vector<armnn::ResolveType<armnnType>>& expectedOutputData)
314*89c4ff92SAndroid Build Coastguard Worker {
315*89c4ff92SAndroid Build Coastguard Worker     RunTest<NumOutputDimensions, armnnType>(subgraphId,
316*89c4ff92SAndroid Build Coastguard Worker                                             { { m_SingleInputName, inputData } },
317*89c4ff92SAndroid Build Coastguard Worker                                             { { m_SingleOutputName, expectedOutputData } });
318*89c4ff92SAndroid Build Coastguard Worker }
319*89c4ff92SAndroid Build Coastguard Worker 
320*89c4ff92SAndroid Build Coastguard Worker /// Multiple Inputs, Multiple Outputs
321*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensors and checks the results against the given output tensors.
322*89c4ff92SAndroid Build Coastguard Worker /// This overload supports multiple inputs and multiple outputs, identified by name.
323*89c4ff92SAndroid Build Coastguard Worker template <std::size_t NumOutputDimensions,
324*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType armnnType>
RunTest(size_t subgraphId,const std::map<std::string,std::vector<armnn::ResolveType<armnnType>>> & inputData,const std::map<std::string,std::vector<armnn::ResolveType<armnnType>>> & expectedOutputData)325*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
326*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<armnnType>>>& inputData,
327*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<armnnType>>>& expectedOutputData)
328*89c4ff92SAndroid Build Coastguard Worker {
329*89c4ff92SAndroid Build Coastguard Worker     RunTest<NumOutputDimensions, armnnType, armnnType>(subgraphId, inputData, expectedOutputData);
330*89c4ff92SAndroid Build Coastguard Worker }
331*89c4ff92SAndroid Build Coastguard Worker 
332*89c4ff92SAndroid Build Coastguard Worker /// Multiple Inputs, Multiple Outputs w/ Variable Datatypes
333*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensors and checks the results against the given output tensors.
334*89c4ff92SAndroid Build Coastguard Worker /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
335*89c4ff92SAndroid Build Coastguard Worker /// the input datatype to be different to the output
336*89c4ff92SAndroid Build Coastguard Worker template <std::size_t NumOutputDimensions,
337*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType armnnType1,
338*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType armnnType2>
RunTest(size_t subgraphId,const std::map<std::string,std::vector<armnn::ResolveType<armnnType1>>> & inputData,const std::map<std::string,std::vector<armnn::ResolveType<armnnType2>>> & expectedOutputData,bool isDynamic)339*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
340*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<armnnType1>>>& inputData,
341*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<armnnType2>>>& expectedOutputData,
342*89c4ff92SAndroid Build Coastguard Worker     bool isDynamic)
343*89c4ff92SAndroid Build Coastguard Worker {
344*89c4ff92SAndroid Build Coastguard Worker     using DataType2 = armnn::ResolveType<armnnType2>;
345*89c4ff92SAndroid Build Coastguard Worker 
346*89c4ff92SAndroid Build Coastguard Worker     // Setup the armnn input tensors from the given vectors.
347*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
348*89c4ff92SAndroid Build Coastguard Worker     FillInputTensors<armnnType1>(inputTensors, inputData, subgraphId);
349*89c4ff92SAndroid Build Coastguard Worker 
350*89c4ff92SAndroid Build Coastguard Worker     // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
351*89c4ff92SAndroid Build Coastguard Worker     std::map<std::string, std::vector<DataType2>> outputStorage;
352*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
353*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
354*89c4ff92SAndroid Build Coastguard Worker     {
355*89c4ff92SAndroid Build Coastguard Worker         armnn::LayerBindingId outputBindingId = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first).first;
356*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkIdentifier, outputBindingId);
357*89c4ff92SAndroid Build Coastguard Worker 
358*89c4ff92SAndroid Build Coastguard Worker         // Check that output tensors have correct number of dimensions (NumOutputDimensions specified in test)
359*89c4ff92SAndroid Build Coastguard Worker         auto outputNumDimensions = outputTensorInfo.GetNumDimensions();
360*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE((outputNumDimensions == NumOutputDimensions),
361*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Number of dimensions expected {}, but got {} for output layer {}",
362*89c4ff92SAndroid Build Coastguard Worker                         NumOutputDimensions,
363*89c4ff92SAndroid Build Coastguard Worker                         outputNumDimensions,
364*89c4ff92SAndroid Build Coastguard Worker                         it.first));
365*89c4ff92SAndroid Build Coastguard Worker 
366*89c4ff92SAndroid Build Coastguard Worker         armnn::VerifyTensorInfoDataType(outputTensorInfo, armnnType2);
367*89c4ff92SAndroid Build Coastguard Worker         outputStorage.emplace(it.first, std::vector<DataType2>(outputTensorInfo.GetNumElements()));
368*89c4ff92SAndroid Build Coastguard Worker         outputTensors.push_back(
369*89c4ff92SAndroid Build Coastguard Worker                 { outputBindingId, armnn::Tensor(outputTensorInfo, outputStorage.at(it.first).data()) });
370*89c4ff92SAndroid Build Coastguard Worker     }
371*89c4ff92SAndroid Build Coastguard Worker 
372*89c4ff92SAndroid Build Coastguard Worker     m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
373*89c4ff92SAndroid Build Coastguard Worker 
374*89c4ff92SAndroid Build Coastguard Worker     // Set flag so that the correct comparison function is called if the output is boolean.
375*89c4ff92SAndroid Build Coastguard Worker     bool isBoolean = armnnType2 == armnn::DataType::Boolean ? true : false;
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker     // Compare each output tensor to the expected values
378*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
379*89c4ff92SAndroid Build Coastguard Worker     {
380*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
381*89c4ff92SAndroid Build Coastguard Worker         auto outputExpected = it.second;
382*89c4ff92SAndroid Build Coastguard Worker         auto result = CompareTensors(outputExpected, outputStorage[it.first],
383*89c4ff92SAndroid Build Coastguard Worker                                      bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
384*89c4ff92SAndroid Build Coastguard Worker                                      isBoolean, isDynamic);
385*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(result.m_Result, result.m_Message.str());
386*89c4ff92SAndroid Build Coastguard Worker     }
387*89c4ff92SAndroid Build Coastguard Worker 
388*89c4ff92SAndroid Build Coastguard Worker     if (isDynamic)
389*89c4ff92SAndroid Build Coastguard Worker     {
390*89c4ff92SAndroid Build Coastguard Worker         m_Runtime->EnqueueWorkload(m_DynamicNetworkIdentifier, inputTensors, outputTensors);
391*89c4ff92SAndroid Build Coastguard Worker 
392*89c4ff92SAndroid Build Coastguard Worker         // Compare each output tensor to the expected values
393*89c4ff92SAndroid Build Coastguard Worker         for (auto&& it : expectedOutputData)
394*89c4ff92SAndroid Build Coastguard Worker         {
395*89c4ff92SAndroid Build Coastguard Worker             armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
396*89c4ff92SAndroid Build Coastguard Worker             auto outputExpected = it.second;
397*89c4ff92SAndroid Build Coastguard Worker             auto result = CompareTensors(outputExpected, outputStorage[it.first],
398*89c4ff92SAndroid Build Coastguard Worker                                          bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
399*89c4ff92SAndroid Build Coastguard Worker                                          false, isDynamic);
400*89c4ff92SAndroid Build Coastguard Worker             CHECK_MESSAGE(result.m_Result, result.m_Message.str());
401*89c4ff92SAndroid Build Coastguard Worker         }
402*89c4ff92SAndroid Build Coastguard Worker     }
403*89c4ff92SAndroid Build Coastguard Worker }
404*89c4ff92SAndroid Build Coastguard Worker 
405*89c4ff92SAndroid Build Coastguard Worker /// Multiple Inputs, Multiple Outputs w/ Variable Datatypes and different dimension sizes.
406*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensors and checks the results against the given output tensors.
407*89c4ff92SAndroid Build Coastguard Worker /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
408*89c4ff92SAndroid Build Coastguard Worker /// the input datatype to be different to the output.
409*89c4ff92SAndroid Build Coastguard Worker template <armnn::DataType armnnType1,
410*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType armnnType2>
RunTest(std::size_t subgraphId,const std::map<std::string,std::vector<armnn::ResolveType<armnnType1>>> & inputData,const std::map<std::string,std::vector<armnn::ResolveType<armnnType2>>> & expectedOutputData)411*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersFixture::RunTest(std::size_t subgraphId,
412*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<armnnType1>>>& inputData,
413*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<armnnType2>>>& expectedOutputData)
414*89c4ff92SAndroid Build Coastguard Worker {
415*89c4ff92SAndroid Build Coastguard Worker     using DataType2 = armnn::ResolveType<armnnType2>;
416*89c4ff92SAndroid Build Coastguard Worker 
417*89c4ff92SAndroid Build Coastguard Worker     // Setup the armnn input tensors from the given vectors.
418*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
419*89c4ff92SAndroid Build Coastguard Worker     FillInputTensors<armnnType1>(inputTensors, inputData, subgraphId);
420*89c4ff92SAndroid Build Coastguard Worker 
421*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
422*89c4ff92SAndroid Build Coastguard Worker     outputTensors.reserve(expectedOutputData.size());
423*89c4ff92SAndroid Build Coastguard Worker     std::map<std::string, std::vector<DataType2>> outputStorage;
424*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
425*89c4ff92SAndroid Build Coastguard Worker     {
426*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
427*89c4ff92SAndroid Build Coastguard Worker         armnn::VerifyTensorInfoDataType(bindingInfo.second, armnnType2);
428*89c4ff92SAndroid Build Coastguard Worker 
429*89c4ff92SAndroid Build Coastguard Worker         std::vector<DataType2> out(it.second.size());
430*89c4ff92SAndroid Build Coastguard Worker         outputStorage.emplace(it.first, out);
431*89c4ff92SAndroid Build Coastguard Worker         outputTensors.push_back({ bindingInfo.first,
432*89c4ff92SAndroid Build Coastguard Worker                                   armnn::Tensor(bindingInfo.second,
433*89c4ff92SAndroid Build Coastguard Worker                                   outputStorage.at(it.first).data()) });
434*89c4ff92SAndroid Build Coastguard Worker     }
435*89c4ff92SAndroid Build Coastguard Worker 
436*89c4ff92SAndroid Build Coastguard Worker     m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
437*89c4ff92SAndroid Build Coastguard Worker 
438*89c4ff92SAndroid Build Coastguard Worker     // Checks the results.
439*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
440*89c4ff92SAndroid Build Coastguard Worker     {
441*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::ResolveType<armnnType2>> out = outputStorage.at(it.first);
442*89c4ff92SAndroid Build Coastguard Worker         {
443*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int i = 0; i < out.size(); ++i)
444*89c4ff92SAndroid Build Coastguard Worker             {
445*89c4ff92SAndroid Build Coastguard Worker                 CHECK(doctest::Approx(it.second[i]).epsilon(0.000001f) == out[i]);
446*89c4ff92SAndroid Build Coastguard Worker             }
447*89c4ff92SAndroid Build Coastguard Worker         }
448*89c4ff92SAndroid Build Coastguard Worker     }
449*89c4ff92SAndroid Build Coastguard Worker }
450*89c4ff92SAndroid Build Coastguard Worker 
451*89c4ff92SAndroid Build Coastguard Worker /// Multiple Inputs with different DataTypes, Multiple Outputs w/ Variable DataTypes
452*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensors and checks the results against the given output tensors.
453*89c4ff92SAndroid Build Coastguard Worker /// This overload supports multiple inputs and multiple outputs, identified by name along with the allowance for
454*89c4ff92SAndroid Build Coastguard Worker /// the input datatype to be different to the output
455*89c4ff92SAndroid Build Coastguard Worker template <std::size_t NumOutputDimensions,
456*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType inputType1,
457*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType inputType2,
458*89c4ff92SAndroid Build Coastguard Worker           armnn::DataType outputType>
RunTest(size_t subgraphId,const std::map<std::string,std::vector<armnn::ResolveType<inputType1>>> & input1Data,const std::map<std::string,std::vector<armnn::ResolveType<inputType2>>> & input2Data,const std::map<std::string,std::vector<armnn::ResolveType<outputType>>> & expectedOutputData)459*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersFixture::RunTest(size_t subgraphId,
460*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<inputType1>>>& input1Data,
461*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<inputType2>>>& input2Data,
462*89c4ff92SAndroid Build Coastguard Worker     const std::map<std::string, std::vector<armnn::ResolveType<outputType>>>& expectedOutputData)
463*89c4ff92SAndroid Build Coastguard Worker {
464*89c4ff92SAndroid Build Coastguard Worker     using DataType2 = armnn::ResolveType<outputType>;
465*89c4ff92SAndroid Build Coastguard Worker 
466*89c4ff92SAndroid Build Coastguard Worker     // Setup the armnn input tensors from the given vectors.
467*89c4ff92SAndroid Build Coastguard Worker     armnn::InputTensors inputTensors;
468*89c4ff92SAndroid Build Coastguard Worker     FillInputTensors<inputType1>(inputTensors, input1Data, subgraphId);
469*89c4ff92SAndroid Build Coastguard Worker     FillInputTensors<inputType2>(inputTensors, input2Data, subgraphId);
470*89c4ff92SAndroid Build Coastguard Worker 
471*89c4ff92SAndroid Build Coastguard Worker     // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
472*89c4ff92SAndroid Build Coastguard Worker     std::map<std::string, std::vector<DataType2>> outputStorage;
473*89c4ff92SAndroid Build Coastguard Worker     armnn::OutputTensors outputTensors;
474*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
475*89c4ff92SAndroid Build Coastguard Worker     {
476*89c4ff92SAndroid Build Coastguard Worker         armnn::LayerBindingId outputBindingId = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first).first;
477*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkIdentifier, outputBindingId);
478*89c4ff92SAndroid Build Coastguard Worker 
479*89c4ff92SAndroid Build Coastguard Worker         // Check that output tensors have correct number of dimensions (NumOutputDimensions specified in test)
480*89c4ff92SAndroid Build Coastguard Worker         auto outputNumDimensions = outputTensorInfo.GetNumDimensions();
481*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE((outputNumDimensions == NumOutputDimensions),
482*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Number of dimensions expected {}, but got {} for output layer {}",
483*89c4ff92SAndroid Build Coastguard Worker                         NumOutputDimensions,
484*89c4ff92SAndroid Build Coastguard Worker                         outputNumDimensions,
485*89c4ff92SAndroid Build Coastguard Worker                         it.first));
486*89c4ff92SAndroid Build Coastguard Worker 
487*89c4ff92SAndroid Build Coastguard Worker         armnn::VerifyTensorInfoDataType(outputTensorInfo, outputType);
488*89c4ff92SAndroid Build Coastguard Worker         outputStorage.emplace(it.first, std::vector<DataType2>(outputTensorInfo.GetNumElements()));
489*89c4ff92SAndroid Build Coastguard Worker         outputTensors.push_back(
490*89c4ff92SAndroid Build Coastguard Worker                 { outputBindingId, armnn::Tensor(outputTensorInfo, outputStorage.at(it.first).data()) });
491*89c4ff92SAndroid Build Coastguard Worker     }
492*89c4ff92SAndroid Build Coastguard Worker 
493*89c4ff92SAndroid Build Coastguard Worker     m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
494*89c4ff92SAndroid Build Coastguard Worker 
495*89c4ff92SAndroid Build Coastguard Worker     // Set flag so that the correct comparison function is called if the output is boolean.
496*89c4ff92SAndroid Build Coastguard Worker     bool isBoolean = outputType == armnn::DataType::Boolean ? true : false;
497*89c4ff92SAndroid Build Coastguard Worker 
498*89c4ff92SAndroid Build Coastguard Worker     // Compare each output tensor to the expected values
499*89c4ff92SAndroid Build Coastguard Worker     for (auto&& it : expectedOutputData)
500*89c4ff92SAndroid Build Coastguard Worker     {
501*89c4ff92SAndroid Build Coastguard Worker         armnn::BindingPointInfo bindingInfo = m_Parser->GetNetworkOutputBindingInfo(subgraphId, it.first);
502*89c4ff92SAndroid Build Coastguard Worker         auto outputExpected = it.second;
503*89c4ff92SAndroid Build Coastguard Worker         auto result = CompareTensors(outputExpected, outputStorage[it.first],
504*89c4ff92SAndroid Build Coastguard Worker                                      bindingInfo.second.GetShape(), bindingInfo.second.GetShape(),
505*89c4ff92SAndroid Build Coastguard Worker                                      isBoolean);
506*89c4ff92SAndroid Build Coastguard Worker         CHECK_MESSAGE(result.m_Result, result.m_Message.str());
507*89c4ff92SAndroid Build Coastguard Worker     }
508*89c4ff92SAndroid Build Coastguard Worker }