1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "SchemaSerialize.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/TensorHelpers.hpp>
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include "flatbuffers/idl.h"
12*89c4ff92SAndroid Build Coastguard Worker #include "flatbuffers/util.h"
13*89c4ff92SAndroid Build Coastguard Worker
14*89c4ff92SAndroid Build Coastguard Worker #include <ArmnnSchema_generated.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
22*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
23*89c4ff92SAndroid Build Coastguard Worker
24*89c4ff92SAndroid Build Coastguard Worker #include <vector>
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker using armnnDeserializer::IDeserializer;
27*89c4ff92SAndroid Build Coastguard Worker using TensorRawPtr = armnnSerializer::TensorInfo*;
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker struct ParserFlatbuffersSerializeFixture
30*89c4ff92SAndroid Build Coastguard Worker {
ParserFlatbuffersSerializeFixtureParserFlatbuffersSerializeFixture31*89c4ff92SAndroid Build Coastguard Worker ParserFlatbuffersSerializeFixture() :
32*89c4ff92SAndroid Build Coastguard Worker m_Parser(IDeserializer::Create()),
33*89c4ff92SAndroid Build Coastguard Worker m_Runtime(armnn::IRuntime::Create(armnn::IRuntime::CreationOptions())),
34*89c4ff92SAndroid Build Coastguard Worker m_NetworkIdentifier(-1)
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_GraphBinary;
39*89c4ff92SAndroid Build Coastguard Worker std::string m_JsonString;
40*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IDeserializer, void (*)(IDeserializer* parser)> m_Parser;
41*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntimePtr m_Runtime;
42*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId m_NetworkIdentifier;
43*89c4ff92SAndroid Build Coastguard Worker
44*89c4ff92SAndroid Build Coastguard Worker /// If the single-input-single-output overload of Setup() is called, these will store the input and output name
45*89c4ff92SAndroid Build Coastguard Worker /// so they don't need to be passed to the single-input-single-output overload of RunTest().
46*89c4ff92SAndroid Build Coastguard Worker std::string m_SingleInputName;
47*89c4ff92SAndroid Build Coastguard Worker std::string m_SingleOutputName;
48*89c4ff92SAndroid Build Coastguard Worker
SetupParserFlatbuffersSerializeFixture49*89c4ff92SAndroid Build Coastguard Worker void Setup()
50*89c4ff92SAndroid Build Coastguard Worker {
51*89c4ff92SAndroid Build Coastguard Worker bool ok = ReadStringToBinary();
52*89c4ff92SAndroid Build Coastguard Worker if (!ok)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("LoadNetwork failed while reading binary input");
55*89c4ff92SAndroid Build Coastguard Worker }
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network =
58*89c4ff92SAndroid Build Coastguard Worker m_Parser->CreateNetworkFromBinary(m_GraphBinary);
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker if (!network)
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception("The parser failed to create an ArmNN network");
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker auto optimized = Optimize(*network, {armnn::Compute::CpuRef},
66*89c4ff92SAndroid Build Coastguard Worker m_Runtime->GetDeviceSpec());
67*89c4ff92SAndroid Build Coastguard Worker
68*89c4ff92SAndroid Build Coastguard Worker std::string errorMessage;
69*89c4ff92SAndroid Build Coastguard Worker armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, move(optimized), errorMessage);
70*89c4ff92SAndroid Build Coastguard Worker
71*89c4ff92SAndroid Build Coastguard Worker if (ret != armnn::Status::Success)
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker throw armnn::Exception(fmt::format("The runtime failed to load the network. "
74*89c4ff92SAndroid Build Coastguard Worker "Error was: {0}. in {1} [{2}:{3}]",
75*89c4ff92SAndroid Build Coastguard Worker errorMessage,
76*89c4ff92SAndroid Build Coastguard Worker __func__,
77*89c4ff92SAndroid Build Coastguard Worker __FILE__,
78*89c4ff92SAndroid Build Coastguard Worker __LINE__));
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker
SetupSingleInputSingleOutputParserFlatbuffersSerializeFixture83*89c4ff92SAndroid Build Coastguard Worker void SetupSingleInputSingleOutput(const std::string& inputName, const std::string& outputName)
84*89c4ff92SAndroid Build Coastguard Worker {
85*89c4ff92SAndroid Build Coastguard Worker // Store the input and output name so they don't need to be passed to the single-input-single-output RunTest().
86*89c4ff92SAndroid Build Coastguard Worker m_SingleInputName = inputName;
87*89c4ff92SAndroid Build Coastguard Worker m_SingleOutputName = outputName;
88*89c4ff92SAndroid Build Coastguard Worker Setup();
89*89c4ff92SAndroid Build Coastguard Worker }
90*89c4ff92SAndroid Build Coastguard Worker
ReadStringToBinaryParserFlatbuffersSerializeFixture91*89c4ff92SAndroid Build Coastguard Worker bool ReadStringToBinary()
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker std::string schemafile(&deserialize_schema_start, &deserialize_schema_end);
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker // parse schema first, so we can use it to parse the data after
96*89c4ff92SAndroid Build Coastguard Worker flatbuffers::Parser parser;
97*89c4ff92SAndroid Build Coastguard Worker
98*89c4ff92SAndroid Build Coastguard Worker bool ok = parser.Parse(schemafile.c_str());
99*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(ok, std::string("Failed to parse schema file. Error was: " + parser.error_).c_str());
100*89c4ff92SAndroid Build Coastguard Worker
101*89c4ff92SAndroid Build Coastguard Worker ok &= parser.Parse(m_JsonString.c_str());
102*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(ok, std::string("Failed to parse json input. Error was: " + parser.error_).c_str());
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker if (!ok)
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker return false;
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker const uint8_t* bufferPtr = parser.builder_.GetBufferPointer();
111*89c4ff92SAndroid Build Coastguard Worker size_t size = static_cast<size_t>(parser.builder_.GetSize());
112*89c4ff92SAndroid Build Coastguard Worker m_GraphBinary.assign(bufferPtr, bufferPtr+size);
113*89c4ff92SAndroid Build Coastguard Worker }
114*89c4ff92SAndroid Build Coastguard Worker return ok;
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensor and checks the result against the given output tensor.
118*89c4ff92SAndroid Build Coastguard Worker /// This overload assumes the network has a single input and a single output.
119*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions,
120*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnType,
121*89c4ff92SAndroid Build Coastguard Worker typename DataType = armnn::ResolveType<ArmnnType>>
122*89c4ff92SAndroid Build Coastguard Worker void RunTest(unsigned int layersId,
123*89c4ff92SAndroid Build Coastguard Worker const std::vector<DataType>& inputData,
124*89c4ff92SAndroid Build Coastguard Worker const std::vector<DataType>& expectedOutputData);
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions,
127*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnInputType,
128*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnOutputType,
129*89c4ff92SAndroid Build Coastguard Worker typename InputDataType = armnn::ResolveType<ArmnnInputType>,
130*89c4ff92SAndroid Build Coastguard Worker typename OutputDataType = armnn::ResolveType<ArmnnOutputType>>
131*89c4ff92SAndroid Build Coastguard Worker void RunTest(unsigned int layersId,
132*89c4ff92SAndroid Build Coastguard Worker const std::vector<InputDataType>& inputData,
133*89c4ff92SAndroid Build Coastguard Worker const std::vector<OutputDataType>& expectedOutputData);
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker /// Executes the network with the given input tensors and checks the results against the given output tensors.
136*89c4ff92SAndroid Build Coastguard Worker /// This overload supports multiple inputs and multiple outputs, identified by name.
137*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions,
138*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnType,
139*89c4ff92SAndroid Build Coastguard Worker typename DataType = armnn::ResolveType<ArmnnType>>
140*89c4ff92SAndroid Build Coastguard Worker void RunTest(unsigned int layersId,
141*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<DataType>>& inputData,
142*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<DataType>>& expectedOutputData);
143*89c4ff92SAndroid Build Coastguard Worker
144*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions,
145*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnInputType,
146*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnOutputType,
147*89c4ff92SAndroid Build Coastguard Worker typename InputDataType = armnn::ResolveType<ArmnnInputType>,
148*89c4ff92SAndroid Build Coastguard Worker typename OutputDataType = armnn::ResolveType<ArmnnOutputType>>
149*89c4ff92SAndroid Build Coastguard Worker void RunTest(unsigned int layersId,
150*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<InputDataType>>& inputData,
151*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData);
152*89c4ff92SAndroid Build Coastguard Worker
CheckTensorsParserFlatbuffersSerializeFixture153*89c4ff92SAndroid Build Coastguard Worker void CheckTensors(const TensorRawPtr& tensors, size_t shapeSize, const std::vector<int32_t>& shape,
154*89c4ff92SAndroid Build Coastguard Worker armnnSerializer::TensorInfo tensorType, const std::string& name,
155*89c4ff92SAndroid Build Coastguard Worker const float scale, const int64_t zeroPoint)
156*89c4ff92SAndroid Build Coastguard Worker {
157*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(name);
158*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(shapeSize, tensors->dimensions()->size());
159*89c4ff92SAndroid Build Coastguard Worker CHECK(std::equal(shape.begin(), shape.end(),
160*89c4ff92SAndroid Build Coastguard Worker tensors->dimensions()->begin(), tensors->dimensions()->end()));
161*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(tensorType.dataType(), tensors->dataType());
162*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(scale, tensors->quantizationScale());
163*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(zeroPoint, tensors->quantizationOffset());
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker };
166*89c4ff92SAndroid Build Coastguard Worker
167*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions, armnn::DataType ArmnnType, typename DataType>
RunTest(unsigned int layersId,const std::vector<DataType> & inputData,const std::vector<DataType> & expectedOutputData)168*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersSerializeFixture::RunTest(unsigned int layersId,
169*89c4ff92SAndroid Build Coastguard Worker const std::vector<DataType>& inputData,
170*89c4ff92SAndroid Build Coastguard Worker const std::vector<DataType>& expectedOutputData)
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker RunTest<NumOutputDimensions, ArmnnType, ArmnnType, DataType, DataType>(layersId, inputData, expectedOutputData);
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker
175*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions,
176*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnInputType,
177*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnOutputType,
178*89c4ff92SAndroid Build Coastguard Worker typename InputDataType,
179*89c4ff92SAndroid Build Coastguard Worker typename OutputDataType>
RunTest(unsigned int layersId,const std::vector<InputDataType> & inputData,const std::vector<OutputDataType> & expectedOutputData)180*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersSerializeFixture::RunTest(unsigned int layersId,
181*89c4ff92SAndroid Build Coastguard Worker const std::vector<InputDataType>& inputData,
182*89c4ff92SAndroid Build Coastguard Worker const std::vector<OutputDataType>& expectedOutputData)
183*89c4ff92SAndroid Build Coastguard Worker {
184*89c4ff92SAndroid Build Coastguard Worker RunTest<NumOutputDimensions, ArmnnInputType, ArmnnOutputType>(layersId,
185*89c4ff92SAndroid Build Coastguard Worker { { m_SingleInputName, inputData } },
186*89c4ff92SAndroid Build Coastguard Worker { { m_SingleOutputName, expectedOutputData } });
187*89c4ff92SAndroid Build Coastguard Worker }
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions, armnn::DataType ArmnnType, typename DataType>
RunTest(unsigned int layersId,const std::map<std::string,std::vector<DataType>> & inputData,const std::map<std::string,std::vector<DataType>> & expectedOutputData)190*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersSerializeFixture::RunTest(unsigned int layersId,
191*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<DataType>>& inputData,
192*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<DataType>>& expectedOutputData)
193*89c4ff92SAndroid Build Coastguard Worker {
194*89c4ff92SAndroid Build Coastguard Worker RunTest<NumOutputDimensions, ArmnnType, ArmnnType, DataType, DataType>(layersId, inputData, expectedOutputData);
195*89c4ff92SAndroid Build Coastguard Worker }
196*89c4ff92SAndroid Build Coastguard Worker
197*89c4ff92SAndroid Build Coastguard Worker template<std::size_t NumOutputDimensions,
198*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnInputType,
199*89c4ff92SAndroid Build Coastguard Worker armnn::DataType ArmnnOutputType,
200*89c4ff92SAndroid Build Coastguard Worker typename InputDataType,
201*89c4ff92SAndroid Build Coastguard Worker typename OutputDataType>
RunTest(unsigned int layersId,const std::map<std::string,std::vector<InputDataType>> & inputData,const std::map<std::string,std::vector<OutputDataType>> & expectedOutputData)202*89c4ff92SAndroid Build Coastguard Worker void ParserFlatbuffersSerializeFixture::RunTest(
203*89c4ff92SAndroid Build Coastguard Worker unsigned int layersId,
204*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<InputDataType>>& inputData,
205*89c4ff92SAndroid Build Coastguard Worker const std::map<std::string, std::vector<OutputDataType>>& expectedOutputData)
206*89c4ff92SAndroid Build Coastguard Worker {
207*89c4ff92SAndroid Build Coastguard Worker auto ConvertBindingInfo = [](const armnnDeserializer::BindingPointInfo& bindingInfo)
208*89c4ff92SAndroid Build Coastguard Worker {
209*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(bindingInfo.m_BindingId, bindingInfo.m_TensorInfo);
210*89c4ff92SAndroid Build Coastguard Worker };
211*89c4ff92SAndroid Build Coastguard Worker
212*89c4ff92SAndroid Build Coastguard Worker // Setup the armnn input tensors from the given vectors.
213*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors;
214*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : inputData)
215*89c4ff92SAndroid Build Coastguard Worker {
216*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
217*89c4ff92SAndroid Build Coastguard Worker m_Parser->GetNetworkInputBindingInfo(layersId, it.first));
218*89c4ff92SAndroid Build Coastguard Worker bindingInfo.second.SetConstant(true);
219*89c4ff92SAndroid Build Coastguard Worker armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnInputType);
220*89c4ff92SAndroid Build Coastguard Worker inputTensors.push_back({ bindingInfo.first, armnn::ConstTensor(bindingInfo.second, it.second.data()) });
221*89c4ff92SAndroid Build Coastguard Worker }
222*89c4ff92SAndroid Build Coastguard Worker
223*89c4ff92SAndroid Build Coastguard Worker // Allocate storage for the output tensors to be written to and setup the armnn output tensors.
224*89c4ff92SAndroid Build Coastguard Worker std::map<std::string, std::vector<OutputDataType>> outputStorage;
225*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors;
226*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : expectedOutputData)
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
229*89c4ff92SAndroid Build Coastguard Worker m_Parser->GetNetworkOutputBindingInfo(layersId, it.first));
230*89c4ff92SAndroid Build Coastguard Worker armnn::VerifyTensorInfoDataType(bindingInfo.second, ArmnnOutputType);
231*89c4ff92SAndroid Build Coastguard Worker outputStorage.emplace(it.first, std::vector<OutputDataType>(bindingInfo.second.GetNumElements()));
232*89c4ff92SAndroid Build Coastguard Worker outputTensors.push_back(
233*89c4ff92SAndroid Build Coastguard Worker { bindingInfo.first, armnn::Tensor(bindingInfo.second, outputStorage.at(it.first).data()) });
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker
236*89c4ff92SAndroid Build Coastguard Worker m_Runtime->EnqueueWorkload(m_NetworkIdentifier, inputTensors, outputTensors);
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker // Compare each output tensor to the expected values
239*89c4ff92SAndroid Build Coastguard Worker for (auto&& it : expectedOutputData)
240*89c4ff92SAndroid Build Coastguard Worker {
241*89c4ff92SAndroid Build Coastguard Worker armnn::BindingPointInfo bindingInfo = ConvertBindingInfo(
242*89c4ff92SAndroid Build Coastguard Worker m_Parser->GetNetworkOutputBindingInfo(layersId, it.first));
243*89c4ff92SAndroid Build Coastguard Worker auto outputExpected = it.second;
244*89c4ff92SAndroid Build Coastguard Worker auto result = CompareTensors(outputExpected, outputStorage[it.first],
245*89c4ff92SAndroid Build Coastguard Worker bindingInfo.second.GetShape(), bindingInfo.second.GetShape());
246*89c4ff92SAndroid Build Coastguard Worker CHECK_MESSAGE(result.m_Result, result.m_Message.str());
247*89c4ff92SAndroid Build Coastguard Worker }
248*89c4ff92SAndroid Build Coastguard Worker }
249