xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/TfLiteParser.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "TfLiteParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "armnnTfLiteParser/Version.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "armnn/LstmParams.hpp"
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendOptions.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Logging.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/TensorUtils.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/TypesUtils.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/LayerSupport.hpp>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker // armnnUtils:
24*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
25*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Filesystem.hpp>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker #include <ParserHelper.hpp>
28*89c4ff92SAndroid Build Coastguard Worker #include <VerificationHelpers.hpp>
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker // The generated code based on the Tf Lite schema:
31*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flexbuffers.h>
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
38*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
39*89c4ff92SAndroid Build Coastguard Worker #include <limits>
40*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker #define ARMNN_THROW_PARSE_EXCEPTION(msg) \
43*89c4ff92SAndroid Build Coastguard Worker           { \
44*89c4ff92SAndroid Build Coastguard Worker             throw armnn::ParseException( static_cast<const std::stringstream&>( std::stringstream() << msg \
45*89c4ff92SAndroid Build Coastguard Worker                << ": " \
46*89c4ff92SAndroid Build Coastguard Worker                << CHECK_LOCATION().AsString()).str()); \
47*89c4ff92SAndroid Build Coastguard Worker           }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
50*89c4ff92SAndroid Build Coastguard Worker using armnn::CheckLocation;
51*89c4ff92SAndroid Build Coastguard Worker namespace armnnTfLiteParser
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker 
ITfLiteParser(const armnn::Optional<TfLiteParserOptions> & options)54*89c4ff92SAndroid Build Coastguard Worker ITfLiteParser::ITfLiteParser(const armnn::Optional<TfLiteParserOptions>& options) :
55*89c4ff92SAndroid Build Coastguard Worker     pTfLiteParserImpl(new TfLiteParserImpl(options)) {}
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker ITfLiteParser::~ITfLiteParser() = default;
58*89c4ff92SAndroid Build Coastguard Worker 
CreateRaw(const armnn::Optional<TfLiteParserOptions> & options)59*89c4ff92SAndroid Build Coastguard Worker ITfLiteParser* ITfLiteParser::CreateRaw(const armnn::Optional<TfLiteParserOptions>& options)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker     return new ITfLiteParser(options);
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
Create(const armnn::Optional<TfLiteParserOptions> & options)64*89c4ff92SAndroid Build Coastguard Worker ITfLiteParserPtr ITfLiteParser::Create(const armnn::Optional<TfLiteParserOptions>& options)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker     return ITfLiteParserPtr(CreateRaw(options), &ITfLiteParser::Destroy);
67*89c4ff92SAndroid Build Coastguard Worker }
68*89c4ff92SAndroid Build Coastguard Worker 
Destroy(ITfLiteParser * parser)69*89c4ff92SAndroid Build Coastguard Worker void ITfLiteParser::Destroy(ITfLiteParser* parser)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker     delete parser;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
CreateNetworkFromBinaryFile(const char * graphFile)74*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr ITfLiteParser::CreateNetworkFromBinaryFile(const char* graphFile)
75*89c4ff92SAndroid Build Coastguard Worker {
76*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->CreateNetworkFromBinaryFile(graphFile);
77*89c4ff92SAndroid Build Coastguard Worker }
78*89c4ff92SAndroid Build Coastguard Worker 
CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent)79*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr ITfLiteParser::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
80*89c4ff92SAndroid Build Coastguard Worker {
81*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->CreateNetworkFromBinary(binaryContent);
82*89c4ff92SAndroid Build Coastguard Worker }
83*89c4ff92SAndroid Build Coastguard Worker 
GetNetworkInputBindingInfo(size_t subgraphId,const std::string & name) const84*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo ITfLiteParser::GetNetworkInputBindingInfo(size_t subgraphId,
85*89c4ff92SAndroid Build Coastguard Worker                                                            const std::string& name) const
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->GetNetworkInputBindingInfo(subgraphId, name);
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker 
GetNetworkOutputBindingInfo(size_t subgraphId,const std::string & name) const90*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo ITfLiteParser::GetNetworkOutputBindingInfo(size_t subgraphId,
91*89c4ff92SAndroid Build Coastguard Worker                                                             const std::string& name) const
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->GetNetworkOutputBindingInfo(subgraphId, name);
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphCount() const96*89c4ff92SAndroid Build Coastguard Worker size_t ITfLiteParser::GetSubgraphCount() const
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->GetSubgraphCount();
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphInputTensorNames(size_t subgraphId) const101*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> ITfLiteParser::GetSubgraphInputTensorNames(size_t subgraphId) const
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->GetSubgraphInputTensorNames(subgraphId);
104*89c4ff92SAndroid Build Coastguard Worker }
105*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphOutputTensorNames(size_t subgraphId) const106*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> ITfLiteParser::GetSubgraphOutputTensorNames(size_t subgraphId) const
107*89c4ff92SAndroid Build Coastguard Worker {
108*89c4ff92SAndroid Build Coastguard Worker     return pTfLiteParserImpl->GetSubgraphOutputTensorNames(subgraphId);
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker namespace
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker const uint32_t VIRTUAL_OPERATOR_ID = std::numeric_limits<uint32_t>::max();
115*89c4ff92SAndroid Build Coastguard Worker 
CheckSubgraph(const TfLiteParserImpl::ModelPtr & model,size_t subgraphIndex,const CheckLocation & location)116*89c4ff92SAndroid Build Coastguard Worker void CheckSubgraph(const TfLiteParserImpl::ModelPtr& model,
117*89c4ff92SAndroid Build Coastguard Worker                    size_t subgraphIndex,
118*89c4ff92SAndroid Build Coastguard Worker                    const CheckLocation& location)
119*89c4ff92SAndroid Build Coastguard Worker {
120*89c4ff92SAndroid Build Coastguard Worker     if (model.get() == nullptr)
121*89c4ff92SAndroid Build Coastguard Worker     {
122*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
123*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with invalid (null) model. "
124*89c4ff92SAndroid Build Coastguard Worker                         "Possible reason is that the model is not yet loaded and Unpack(ed). "
125*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} at {}",
126*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
127*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
128*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
129*89c4ff92SAndroid Build Coastguard Worker     }
130*89c4ff92SAndroid Build Coastguard Worker     else if (subgraphIndex >= model->subgraphs.size())
131*89c4ff92SAndroid Build Coastguard Worker     {
132*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
133*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with an invalid subgraph index. "
134*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} at {}",
135*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
136*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
137*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
138*89c4ff92SAndroid Build Coastguard Worker     }
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker #define CHECK_SUBGRAPH(MODEL, SUBGRAPH_INDEX) \
142*89c4ff92SAndroid Build Coastguard Worker     CheckSubgraph(MODEL, SUBGRAPH_INDEX, CHECK_LOCATION())
143*89c4ff92SAndroid Build Coastguard Worker 
CheckModel(const TfLiteParserImpl::ModelPtr & model,size_t subgraphIndex,size_t operatorIndex,const CheckLocation & location)144*89c4ff92SAndroid Build Coastguard Worker void CheckModel(const TfLiteParserImpl::ModelPtr& model,
145*89c4ff92SAndroid Build Coastguard Worker                 size_t subgraphIndex,
146*89c4ff92SAndroid Build Coastguard Worker                 size_t operatorIndex,
147*89c4ff92SAndroid Build Coastguard Worker                 const CheckLocation& location)
148*89c4ff92SAndroid Build Coastguard Worker {
149*89c4ff92SAndroid Build Coastguard Worker     if (model.get() == nullptr)
150*89c4ff92SAndroid Build Coastguard Worker     {
151*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
152*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with invalid (null) model. "
153*89c4ff92SAndroid Build Coastguard Worker                         "Possible reason is that the model is not yet loaded and Unpack(ed). "
154*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} operator:{} at {}",
155*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
156*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
157*89c4ff92SAndroid Build Coastguard Worker                         operatorIndex,
158*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
159*89c4ff92SAndroid Build Coastguard Worker     }
160*89c4ff92SAndroid Build Coastguard Worker     else if (subgraphIndex >= model->subgraphs.size())
161*89c4ff92SAndroid Build Coastguard Worker     {
162*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
163*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with an invalid subgraph index. "
164*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} operator:{} at {}",
165*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
166*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
167*89c4ff92SAndroid Build Coastguard Worker                         operatorIndex,
168*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
169*89c4ff92SAndroid Build Coastguard Worker     }
170*89c4ff92SAndroid Build Coastguard Worker     else if (operatorIndex >= model->subgraphs[subgraphIndex]->operators.size() &&
171*89c4ff92SAndroid Build Coastguard Worker              operatorIndex != VIRTUAL_OPERATOR_ID)
172*89c4ff92SAndroid Build Coastguard Worker     {
173*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
174*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with an invalid operator index. "
175*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} operator:{} at {}",
176*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
177*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
178*89c4ff92SAndroid Build Coastguard Worker                         operatorIndex,
179*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
180*89c4ff92SAndroid Build Coastguard Worker     }
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker #define CHECK_MODEL(MODEL, SUBGRAPH_INDEX, OPERATOR_INDEX) \
184*89c4ff92SAndroid Build Coastguard Worker     CheckModel(MODEL, SUBGRAPH_INDEX, OPERATOR_INDEX, CHECK_LOCATION())
185*89c4ff92SAndroid Build Coastguard Worker 
CheckTensor(const TfLiteParserImpl::ModelPtr & model,size_t subgraphIndex,size_t tensorIndex,const CheckLocation & location)186*89c4ff92SAndroid Build Coastguard Worker void CheckTensor(const TfLiteParserImpl::ModelPtr& model,
187*89c4ff92SAndroid Build Coastguard Worker                  size_t subgraphIndex,
188*89c4ff92SAndroid Build Coastguard Worker                  size_t tensorIndex,
189*89c4ff92SAndroid Build Coastguard Worker                  const CheckLocation& location)
190*89c4ff92SAndroid Build Coastguard Worker {
191*89c4ff92SAndroid Build Coastguard Worker     // not checking model, because I assume CHECK_MODEL already run
192*89c4ff92SAndroid Build Coastguard Worker     // and checked that. An assert would do.
193*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(model.get() != nullptr, "Expecting a valid model in this function");
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     // also subgraph index should be checked by CHECK_MODEL so
196*89c4ff92SAndroid Build Coastguard Worker     // I only add an assert here
197*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(subgraphIndex < model->subgraphs.size(), "Expecting a valid subgraph index");
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker     // the tensor index is the only one to check here
200*89c4ff92SAndroid Build Coastguard Worker     if (tensorIndex >= model->subgraphs[subgraphIndex]->tensors.size())
201*89c4ff92SAndroid Build Coastguard Worker     {
202*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
203*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with an invalid tensor index. "
204*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} tensor:{} at {}",
205*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
206*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
207*89c4ff92SAndroid Build Coastguard Worker                         tensorIndex,
208*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
209*89c4ff92SAndroid Build Coastguard Worker     }
210*89c4ff92SAndroid Build Coastguard Worker }
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker #define CHECK_TENSOR(MODEL, SUBGRAPH_INDEX, TENSOR_INDEX) \
213*89c4ff92SAndroid Build Coastguard Worker     CheckTensor(MODEL, SUBGRAPH_INDEX, TENSOR_INDEX, CHECK_LOCATION())
214*89c4ff92SAndroid Build Coastguard Worker 
CheckTensorPtr(TfLiteParserImpl::TensorRawPtr rawPtr,const CheckLocation & location)215*89c4ff92SAndroid Build Coastguard Worker void CheckTensorPtr(TfLiteParserImpl::TensorRawPtr rawPtr,
216*89c4ff92SAndroid Build Coastguard Worker                     const CheckLocation& location)
217*89c4ff92SAndroid Build Coastguard Worker {
218*89c4ff92SAndroid Build Coastguard Worker     if (rawPtr == nullptr)
219*89c4ff92SAndroid Build Coastguard Worker     {
220*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
221*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with a null tensor pointer at {}", location.m_Function, location.FileLine()));
222*89c4ff92SAndroid Build Coastguard Worker     }
223*89c4ff92SAndroid Build Coastguard Worker }
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker #define CHECK_TENSOR_PTR(TENSOR_PTR) \
226*89c4ff92SAndroid Build Coastguard Worker     CheckTensorPtr(TENSOR_PTR, CHECK_LOCATION())
227*89c4ff92SAndroid Build Coastguard Worker 
CheckBuffer(const TfLiteParserImpl::ModelPtr & model,size_t bufferIndex,const CheckLocation & location)228*89c4ff92SAndroid Build Coastguard Worker void CheckBuffer(const TfLiteParserImpl::ModelPtr& model,
229*89c4ff92SAndroid Build Coastguard Worker                  size_t bufferIndex,
230*89c4ff92SAndroid Build Coastguard Worker                  const CheckLocation& location)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker     if (model.get() == nullptr)
233*89c4ff92SAndroid Build Coastguard Worker     {
234*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
235*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with invalid (null) model. "
236*89c4ff92SAndroid Build Coastguard Worker                         "Possible reason is that the model is not yet loaded and Unpack(ed). "
237*89c4ff92SAndroid Build Coastguard Worker                         "buffer:{} at {}",
238*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
239*89c4ff92SAndroid Build Coastguard Worker                         bufferIndex,
240*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
241*89c4ff92SAndroid Build Coastguard Worker     }
242*89c4ff92SAndroid Build Coastguard Worker     else if (bufferIndex >= model->buffers.size())
243*89c4ff92SAndroid Build Coastguard Worker     {
244*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
245*89c4ff92SAndroid Build Coastguard Worker             fmt::format("{} was called with an invalid buffer index. "
246*89c4ff92SAndroid Build Coastguard Worker                         "buffer index:{} at {}",
247*89c4ff92SAndroid Build Coastguard Worker                         location.m_Function,
248*89c4ff92SAndroid Build Coastguard Worker                         bufferIndex,
249*89c4ff92SAndroid Build Coastguard Worker                         location.FileLine()));
250*89c4ff92SAndroid Build Coastguard Worker     }
251*89c4ff92SAndroid Build Coastguard Worker     else if (model->buffers[bufferIndex].get() == nullptr)
252*89c4ff92SAndroid Build Coastguard Worker     {
253*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
254*89c4ff92SAndroid Build Coastguard Worker             fmt::format("The buffer #{} is null. {}",
255*89c4ff92SAndroid Build Coastguard Worker                         bufferIndex,
256*89c4ff92SAndroid Build Coastguard Worker                         location.AsString()));
257*89c4ff92SAndroid Build Coastguard Worker     }
258*89c4ff92SAndroid Build Coastguard Worker }
259*89c4ff92SAndroid Build Coastguard Worker 
260*89c4ff92SAndroid Build Coastguard Worker #define CHECK_BUFFER(MODEL, BUFFER_INDEX) \
261*89c4ff92SAndroid Build Coastguard Worker     CheckBuffer(MODEL, BUFFER_INDEX, CHECK_LOCATION())
262*89c4ff92SAndroid Build Coastguard Worker 
CheckBufferSize(TfLiteParserImpl::BufferRawPtr bufferPtr,const armnn::TensorInfo & tensorInfo,uint32_t bufferId,const CheckLocation & location)263*89c4ff92SAndroid Build Coastguard Worker void CheckBufferSize(TfLiteParserImpl::BufferRawPtr bufferPtr,
264*89c4ff92SAndroid Build Coastguard Worker                      const armnn::TensorInfo& tensorInfo,
265*89c4ff92SAndroid Build Coastguard Worker                      uint32_t bufferId,
266*89c4ff92SAndroid Build Coastguard Worker                      const CheckLocation& location)
267*89c4ff92SAndroid Build Coastguard Worker {
268*89c4ff92SAndroid Build Coastguard Worker     if (bufferPtr == nullptr)
269*89c4ff92SAndroid Build Coastguard Worker     {
270*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
271*89c4ff92SAndroid Build Coastguard Worker             fmt::format("BufferPtr is null for buffer:{}. {}",
272*89c4ff92SAndroid Build Coastguard Worker                         bufferId,
273*89c4ff92SAndroid Build Coastguard Worker                         location.AsString()));
274*89c4ff92SAndroid Build Coastguard Worker     }
275*89c4ff92SAndroid Build Coastguard Worker     else if(tensorInfo.GetNumElements() > bufferPtr->data.size() ||
276*89c4ff92SAndroid Build Coastguard Worker             tensorInfo.GetNumBytes() > bufferPtr->data.size())
277*89c4ff92SAndroid Build Coastguard Worker     {
278*89c4ff92SAndroid Build Coastguard Worker         std::stringstream ss;
279*89c4ff92SAndroid Build Coastguard Worker         ss << "Buffer #" << bufferId << " has " << bufferPtr->data.size() << " bytes. "
280*89c4ff92SAndroid Build Coastguard Worker            << "For tensor: " << tensorInfo.GetShape()
281*89c4ff92SAndroid Build Coastguard Worker            << " expecting: " << tensorInfo.GetNumBytes() << " bytes and "
282*89c4ff92SAndroid Build Coastguard Worker            << tensorInfo.GetNumElements() << " elements. " << location.AsString();
283*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(ss.str());
284*89c4ff92SAndroid Build Coastguard Worker     }
285*89c4ff92SAndroid Build Coastguard Worker }
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker 
GetOpCode(const TfLiteParserImpl::ModelPtr & model,size_t subgraphIndex,size_t operatorIndex)288*89c4ff92SAndroid Build Coastguard Worker tflite::BuiltinOperator GetOpCode(const TfLiteParserImpl::ModelPtr& model, size_t subgraphIndex, size_t operatorIndex)
289*89c4ff92SAndroid Build Coastguard Worker {
290*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = model->subgraphs[subgraphIndex]->operators[operatorIndex];
291*89c4ff92SAndroid Build Coastguard Worker     auto opcodeIndex = operatorPtr->opcode_index;
292*89c4ff92SAndroid Build Coastguard Worker 
293*89c4ff92SAndroid Build Coastguard Worker // work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner
294*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_3)
295*89c4ff92SAndroid Build Coastguard Worker     auto opcode = std::max(model->operator_codes[opcodeIndex]->builtin_code,
296*89c4ff92SAndroid Build Coastguard Worker             static_cast<tflite::BuiltinOperator>(model->operator_codes[opcodeIndex]->deprecated_builtin_code));
297*89c4ff92SAndroid Build Coastguard Worker #else
298*89c4ff92SAndroid Build Coastguard Worker     auto opcode = model->operator_codes[opcodeIndex]->builtin_code;
299*89c4ff92SAndroid Build Coastguard Worker #endif
300*89c4ff92SAndroid Build Coastguard Worker     return opcode;
301*89c4ff92SAndroid Build Coastguard Worker }
302*89c4ff92SAndroid Build Coastguard Worker 
GetUIntBuffer(armnn::TensorInfo info,const TfLiteParserImpl::ModelPtr & model,size_t bufferIndex)303*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> GetUIntBuffer(armnn::TensorInfo info,
304*89c4ff92SAndroid Build Coastguard Worker                                         const TfLiteParserImpl::ModelPtr& model,
305*89c4ff92SAndroid Build Coastguard Worker                                         size_t bufferIndex)
306*89c4ff92SAndroid Build Coastguard Worker {
307*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::BufferRawPtr bufferPtr = TfLiteParserImpl::GetBuffer(model, bufferIndex);
308*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> buffer(info.GetNumElements());
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker     if (info.GetDataType() == DataType::Signed32)
311*89c4ff92SAndroid Build Coastguard Worker     {
312*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(buffer.data(), bufferPtr->data.data(), bufferPtr->data.size());
313*89c4ff92SAndroid Build Coastguard Worker     }
314*89c4ff92SAndroid Build Coastguard Worker     else if (info.GetDataType() == DataType::Signed64)
315*89c4ff92SAndroid Build Coastguard Worker     {
316*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint64_t> uint64Buffer(info.GetNumElements());
317*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(uint64Buffer.data(), bufferPtr->data.data(), bufferPtr->data.size());
318*89c4ff92SAndroid Build Coastguard Worker         buffer.assign(std::begin(uint64Buffer), std::end(uint64Buffer));
319*89c4ff92SAndroid Build Coastguard Worker     }
320*89c4ff92SAndroid Build Coastguard Worker     else
321*89c4ff92SAndroid Build Coastguard Worker     {
322*89c4ff92SAndroid Build Coastguard Worker         CheckLocation location = CHECK_LOCATION();
323*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
324*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Unsupported data type for uint buffer {}, only Signed 32 or Signed 64 are supported. {}",
325*89c4ff92SAndroid Build Coastguard Worker                             GetDataTypeName(info.GetDataType()),
326*89c4ff92SAndroid Build Coastguard Worker                             location.AsString()));
327*89c4ff92SAndroid Build Coastguard Worker     }
328*89c4ff92SAndroid Build Coastguard Worker     return buffer;
329*89c4ff92SAndroid Build Coastguard Worker }
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker #define CHECK_BUFFER_SIZE(BUFFER_PTR, TENSOR_INFO, BUFFER_ID) \
332*89c4ff92SAndroid Build Coastguard Worker     CheckBufferSize(BUFFER_PTR, TENSOR_INFO, BUFFER_ID, CHECK_LOCATION())
333*89c4ff92SAndroid Build Coastguard Worker 
IsActivationSupported(tflite::ActivationFunctionType activationType)334*89c4ff92SAndroid Build Coastguard Worker bool IsActivationSupported(tflite::ActivationFunctionType activationType)
335*89c4ff92SAndroid Build Coastguard Worker {
336*89c4ff92SAndroid Build Coastguard Worker     switch(activationType)
337*89c4ff92SAndroid Build Coastguard Worker     {
338*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_NONE:
339*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_RELU:
340*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_RELU6:
341*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_TANH:
342*89c4ff92SAndroid Build Coastguard Worker         {
343*89c4ff92SAndroid Build Coastguard Worker             return true;
344*89c4ff92SAndroid Build Coastguard Worker         }
345*89c4ff92SAndroid Build Coastguard Worker         default:
346*89c4ff92SAndroid Build Coastguard Worker         {
347*89c4ff92SAndroid Build Coastguard Worker             return false;
348*89c4ff92SAndroid Build Coastguard Worker         }
349*89c4ff92SAndroid Build Coastguard Worker     }
350*89c4ff92SAndroid Build Coastguard Worker }
351*89c4ff92SAndroid Build Coastguard Worker 
352*89c4ff92SAndroid Build Coastguard Worker #define CHECK_SUPPORTED_FUSED_ACTIVATION(OPTION, SUBGRAPH_INDEX, OPERATOR_INDEX) \
353*89c4ff92SAndroid Build Coastguard Worker     do { \
354*89c4ff92SAndroid Build Coastguard Worker         if (IsActivationSupported(OPTION->fused_activation_function) == false) \
355*89c4ff92SAndroid Build Coastguard Worker         { \
356*89c4ff92SAndroid Build Coastguard Worker             throw ParseException( \
357*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("TfLite parser doesn't support fused activation: " \
358*89c4ff92SAndroid Build Coastguard Worker                             "{}/{} in {} subgraph:{} operator:{} at {}", \
359*89c4ff92SAndroid Build Coastguard Worker                             OPTION->fused_activation_function, \
360*89c4ff92SAndroid Build Coastguard Worker                             tflite::EnumNameActivationFunctionType(\
361*89c4ff92SAndroid Build Coastguard Worker                             OPTION->fused_activation_function), \
362*89c4ff92SAndroid Build Coastguard Worker                             __func__, \
363*89c4ff92SAndroid Build Coastguard Worker                             SUBGRAPH_INDEX, \
364*89c4ff92SAndroid Build Coastguard Worker                             OPERATOR_INDEX, \
365*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().FileLine())); \
366*89c4ff92SAndroid Build Coastguard Worker         } \
367*89c4ff92SAndroid Build Coastguard Worker     } while(false)
368*89c4ff92SAndroid Build Coastguard Worker 
369*89c4ff92SAndroid Build Coastguard Worker 
AsUnsignedVector(const std::vector<int32_t> & in)370*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> AsUnsignedVector(const std::vector<int32_t>& in)
371*89c4ff92SAndroid Build Coastguard Worker {
372*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> result;
373*89c4ff92SAndroid Build Coastguard Worker     result.reserve(in.size());
374*89c4ff92SAndroid Build Coastguard Worker     for (auto& i : in)
375*89c4ff92SAndroid Build Coastguard Worker     {
376*89c4ff92SAndroid Build Coastguard Worker         // If the location of the input data is -1 then the input should be ignored.
377*89c4ff92SAndroid Build Coastguard Worker         if (i == -1)
378*89c4ff92SAndroid Build Coastguard Worker         {
379*89c4ff92SAndroid Build Coastguard Worker             continue;
380*89c4ff92SAndroid Build Coastguard Worker         }
381*89c4ff92SAndroid Build Coastguard Worker         result.push_back(CHECKED_NON_NEGATIVE(i));
382*89c4ff92SAndroid Build Coastguard Worker     }
383*89c4ff92SAndroid Build Coastguard Worker     return result;
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker 
IsOptionalOperandPresent(int input)386*89c4ff92SAndroid Build Coastguard Worker bool IsOptionalOperandPresent(int input)
387*89c4ff92SAndroid Build Coastguard Worker {
388*89c4ff92SAndroid Build Coastguard Worker     return (input >= 0);
389*89c4ff92SAndroid Build Coastguard Worker }
390*89c4ff92SAndroid Build Coastguard Worker 
CalcPadding(uint32_t inputSize,uint32_t filterSize,uint32_t stride,uint32_t dilation,uint32_t & paddingFront,uint32_t & paddingBack,tflite::Padding padding)391*89c4ff92SAndroid Build Coastguard Worker void CalcPadding(uint32_t inputSize,
392*89c4ff92SAndroid Build Coastguard Worker                  uint32_t filterSize,
393*89c4ff92SAndroid Build Coastguard Worker                  uint32_t stride,
394*89c4ff92SAndroid Build Coastguard Worker                  uint32_t dilation,
395*89c4ff92SAndroid Build Coastguard Worker                  uint32_t& paddingFront,
396*89c4ff92SAndroid Build Coastguard Worker                  uint32_t& paddingBack,
397*89c4ff92SAndroid Build Coastguard Worker                  tflite::Padding padding)
398*89c4ff92SAndroid Build Coastguard Worker {
399*89c4ff92SAndroid Build Coastguard Worker     paddingFront = 0;
400*89c4ff92SAndroid Build Coastguard Worker     paddingBack = 0;
401*89c4ff92SAndroid Build Coastguard Worker     if (padding == tflite::Padding_SAME)
402*89c4ff92SAndroid Build Coastguard Worker     {
403*89c4ff92SAndroid Build Coastguard Worker         uint32_t outputSize = (inputSize + stride - 1) / stride;
404*89c4ff92SAndroid Build Coastguard Worker         uint32_t dilatedSize = filterSize + (dilation - 1) * (filterSize - 1);
405*89c4ff92SAndroid Build Coastguard Worker         uint32_t temp = (outputSize - 1) * stride + dilatedSize;
406*89c4ff92SAndroid Build Coastguard Worker         if (temp > inputSize)
407*89c4ff92SAndroid Build Coastguard Worker         {
408*89c4ff92SAndroid Build Coastguard Worker             paddingFront = (temp - inputSize) / 2;
409*89c4ff92SAndroid Build Coastguard Worker             paddingBack = (temp - inputSize) - paddingFront;
410*89c4ff92SAndroid Build Coastguard Worker         }
411*89c4ff92SAndroid Build Coastguard Worker     }
412*89c4ff92SAndroid Build Coastguard Worker }
413*89c4ff92SAndroid Build Coastguard Worker 
ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,const std::vector<unsigned int> & shape,const bool outputTensor=false)414*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
415*89c4ff92SAndroid Build Coastguard Worker                                const std::vector<unsigned int>& shape,
416*89c4ff92SAndroid Build Coastguard Worker                                const bool outputTensor = false)
417*89c4ff92SAndroid Build Coastguard Worker {
418*89c4ff92SAndroid Build Coastguard Worker     armnn::DataType type;
419*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR_PTR(tensorPtr);
420*89c4ff92SAndroid Build Coastguard Worker 
421*89c4ff92SAndroid Build Coastguard Worker     switch (tensorPtr->type)
422*89c4ff92SAndroid Build Coastguard Worker     {
423*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_UINT8:
424*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::QAsymmU8;
425*89c4ff92SAndroid Build Coastguard Worker             break;
426*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_FLOAT32:
427*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::Float32;
428*89c4ff92SAndroid Build Coastguard Worker             break;
429*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_FLOAT16:
430*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::Float16;
431*89c4ff92SAndroid Build Coastguard Worker             break;
432*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_INT8:
433*89c4ff92SAndroid Build Coastguard Worker             if (tensorPtr->quantization->zero_point.size() == 1)
434*89c4ff92SAndroid Build Coastguard Worker             {
435*89c4ff92SAndroid Build Coastguard Worker                 // Per-tensor
436*89c4ff92SAndroid Build Coastguard Worker                 type = armnn::DataType::QAsymmS8;
437*89c4ff92SAndroid Build Coastguard Worker             }
438*89c4ff92SAndroid Build Coastguard Worker             else
439*89c4ff92SAndroid Build Coastguard Worker             {
440*89c4ff92SAndroid Build Coastguard Worker                 // Per-channel
441*89c4ff92SAndroid Build Coastguard Worker                 type = armnn::DataType::QSymmS8;
442*89c4ff92SAndroid Build Coastguard Worker             }
443*89c4ff92SAndroid Build Coastguard Worker             break;
444*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_INT16:
445*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::QSymmS16;
446*89c4ff92SAndroid Build Coastguard Worker             break;
447*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_INT32:
448*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::Signed32;
449*89c4ff92SAndroid Build Coastguard Worker             break;
450*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_INT64:
451*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::Signed64;
452*89c4ff92SAndroid Build Coastguard Worker             break;
453*89c4ff92SAndroid Build Coastguard Worker         case tflite::TensorType_BOOL:
454*89c4ff92SAndroid Build Coastguard Worker             type = armnn::DataType::Boolean;
455*89c4ff92SAndroid Build Coastguard Worker             break;
456*89c4ff92SAndroid Build Coastguard Worker         default:
457*89c4ff92SAndroid Build Coastguard Worker         {
458*89c4ff92SAndroid Build Coastguard Worker             CheckLocation location = CHECK_LOCATION();
459*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
460*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Unsupported data type {} = {} for tensor: {}. {}",
461*89c4ff92SAndroid Build Coastguard Worker                             tensorPtr->type,
462*89c4ff92SAndroid Build Coastguard Worker                             tflite::EnumNameTensorType(tensorPtr->type),
463*89c4ff92SAndroid Build Coastguard Worker                             tensorPtr->name,
464*89c4ff92SAndroid Build Coastguard Worker                             location.AsString()));
465*89c4ff92SAndroid Build Coastguard Worker         }
466*89c4ff92SAndroid Build Coastguard Worker     }
467*89c4ff92SAndroid Build Coastguard Worker     TensorShape tensorShape;
468*89c4ff92SAndroid Build Coastguard Worker 
469*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> safeShape = shape;
470*89c4ff92SAndroid Build Coastguard Worker     if (shape.size() == 0)
471*89c4ff92SAndroid Build Coastguard Worker     {
472*89c4ff92SAndroid Build Coastguard Worker         safeShape.push_back(1);
473*89c4ff92SAndroid Build Coastguard Worker     }
474*89c4ff92SAndroid Build Coastguard Worker 
475*89c4ff92SAndroid Build Coastguard Worker     if (!outputTensor)
476*89c4ff92SAndroid Build Coastguard Worker     {
477*89c4ff92SAndroid Build Coastguard Worker         tensorShape = TensorShape(armnn::numeric_cast<unsigned int>(safeShape.size()), safeShape.data());
478*89c4ff92SAndroid Build Coastguard Worker     }
479*89c4ff92SAndroid Build Coastguard Worker     else
480*89c4ff92SAndroid Build Coastguard Worker     {
481*89c4ff92SAndroid Build Coastguard Worker         size_t shapeSignatureSize = tensorPtr->shape_signature.size();
482*89c4ff92SAndroid Build Coastguard Worker 
483*89c4ff92SAndroid Build Coastguard Worker         // If a shape signature exists we will use that to infer dynamic tensors
484*89c4ff92SAndroid Build Coastguard Worker         if (shapeSignatureSize != 0)
485*89c4ff92SAndroid Build Coastguard Worker         {
486*89c4ff92SAndroid Build Coastguard Worker             // If the shape is incompatible with the shape signature override the shape
487*89c4ff92SAndroid Build Coastguard Worker             if (shapeSignatureSize != shape.size())
488*89c4ff92SAndroid Build Coastguard Worker             {
489*89c4ff92SAndroid Build Coastguard Worker                 safeShape = {};
490*89c4ff92SAndroid Build Coastguard Worker 
491*89c4ff92SAndroid Build Coastguard Worker                 for (unsigned int i = 0; i < shapeSignatureSize; ++i)
492*89c4ff92SAndroid Build Coastguard Worker                 {
493*89c4ff92SAndroid Build Coastguard Worker                     unsigned int dim = tensorPtr->shape_signature[i] > -1 ?
494*89c4ff92SAndroid Build Coastguard Worker                                        static_cast<unsigned int>(tensorPtr->shape_signature[i]) : 0;
495*89c4ff92SAndroid Build Coastguard Worker                     safeShape.push_back(dim);
496*89c4ff92SAndroid Build Coastguard Worker                 }
497*89c4ff92SAndroid Build Coastguard Worker             }
498*89c4ff92SAndroid Build Coastguard Worker 
499*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<bool[]> dimMask = std::make_unique<bool[]>(tensorPtr->shape_signature.size());
500*89c4ff92SAndroid Build Coastguard Worker             bool batchOnly = true;
501*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int i = 0; i < tensorPtr->shape_signature.size(); ++i)
502*89c4ff92SAndroid Build Coastguard Worker             {
503*89c4ff92SAndroid Build Coastguard Worker                 dimMask[i] = tensorPtr->shape_signature[i] != -1;
504*89c4ff92SAndroid Build Coastguard Worker 
505*89c4ff92SAndroid Build Coastguard Worker                 if (i > 0 && !dimMask[i])
506*89c4ff92SAndroid Build Coastguard Worker                 {
507*89c4ff92SAndroid Build Coastguard Worker                     batchOnly = false;
508*89c4ff92SAndroid Build Coastguard Worker                 }
509*89c4ff92SAndroid Build Coastguard Worker             }
510*89c4ff92SAndroid Build Coastguard Worker             if (batchOnly)
511*89c4ff92SAndroid Build Coastguard Worker             {
512*89c4ff92SAndroid Build Coastguard Worker                 dimMask[0] = true;
513*89c4ff92SAndroid Build Coastguard Worker             }
514*89c4ff92SAndroid Build Coastguard Worker             tensorShape = TensorShape(static_cast<unsigned int>(safeShape.size()), safeShape.data(), dimMask.get());
515*89c4ff92SAndroid Build Coastguard Worker         }
516*89c4ff92SAndroid Build Coastguard Worker         // If there is no shape signature treat the tensor as dynamic if the shape has a size of zero
517*89c4ff92SAndroid Build Coastguard Worker         else if (shape.size() == 0)
518*89c4ff92SAndroid Build Coastguard Worker         {
519*89c4ff92SAndroid Build Coastguard Worker             tensorShape = TensorShape(1, false);
520*89c4ff92SAndroid Build Coastguard Worker         }
521*89c4ff92SAndroid Build Coastguard Worker         else
522*89c4ff92SAndroid Build Coastguard Worker         {
523*89c4ff92SAndroid Build Coastguard Worker             tensorShape = TensorShape(armnn::numeric_cast<unsigned int>(shape.size()), shape.data());
524*89c4ff92SAndroid Build Coastguard Worker         }
525*89c4ff92SAndroid Build Coastguard Worker     }
526*89c4ff92SAndroid Build Coastguard Worker 
527*89c4ff92SAndroid Build Coastguard Worker     float quantizationScale = 1.0f;
528*89c4ff92SAndroid Build Coastguard Worker     int32_t quantizationOffset = 0;
529*89c4ff92SAndroid Build Coastguard Worker 
530*89c4ff92SAndroid Build Coastguard Worker     if (tensorPtr->quantization.get())
531*89c4ff92SAndroid Build Coastguard Worker     {
532*89c4ff92SAndroid Build Coastguard Worker         if (tensorPtr->quantization->scale.size() <= 1)
533*89c4ff92SAndroid Build Coastguard Worker         {
534*89c4ff92SAndroid Build Coastguard Worker             CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
535*89c4ff92SAndroid Build Coastguard Worker             CHECK_VALID_SIZE(tensorPtr->quantization->zero_point.size(), 0, 1);
536*89c4ff92SAndroid Build Coastguard Worker 
537*89c4ff92SAndroid Build Coastguard Worker             if (tensorPtr->quantization->scale.size() == 1)
538*89c4ff92SAndroid Build Coastguard Worker             {
539*89c4ff92SAndroid Build Coastguard Worker                 quantizationScale = tensorPtr->quantization->scale[0];
540*89c4ff92SAndroid Build Coastguard Worker             }
541*89c4ff92SAndroid Build Coastguard Worker             if (tensorPtr->quantization->zero_point.size() == 1)
542*89c4ff92SAndroid Build Coastguard Worker             {
543*89c4ff92SAndroid Build Coastguard Worker                 // NOTE: we lose precision here when converting from 64 bit to 32
544*89c4ff92SAndroid Build Coastguard Worker                 //       but this is what we support at the moment in ArmNN
545*89c4ff92SAndroid Build Coastguard Worker                 quantizationOffset = armnn::numeric_cast<int32_t>(tensorPtr->quantization->zero_point[0]);
546*89c4ff92SAndroid Build Coastguard Worker             }
547*89c4ff92SAndroid Build Coastguard Worker 
548*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorInfo result(tensorShape,
549*89c4ff92SAndroid Build Coastguard Worker                                      type,
550*89c4ff92SAndroid Build Coastguard Worker                                      quantizationScale,
551*89c4ff92SAndroid Build Coastguard Worker                                      quantizationOffset);
552*89c4ff92SAndroid Build Coastguard Worker             return result;
553*89c4ff92SAndroid Build Coastguard Worker         }
554*89c4ff92SAndroid Build Coastguard Worker         else
555*89c4ff92SAndroid Build Coastguard Worker         {
556*89c4ff92SAndroid Build Coastguard Worker             std::vector<float> quantizationScales;
557*89c4ff92SAndroid Build Coastguard Worker             std::vector<int32_t> quantizationOffsets;
558*89c4ff92SAndroid Build Coastguard Worker 
559*89c4ff92SAndroid Build Coastguard Worker             // Scale
560*89c4ff92SAndroid Build Coastguard Worker             std::copy(tensorPtr->quantization->scale.begin(),
561*89c4ff92SAndroid Build Coastguard Worker                       tensorPtr->quantization->scale.end(),
562*89c4ff92SAndroid Build Coastguard Worker                       std::back_inserter(quantizationScales));
563*89c4ff92SAndroid Build Coastguard Worker 
564*89c4ff92SAndroid Build Coastguard Worker             // QSymmS8 Per-axis
565*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorInfo result(tensorShape,
566*89c4ff92SAndroid Build Coastguard Worker                                      type,
567*89c4ff92SAndroid Build Coastguard Worker                                      quantizationScales,
568*89c4ff92SAndroid Build Coastguard Worker                                      armnn::numeric_cast<unsigned int>(tensorPtr->quantization->quantized_dimension));
569*89c4ff92SAndroid Build Coastguard Worker             return result;
570*89c4ff92SAndroid Build Coastguard Worker         }
571*89c4ff92SAndroid Build Coastguard Worker     }
572*89c4ff92SAndroid Build Coastguard Worker     else
573*89c4ff92SAndroid Build Coastguard Worker     {
574*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo result(tensorShape,
575*89c4ff92SAndroid Build Coastguard Worker                                  type,
576*89c4ff92SAndroid Build Coastguard Worker                                  quantizationScale,
577*89c4ff92SAndroid Build Coastguard Worker                                  quantizationOffset);
578*89c4ff92SAndroid Build Coastguard Worker         return result;
579*89c4ff92SAndroid Build Coastguard Worker     }
580*89c4ff92SAndroid Build Coastguard Worker }
581*89c4ff92SAndroid Build Coastguard Worker 
ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,const bool outputTensor=false)582*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo ToTensorInfo(TfLiteParserImpl::TensorRawPtr tensorPtr,
583*89c4ff92SAndroid Build Coastguard Worker                                const bool outputTensor = false)
584*89c4ff92SAndroid Build Coastguard Worker {
585*89c4ff92SAndroid Build Coastguard Worker     auto const& dimensions = AsUnsignedVector(tensorPtr->shape);
586*89c4ff92SAndroid Build Coastguard Worker     return ToTensorInfo(tensorPtr, dimensions, outputTensor);
587*89c4ff92SAndroid Build Coastguard Worker }
588*89c4ff92SAndroid Build Coastguard Worker 
589*89c4ff92SAndroid Build Coastguard Worker template<typename T>
590*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
CreateConstTensorImpl(TfLiteParserImpl::BufferRawPtr bufferPtr,TfLiteParserImpl::TensorRawPtr tensorPtr,armnn::TensorInfo & tensorInfo,armnn::Optional<armnn::PermutationVector &> permutationVector)591*89c4ff92SAndroid Build Coastguard Worker CreateConstTensorImpl(TfLiteParserImpl::BufferRawPtr bufferPtr,
592*89c4ff92SAndroid Build Coastguard Worker                       TfLiteParserImpl::TensorRawPtr tensorPtr,
593*89c4ff92SAndroid Build Coastguard Worker                       armnn::TensorInfo& tensorInfo,
594*89c4ff92SAndroid Build Coastguard Worker                       armnn::Optional<armnn::PermutationVector&> permutationVector)
595*89c4ff92SAndroid Build Coastguard Worker {
596*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(tensorPtr);
597*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(tensorPtr != nullptr, "tensorPtr is null");
598*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT_MSG(bufferPtr != nullptr,
599*89c4ff92SAndroid Build Coastguard Worker         fmt::format("Buffer for buffer:{} is null", tensorPtr->buffer).c_str());
600*89c4ff92SAndroid Build Coastguard Worker 
601*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
602*89c4ff92SAndroid Build Coastguard Worker 
603*89c4ff92SAndroid Build Coastguard Worker     if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
604*89c4ff92SAndroid Build Coastguard Worker     {
605*89c4ff92SAndroid Build Coastguard Worker         tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
606*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
607*89c4ff92SAndroid Build Coastguard Worker                             reinterpret_cast<const T*>(bufferPtr->data.data()), data.get(), sizeof(T));
608*89c4ff92SAndroid Build Coastguard Worker     }
609*89c4ff92SAndroid Build Coastguard Worker     else
610*89c4ff92SAndroid Build Coastguard Worker     {
611*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(data.get(), bufferPtr->data.data(), tensorInfo.GetNumBytes());
612*89c4ff92SAndroid Build Coastguard Worker     }
613*89c4ff92SAndroid Build Coastguard Worker 
614*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
615*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
616*89c4ff92SAndroid Build Coastguard Worker 
617*89c4ff92SAndroid Build Coastguard Worker     return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
618*89c4ff92SAndroid Build Coastguard Worker }
619*89c4ff92SAndroid Build Coastguard Worker 
GenerateLayerBindingId(size_t subgraphIndex,size_t tensorIndex)620*89c4ff92SAndroid Build Coastguard Worker armnn::LayerBindingId GenerateLayerBindingId(size_t subgraphIndex, size_t tensorIndex)
621*89c4ff92SAndroid Build Coastguard Worker {
622*89c4ff92SAndroid Build Coastguard Worker     // generate the binding id by shifting the tensor id by 8 bit
623*89c4ff92SAndroid Build Coastguard Worker     // and add the subgraph id, which allows 256 subgraphs
624*89c4ff92SAndroid Build Coastguard Worker     return static_cast<armnn::LayerBindingId>((tensorIndex<<8)+subgraphIndex);
625*89c4ff92SAndroid Build Coastguard Worker }
626*89c4ff92SAndroid Build Coastguard Worker 
CheckShape(const armnn::TensorShape & actual,const std::vector<int32_t> & expected)627*89c4ff92SAndroid Build Coastguard Worker bool CheckShape(const armnn::TensorShape& actual, const std::vector<int32_t>& expected)
628*89c4ff92SAndroid Build Coastguard Worker {
629*89c4ff92SAndroid Build Coastguard Worker     const unsigned int actualSize = actual.GetNumDimensions();
630*89c4ff92SAndroid Build Coastguard Worker     if (actualSize != expected.size())
631*89c4ff92SAndroid Build Coastguard Worker     {
632*89c4ff92SAndroid Build Coastguard Worker         return false;
633*89c4ff92SAndroid Build Coastguard Worker     }
634*89c4ff92SAndroid Build Coastguard Worker 
635*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0u; i < actualSize; i++)
636*89c4ff92SAndroid Build Coastguard Worker     {
637*89c4ff92SAndroid Build Coastguard Worker         if (expected[i] < 0 ||
638*89c4ff92SAndroid Build Coastguard Worker             actual[i] != static_cast<unsigned int>(expected[i]))
639*89c4ff92SAndroid Build Coastguard Worker         {
640*89c4ff92SAndroid Build Coastguard Worker             return false;
641*89c4ff92SAndroid Build Coastguard Worker         }
642*89c4ff92SAndroid Build Coastguard Worker     }
643*89c4ff92SAndroid Build Coastguard Worker 
644*89c4ff92SAndroid Build Coastguard Worker     return true;
645*89c4ff92SAndroid Build Coastguard Worker }
646*89c4ff92SAndroid Build Coastguard Worker 
CheckShape(const armnn::TensorShape & actual,const armnn::TensorShape & expected)647*89c4ff92SAndroid Build Coastguard Worker bool CheckShape(const armnn::TensorShape& actual, const armnn::TensorShape& expected)
648*89c4ff92SAndroid Build Coastguard Worker {
649*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedVec;
650*89c4ff92SAndroid Build Coastguard Worker     for (uint32_t i = 0; i < expected.GetNumDimensions(); i++)
651*89c4ff92SAndroid Build Coastguard Worker     {
652*89c4ff92SAndroid Build Coastguard Worker         expectedVec.push_back(expected[i]);
653*89c4ff92SAndroid Build Coastguard Worker     }
654*89c4ff92SAndroid Build Coastguard Worker     return CheckShape(actual, expectedVec);
655*89c4ff92SAndroid Build Coastguard Worker }
656*89c4ff92SAndroid Build Coastguard Worker 
CheckMatchingQuantization(const TensorInfo & first,const TensorInfo & second,const std::string & descName,std::string const & firstName,std::string const & secondName)657*89c4ff92SAndroid Build Coastguard Worker void CheckMatchingQuantization(const TensorInfo& first,
658*89c4ff92SAndroid Build Coastguard Worker                                const TensorInfo& second,
659*89c4ff92SAndroid Build Coastguard Worker                                const std::string& descName,
660*89c4ff92SAndroid Build Coastguard Worker                                std::string const& firstName,
661*89c4ff92SAndroid Build Coastguard Worker                                std::string const& secondName)
662*89c4ff92SAndroid Build Coastguard Worker {
663*89c4ff92SAndroid Build Coastguard Worker     if (!first.IsQuantized() ||
664*89c4ff92SAndroid Build Coastguard Worker         !second.IsQuantized())
665*89c4ff92SAndroid Build Coastguard Worker     {
666*89c4ff92SAndroid Build Coastguard Worker         // Not a quantized type, ignore the validation
667*89c4ff92SAndroid Build Coastguard Worker         return;
668*89c4ff92SAndroid Build Coastguard Worker     }
669*89c4ff92SAndroid Build Coastguard Worker 
670*89c4ff92SAndroid Build Coastguard Worker     DataType firstDataType  = first.GetDataType();
671*89c4ff92SAndroid Build Coastguard Worker     DataType secondDataType = second.GetDataType();
672*89c4ff92SAndroid Build Coastguard Worker 
673*89c4ff92SAndroid Build Coastguard Worker     if (firstDataType != secondDataType)
674*89c4ff92SAndroid Build Coastguard Worker     {
675*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
676*89c4ff92SAndroid Build Coastguard Worker                                        " must be of the same quantized type, " +
677*89c4ff92SAndroid Build Coastguard Worker                                        firstName + " is " + GetDataTypeName(firstDataType) + ", " +
678*89c4ff92SAndroid Build Coastguard Worker                                        secondName + " is " + GetDataTypeName(secondDataType));
679*89c4ff92SAndroid Build Coastguard Worker     }
680*89c4ff92SAndroid Build Coastguard Worker 
681*89c4ff92SAndroid Build Coastguard Worker     if (!first.IsTypeSpaceMatch(second))
682*89c4ff92SAndroid Build Coastguard Worker     {
683*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(descName + ": " + firstName + " and " + secondName +
684*89c4ff92SAndroid Build Coastguard Worker                                        " must have the same quantization space, " +
685*89c4ff92SAndroid Build Coastguard Worker                                        firstName + " has offset " + std::to_string(first.GetQuantizationOffset()) +
686*89c4ff92SAndroid Build Coastguard Worker                                        " and scale " + std::to_string(first.GetQuantizationScale()) + ", " +
687*89c4ff92SAndroid Build Coastguard Worker                                        secondName + " has offset " + std::to_string(second.GetQuantizationOffset()) +
688*89c4ff92SAndroid Build Coastguard Worker                                        " and scale " + std::to_string(second.GetQuantizationScale()));
689*89c4ff92SAndroid Build Coastguard Worker     }
690*89c4ff92SAndroid Build Coastguard Worker }
691*89c4ff92SAndroid Build Coastguard Worker 
IsDynamic(TfLiteParserImpl::TensorRawPtr tensorPtr)692*89c4ff92SAndroid Build Coastguard Worker bool IsDynamic(TfLiteParserImpl::TensorRawPtr tensorPtr)
693*89c4ff92SAndroid Build Coastguard Worker {
694*89c4ff92SAndroid Build Coastguard Worker     auto shape = tensorPtr->shape;
695*89c4ff92SAndroid Build Coastguard Worker 
696*89c4ff92SAndroid Build Coastguard Worker     if (shape.empty())
697*89c4ff92SAndroid Build Coastguard Worker     {
698*89c4ff92SAndroid Build Coastguard Worker         return true;
699*89c4ff92SAndroid Build Coastguard Worker     }
700*89c4ff92SAndroid Build Coastguard Worker     auto shapeSig = tensorPtr->shape_signature;
701*89c4ff92SAndroid Build Coastguard Worker 
702*89c4ff92SAndroid Build Coastguard Worker     if (shapeSig.empty())
703*89c4ff92SAndroid Build Coastguard Worker     {
704*89c4ff92SAndroid Build Coastguard Worker         return false;
705*89c4ff92SAndroid Build Coastguard Worker     }
706*89c4ff92SAndroid Build Coastguard Worker 
707*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < shapeSig.size() ; ++i)
708*89c4ff92SAndroid Build Coastguard Worker     {
709*89c4ff92SAndroid Build Coastguard Worker         if (shapeSig[i] == -1)
710*89c4ff92SAndroid Build Coastguard Worker         {
711*89c4ff92SAndroid Build Coastguard Worker             return true;
712*89c4ff92SAndroid Build Coastguard Worker         }
713*89c4ff92SAndroid Build Coastguard Worker     }
714*89c4ff92SAndroid Build Coastguard Worker     return false;
715*89c4ff92SAndroid Build Coastguard Worker }
716*89c4ff92SAndroid Build Coastguard Worker 
717*89c4ff92SAndroid Build Coastguard Worker } // <anonymous>
718*89c4ff92SAndroid Build Coastguard Worker 
TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOptions> & options)719*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TfLiteParserImpl(const Optional<ITfLiteParser::TfLiteParserOptions>& options)
720*89c4ff92SAndroid Build Coastguard Worker : m_Options(options)
721*89c4ff92SAndroid Build Coastguard Worker , m_Network(nullptr, nullptr)
722*89c4ff92SAndroid Build Coastguard Worker , m_ParserFunctions(tflite::BuiltinOperator_MAX+1, &TfLiteParserImpl::ParseUnsupportedOperator)
723*89c4ff92SAndroid Build Coastguard Worker {
724*89c4ff92SAndroid Build Coastguard Worker     // register supported operators
725*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_ABS]                     = &TfLiteParserImpl::ParseAbs;
726*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_ADD]                     = &TfLiteParserImpl::ParseAdd;
727*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_ARG_MIN]                 = &TfLiteParserImpl::ParseArgMin;
728*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_ARG_MAX]                 = &TfLiteParserImpl::ParseArgMax;
729*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_AVERAGE_POOL_2D]         = &TfLiteParserImpl::ParseAveragePool2D;
730*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_BATCH_TO_SPACE_ND]       = &TfLiteParserImpl::ParseBatchToSpaceND;
731*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_BATCH_MATMUL]            = &TfLiteParserImpl::ParseBatchMatMul;
732*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_CEIL]                    = &TfLiteParserImpl::ParseCeil;
733*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_CAST]                    = &TfLiteParserImpl::ParseCast;
734*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_CONCATENATION]           = &TfLiteParserImpl::ParseConcatenation;
735*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_CONV_2D]                 = &TfLiteParserImpl::ParseConv2D;
736*89c4ff92SAndroid Build Coastguard Worker     // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
737*89c4ff92SAndroid Build Coastguard Worker     #if defined(ARMNN_POST_TFLITE_2_4)
738*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_CONV_3D]                 = &TfLiteParserImpl::ParseConv3D;
739*89c4ff92SAndroid Build Coastguard Worker     #endif
740*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_CUSTOM]                  = &TfLiteParserImpl::ParseCustomOperator;
741*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_DEPTH_TO_SPACE]          = &TfLiteParserImpl::ParseDepthToSpace;
742*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_DEPTHWISE_CONV_2D]       = &TfLiteParserImpl::ParseDepthwiseConv2D;
743*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_DEQUANTIZE]              = &TfLiteParserImpl::ParseDequantize;
744*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_DIV]                     = &TfLiteParserImpl::ParseDiv;
745*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_ELU]                     = &TfLiteParserImpl::ParseElu;
746*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_EQUAL]                   = &TfLiteParserImpl::ParseEqual;
747*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_EXP]                     = &TfLiteParserImpl::ParseExp;
748*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_EXPAND_DIMS]             = &TfLiteParserImpl::ParseExpandDims;
749*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_FLOOR_DIV]               = &TfLiteParserImpl::ParseFloorDiv;
750*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_FULLY_CONNECTED]         = &TfLiteParserImpl::ParseFullyConnected;
751*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_GATHER]                  = &TfLiteParserImpl::ParseGather;
752*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_GATHER_ND]               = &TfLiteParserImpl::ParseGatherNd;
753*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_GREATER]                 = &TfLiteParserImpl::ParseGreater;
754*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_GREATER_EQUAL]           = &TfLiteParserImpl::ParseGreaterOrEqual;
755*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_HARD_SWISH]              = &TfLiteParserImpl::ParseHardSwish;
756*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LEAKY_RELU]              = &TfLiteParserImpl::ParseLeakyRelu;
757*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LESS]                    = &TfLiteParserImpl::ParseLess;
758*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LESS_EQUAL]              = &TfLiteParserImpl::ParseLessOrEqual;
759*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION]
760*89c4ff92SAndroid Build Coastguard Worker             = &TfLiteParserImpl::ParseLocalResponseNormalization;
761*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LOG]                     = &TfLiteParserImpl::ParseLog;
762*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LOGICAL_NOT]             = &TfLiteParserImpl::ParseLogicalNot;
763*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LOGISTIC]                = &TfLiteParserImpl::ParseLogistic;
764*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_LOG_SOFTMAX]             = &TfLiteParserImpl::ParseLogSoftmax;
765*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_L2_NORMALIZATION]        = &TfLiteParserImpl::ParseL2Normalization;
766*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_MAX_POOL_2D]             = &TfLiteParserImpl::ParseMaxPool2D;
767*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_MAXIMUM]                 = &TfLiteParserImpl::ParseMaximum;
768*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_MEAN]                    = &TfLiteParserImpl::ParseMean;
769*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_MINIMUM]                 = &TfLiteParserImpl::ParseMinimum;
770*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_MIRROR_PAD]              = &TfLiteParserImpl::ParseMirrorPad;
771*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_MUL]                     = &TfLiteParserImpl::ParseMul;
772*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_NEG]                     = &TfLiteParserImpl::ParseNeg;
773*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_NOT_EQUAL]               = &TfLiteParserImpl::ParseNotEqual;
774*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_PACK]                    = &TfLiteParserImpl::ParsePack;
775*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_PAD]                     = &TfLiteParserImpl::ParsePad;
776*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_PADV2]                   = &TfLiteParserImpl::ParsePad;
777*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_PRELU]                   = &TfLiteParserImpl::ParsePrelu;
778*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_QUANTIZE]                = &TfLiteParserImpl::ParseQuantize;
779*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_RELU]                    = &TfLiteParserImpl::ParseRelu;
780*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_RELU6]                   = &TfLiteParserImpl::ParseRelu6;
781*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_REDUCE_MAX]              = &TfLiteParserImpl::ParseReduceMax;
782*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_REDUCE_MIN]              = &TfLiteParserImpl::ParseReduceMin;
783*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_REDUCE_PROD]             = &TfLiteParserImpl::ParseReduceProd;
784*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_RESHAPE]                 = &TfLiteParserImpl::ParseReshape;
785*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_RESIZE_BILINEAR]         = &TfLiteParserImpl::ParseResizeBilinear;
786*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_RESIZE_NEAREST_NEIGHBOR] = &TfLiteParserImpl::ParseResizeNearestNeighbor;
787*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_RSQRT]                   = &TfLiteParserImpl::ParseRsqrt;
788*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SQRT]                    = &TfLiteParserImpl::ParseSqrt;
789*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SHAPE]                   = &TfLiteParserImpl::ParseShape;
790*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SIN]                     = &TfLiteParserImpl::ParseSin;
791*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SLICE]                   = &TfLiteParserImpl::ParseSlice;
792*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SOFTMAX]                 = &TfLiteParserImpl::ParseSoftmax;
793*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SPACE_TO_BATCH_ND]       = &TfLiteParserImpl::ParseSpaceToBatchND;
794*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SPACE_TO_DEPTH]          = &TfLiteParserImpl::ParseSpaceToDepth;
795*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SPLIT]                   = &TfLiteParserImpl::ParseSplit;
796*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SPLIT_V]                 = &TfLiteParserImpl::ParseSplitV;
797*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SQUEEZE]                 = &TfLiteParserImpl::ParseSqueeze;
798*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_STRIDED_SLICE]           = &TfLiteParserImpl::ParseStridedSlice;
799*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SUB]                     = &TfLiteParserImpl::ParseSub;
800*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_SUM]                     = &TfLiteParserImpl::ParseSum;
801*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_TANH]                    = &TfLiteParserImpl::ParseTanH;
802*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE]               = &TfLiteParserImpl::ParseTranspose;
803*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_TRANSPOSE_CONV]          = &TfLiteParserImpl::ParseTransposeConv;
804*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM]
805*89c4ff92SAndroid Build Coastguard Worker             = &TfLiteParserImpl::ParseUnidirectionalSequenceLSTM;
806*89c4ff92SAndroid Build Coastguard Worker     m_ParserFunctions[tflite::BuiltinOperator_UNPACK]                  = &TfLiteParserImpl::ParseUnpack;
807*89c4ff92SAndroid Build Coastguard Worker 
808*89c4ff92SAndroid Build Coastguard Worker     // register supported custom operators
809*89c4ff92SAndroid Build Coastguard Worker     m_CustomParserFunctions["TFLite_Detection_PostProcess"]      = &TfLiteParserImpl::ParseDetectionPostProcess;
810*89c4ff92SAndroid Build Coastguard Worker }
811*89c4ff92SAndroid Build Coastguard Worker 
InputTensorInfo(size_t subgraphIndex,size_t operatorIndex,int input)812*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TfLiteParserImpl::InputTensorInfo(size_t subgraphIndex,
813*89c4ff92SAndroid Build Coastguard Worker                                     size_t operatorIndex,
814*89c4ff92SAndroid Build Coastguard Worker                                     int input)
815*89c4ff92SAndroid Build Coastguard Worker {
816*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = m_Model->subgraphs[subgraphIndex];
817*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
818*89c4ff92SAndroid Build Coastguard Worker 
819*89c4ff92SAndroid Build Coastguard Worker     uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[input]);
820*89c4ff92SAndroid Build Coastguard Worker     auto search = armnnTfLiteParser::TfLiteParserImpl::m_TensorInfos.find(inputId);
821*89c4ff92SAndroid Build Coastguard Worker 
822*89c4ff92SAndroid Build Coastguard Worker     if (search != m_TensorInfos.end())
823*89c4ff92SAndroid Build Coastguard Worker     {
824*89c4ff92SAndroid Build Coastguard Worker         return m_TensorInfos[inputId];
825*89c4ff92SAndroid Build Coastguard Worker     }
826*89c4ff92SAndroid Build Coastguard Worker     else
827*89c4ff92SAndroid Build Coastguard Worker     {
828*89c4ff92SAndroid Build Coastguard Worker         auto tensorInfo = ::armnnTfLiteParser::ToTensorInfo(subgraphPtr->tensors[inputId].get());
829*89c4ff92SAndroid Build Coastguard Worker         m_TensorInfos.insert({ inputId, tensorInfo });
830*89c4ff92SAndroid Build Coastguard Worker         return tensorInfo;
831*89c4ff92SAndroid Build Coastguard Worker     }
832*89c4ff92SAndroid Build Coastguard Worker }
833*89c4ff92SAndroid Build Coastguard Worker 
OutputTensorInfoFromInputs(size_t subgraphIndex,size_t operatorIndex,armnn::IConnectableLayer * layer,int output,std::vector<int> inputs)834*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TfLiteParserImpl::OutputTensorInfoFromInputs(size_t subgraphIndex,
835*89c4ff92SAndroid Build Coastguard Worker                                                                size_t operatorIndex,
836*89c4ff92SAndroid Build Coastguard Worker                                                                armnn::IConnectableLayer* layer,
837*89c4ff92SAndroid Build Coastguard Worker                                                                int output,
838*89c4ff92SAndroid Build Coastguard Worker                                                                std::vector<int> inputs)
839*89c4ff92SAndroid Build Coastguard Worker {
840*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = m_Model->subgraphs[subgraphIndex];
841*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
842*89c4ff92SAndroid Build Coastguard Worker 
843*89c4ff92SAndroid Build Coastguard Worker     uint32_t outputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[output]);
844*89c4ff92SAndroid Build Coastguard Worker 
845*89c4ff92SAndroid Build Coastguard Worker     auto outputSearch = armnnTfLiteParser::TfLiteParserImpl::m_TensorInfos.find(outputId);
846*89c4ff92SAndroid Build Coastguard Worker 
847*89c4ff92SAndroid Build Coastguard Worker     if (outputSearch != m_TensorInfos.end())
848*89c4ff92SAndroid Build Coastguard Worker     {
849*89c4ff92SAndroid Build Coastguard Worker         return m_TensorInfos[outputId];
850*89c4ff92SAndroid Build Coastguard Worker     }
851*89c4ff92SAndroid Build Coastguard Worker 
852*89c4ff92SAndroid Build Coastguard Worker     const auto& outputTensorPtr = subgraphPtr->tensors[outputId].get();
853*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensor = ::armnnTfLiteParser::ToTensorInfo(outputTensorPtr, true);
854*89c4ff92SAndroid Build Coastguard Worker 
855*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamic(outputTensorPtr))
856*89c4ff92SAndroid Build Coastguard Worker     {
857*89c4ff92SAndroid Build Coastguard Worker         if (inputs.empty())
858*89c4ff92SAndroid Build Coastguard Worker         {
859*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int i = 0; i < layer->GetNumInputSlots(); ++i)
860*89c4ff92SAndroid Build Coastguard Worker             {
861*89c4ff92SAndroid Build Coastguard Worker                 inputs.emplace_back(i);
862*89c4ff92SAndroid Build Coastguard Worker             }
863*89c4ff92SAndroid Build Coastguard Worker         }
864*89c4ff92SAndroid Build Coastguard Worker         auto inputTensorIds = GetInputTensorIds(m_Model, subgraphIndex, operatorIndex);
865*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::TensorShape> inputShapes;
866*89c4ff92SAndroid Build Coastguard Worker 
867*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < inputs.size(); ++i)
868*89c4ff92SAndroid Build Coastguard Worker         {
869*89c4ff92SAndroid Build Coastguard Worker             uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[inputs[i]]);
870*89c4ff92SAndroid Build Coastguard Worker             auto search = armnnTfLiteParser::TfLiteParserImpl::m_TensorInfos.find(inputId);
871*89c4ff92SAndroid Build Coastguard Worker 
872*89c4ff92SAndroid Build Coastguard Worker             if (search != m_TensorInfos.end())
873*89c4ff92SAndroid Build Coastguard Worker             {
874*89c4ff92SAndroid Build Coastguard Worker                 auto &inputTensorInfo = m_TensorInfos[inputId];
875*89c4ff92SAndroid Build Coastguard Worker                 inputShapes.push_back(inputTensorInfo.GetShape());
876*89c4ff92SAndroid Build Coastguard Worker             }
877*89c4ff92SAndroid Build Coastguard Worker             else
878*89c4ff92SAndroid Build Coastguard Worker             {
879*89c4ff92SAndroid Build Coastguard Worker                 m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
880*89c4ff92SAndroid Build Coastguard Worker                 auto inputTensorInfo = ::armnnTfLiteParser::ToTensorInfo(subgraphPtr->tensors[inputId].get());
881*89c4ff92SAndroid Build Coastguard Worker                 m_TensorInfos.insert({ inputId, inputTensorInfo});
882*89c4ff92SAndroid Build Coastguard Worker                 inputShapes.push_back(inputTensorInfo.GetShape());
883*89c4ff92SAndroid Build Coastguard Worker             }
884*89c4ff92SAndroid Build Coastguard Worker         }
885*89c4ff92SAndroid Build Coastguard Worker         const auto outputShape = layer->InferOutputShapes(inputShapes)[output];
886*89c4ff92SAndroid Build Coastguard Worker         tensor.SetShape(outputShape);
887*89c4ff92SAndroid Build Coastguard Worker     }
888*89c4ff92SAndroid Build Coastguard Worker     m_TensorInfos.insert({ outputId, tensor});
889*89c4ff92SAndroid Build Coastguard Worker     return tensor;
890*89c4ff92SAndroid Build Coastguard Worker }
891*89c4ff92SAndroid Build Coastguard Worker 
OutputTensorInfoFromShapes(size_t subgraphIndex,size_t operatorIndex,armnn::IConnectableLayer * layer,int output,std::vector<armnn::TensorShape> inputShapes)892*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TfLiteParserImpl::OutputTensorInfoFromShapes(size_t subgraphIndex,
893*89c4ff92SAndroid Build Coastguard Worker                                                                size_t operatorIndex,
894*89c4ff92SAndroid Build Coastguard Worker                                                                armnn::IConnectableLayer* layer,
895*89c4ff92SAndroid Build Coastguard Worker                                                                int output,
896*89c4ff92SAndroid Build Coastguard Worker                                                                std::vector<armnn::TensorShape> inputShapes)
897*89c4ff92SAndroid Build Coastguard Worker {
898*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = m_Model->subgraphs[subgraphIndex];
899*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
900*89c4ff92SAndroid Build Coastguard Worker 
901*89c4ff92SAndroid Build Coastguard Worker     uint32_t outputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[output]);
902*89c4ff92SAndroid Build Coastguard Worker     const auto& outputTensorPtr = subgraphPtr->tensors[outputId].get();
903*89c4ff92SAndroid Build Coastguard Worker     TensorInfo tensor = ::armnnTfLiteParser::ToTensorInfo(outputTensorPtr, true);
904*89c4ff92SAndroid Build Coastguard Worker 
905*89c4ff92SAndroid Build Coastguard Worker     if (IsDynamic(outputTensorPtr))
906*89c4ff92SAndroid Build Coastguard Worker     {
907*89c4ff92SAndroid Build Coastguard Worker         const auto outputShape = layer->InferOutputShapes(inputShapes)[output];
908*89c4ff92SAndroid Build Coastguard Worker         tensor.SetShape(outputShape);
909*89c4ff92SAndroid Build Coastguard Worker     }
910*89c4ff92SAndroid Build Coastguard Worker     m_TensorInfos.insert({ outputId, tensor});
911*89c4ff92SAndroid Build Coastguard Worker     return tensor;
912*89c4ff92SAndroid Build Coastguard Worker }
913*89c4ff92SAndroid Build Coastguard Worker 
ResetParser()914*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ResetParser()
915*89c4ff92SAndroid Build Coastguard Worker {
916*89c4ff92SAndroid Build Coastguard Worker     m_Network = armnn::INetworkPtr(nullptr, nullptr);
917*89c4ff92SAndroid Build Coastguard Worker     m_Model = nullptr;
918*89c4ff92SAndroid Build Coastguard Worker     m_SubgraphConnections.clear();
919*89c4ff92SAndroid Build Coastguard Worker     m_OverriddenOutputShapes.clear();
920*89c4ff92SAndroid Build Coastguard Worker     m_ConstantsToDequantize.clear();
921*89c4ff92SAndroid Build Coastguard Worker     m_ConstantsToBeCreated.clear();
922*89c4ff92SAndroid Build Coastguard Worker     m_TensorInfos.clear();
923*89c4ff92SAndroid Build Coastguard Worker }
924*89c4ff92SAndroid Build Coastguard Worker 
CreateNetworkFromBinaryFile(const char * graphFile)925*89c4ff92SAndroid Build Coastguard Worker INetworkPtr TfLiteParserImpl::CreateNetworkFromBinaryFile(const char* graphFile)
926*89c4ff92SAndroid Build Coastguard Worker {
927*89c4ff92SAndroid Build Coastguard Worker     ResetParser();
928*89c4ff92SAndroid Build Coastguard Worker     m_Model = LoadModelFromFile(graphFile);
929*89c4ff92SAndroid Build Coastguard Worker     return CreateNetworkFromModel();
930*89c4ff92SAndroid Build Coastguard Worker }
931*89c4ff92SAndroid Build Coastguard Worker 
CreateNetworkFromBinary(const std::vector<uint8_t> & binaryContent)932*89c4ff92SAndroid Build Coastguard Worker INetworkPtr TfLiteParserImpl::CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent)
933*89c4ff92SAndroid Build Coastguard Worker {
934*89c4ff92SAndroid Build Coastguard Worker     ResetParser();
935*89c4ff92SAndroid Build Coastguard Worker     m_Model = LoadModelFromBinary(binaryContent.data(), binaryContent.size());
936*89c4ff92SAndroid Build Coastguard Worker     return CreateNetworkFromModel();
937*89c4ff92SAndroid Build Coastguard Worker }
938*89c4ff92SAndroid Build Coastguard Worker 
939*89c4ff92SAndroid Build Coastguard Worker 
LoadModel(std::unique_ptr<tflite::ModelT> model)940*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr TfLiteParserImpl::LoadModel(std::unique_ptr<tflite::ModelT> model)
941*89c4ff92SAndroid Build Coastguard Worker {
942*89c4ff92SAndroid Build Coastguard Worker     ResetParser();
943*89c4ff92SAndroid Build Coastguard Worker     m_Model = std::move(model);
944*89c4ff92SAndroid Build Coastguard Worker 
945*89c4ff92SAndroid Build Coastguard Worker     return CreateNetworkFromModel();
946*89c4ff92SAndroid Build Coastguard Worker }
947*89c4ff92SAndroid Build Coastguard Worker 
CreateNetworkFromModel()948*89c4ff92SAndroid Build Coastguard Worker INetworkPtr TfLiteParserImpl::CreateNetworkFromModel()
949*89c4ff92SAndroid Build Coastguard Worker {
950*89c4ff92SAndroid Build Coastguard Worker 
951*89c4ff92SAndroid Build Coastguard Worker     using NetworkOptions = std::vector<BackendOptions>;
952*89c4ff92SAndroid Build Coastguard Worker     NetworkOptions networkOptions = {};
953*89c4ff92SAndroid Build Coastguard Worker     if (m_Options)
954*89c4ff92SAndroid Build Coastguard Worker     {
955*89c4ff92SAndroid Build Coastguard Worker         if (m_Options.value().m_InferAndValidate)
956*89c4ff92SAndroid Build Coastguard Worker         {
957*89c4ff92SAndroid Build Coastguard Worker             BackendOptions shapeInferenceMethodOption("ShapeInferenceMethod",
958*89c4ff92SAndroid Build Coastguard Worker                                                       {
959*89c4ff92SAndroid Build Coastguard Worker                                                           { "InferAndValidate", true }
960*89c4ff92SAndroid Build Coastguard Worker                                                       });
961*89c4ff92SAndroid Build Coastguard Worker 
962*89c4ff92SAndroid Build Coastguard Worker             networkOptions.push_back(shapeInferenceMethodOption);
963*89c4ff92SAndroid Build Coastguard Worker         }
964*89c4ff92SAndroid Build Coastguard Worker         if (m_Options.value().m_AllowExpandedDims)
965*89c4ff92SAndroid Build Coastguard Worker         {
966*89c4ff92SAndroid Build Coastguard Worker             BackendOptions shapeInferenceMethodOption("AllowExpandedDims",
967*89c4ff92SAndroid Build Coastguard Worker                                                       {
968*89c4ff92SAndroid Build Coastguard Worker                                                           { "AllowExpandedDims", true }
969*89c4ff92SAndroid Build Coastguard Worker                                                       });
970*89c4ff92SAndroid Build Coastguard Worker 
971*89c4ff92SAndroid Build Coastguard Worker             networkOptions.push_back(shapeInferenceMethodOption);
972*89c4ff92SAndroid Build Coastguard Worker         }
973*89c4ff92SAndroid Build Coastguard Worker     }
974*89c4ff92SAndroid Build Coastguard Worker     m_Network = INetwork::Create(networkOptions);
975*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(m_Model.get() != nullptr);
976*89c4ff92SAndroid Build Coastguard Worker 
977*89c4ff92SAndroid Build Coastguard Worker     if (m_Model->subgraphs.size() != 1)
978*89c4ff92SAndroid Build Coastguard Worker     {
979*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
980*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Current TfLite parser only supports 1 subgraph. Current one has: {} {}",
981*89c4ff92SAndroid Build Coastguard Worker                             m_Model->subgraphs.size(),
982*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
983*89c4ff92SAndroid Build Coastguard Worker     }
984*89c4ff92SAndroid Build Coastguard Worker 
985*89c4ff92SAndroid Build Coastguard Worker     size_t subgraphIndex = 0;
986*89c4ff92SAndroid Build Coastguard Worker     size_t operatorIndex = 0;
987*89c4ff92SAndroid Build Coastguard Worker     try
988*89c4ff92SAndroid Build Coastguard Worker     {
989*89c4ff92SAndroid Build Coastguard Worker         for (SubgraphPtr const& subgraph : m_Model->subgraphs)
990*89c4ff92SAndroid Build Coastguard Worker         {
991*89c4ff92SAndroid Build Coastguard Worker             SetupInputLayerTensorInfos(subgraphIndex);
992*89c4ff92SAndroid Build Coastguard Worker             SetupConstantLayerTensorInfos(subgraphIndex);
993*89c4ff92SAndroid Build Coastguard Worker 
994*89c4ff92SAndroid Build Coastguard Worker             m_SubgraphConnections.emplace_back(subgraph->tensors.size());
995*89c4ff92SAndroid Build Coastguard Worker             for (OperatorPtr const& op : subgraph->operators)
996*89c4ff92SAndroid Build Coastguard Worker             {
997*89c4ff92SAndroid Build Coastguard Worker                 auto const& opCodePtr = m_Model->operator_codes[op->opcode_index];
998*89c4ff92SAndroid Build Coastguard Worker 
999*89c4ff92SAndroid Build Coastguard Worker // work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner
1000*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_3)
1001*89c4ff92SAndroid Build Coastguard Worker                 auto builtinCode = std::max(opCodePtr->builtin_code,
1002*89c4ff92SAndroid Build Coastguard Worker                         static_cast<tflite::BuiltinOperator>(opCodePtr->deprecated_builtin_code));
1003*89c4ff92SAndroid Build Coastguard Worker #else
1004*89c4ff92SAndroid Build Coastguard Worker                 auto builtinCode = opCodePtr->builtin_code;
1005*89c4ff92SAndroid Build Coastguard Worker #endif
1006*89c4ff92SAndroid Build Coastguard Worker 
1007*89c4ff92SAndroid Build Coastguard Worker                 if (builtinCode > tflite::BuiltinOperator_MAX)
1008*89c4ff92SAndroid Build Coastguard Worker                 {
1009*89c4ff92SAndroid Build Coastguard Worker                     throw ParseException(fmt::format("Operator code {} is out of range 0-{}. "
1010*89c4ff92SAndroid Build Coastguard Worker                                                      "subgraph:{} operator idx:{}. {}",
1011*89c4ff92SAndroid Build Coastguard Worker                                                      builtinCode, tflite::BuiltinOperator_MAX, subgraphIndex,
1012*89c4ff92SAndroid Build Coastguard Worker                                                      operatorIndex, CHECK_LOCATION().AsString()));
1013*89c4ff92SAndroid Build Coastguard Worker                 }
1014*89c4ff92SAndroid Build Coastguard Worker 
1015*89c4ff92SAndroid Build Coastguard Worker                 // lookup and call the parser function
1016*89c4ff92SAndroid Build Coastguard Worker                 auto& parserFunction = m_ParserFunctions[builtinCode];
1017*89c4ff92SAndroid Build Coastguard Worker                 (this->*parserFunction)(subgraphIndex, operatorIndex);
1018*89c4ff92SAndroid Build Coastguard Worker                 ++operatorIndex;
1019*89c4ff92SAndroid Build Coastguard Worker             }
1020*89c4ff92SAndroid Build Coastguard Worker 
1021*89c4ff92SAndroid Build Coastguard Worker             SetupInputLayers(subgraphIndex);
1022*89c4ff92SAndroid Build Coastguard Worker             SetupOutputLayers(subgraphIndex);
1023*89c4ff92SAndroid Build Coastguard Worker             SetupConstantLayers(subgraphIndex);
1024*89c4ff92SAndroid Build Coastguard Worker 
1025*89c4ff92SAndroid Build Coastguard Worker             ++subgraphIndex;
1026*89c4ff92SAndroid Build Coastguard Worker             operatorIndex = 0;
1027*89c4ff92SAndroid Build Coastguard Worker         }
1028*89c4ff92SAndroid Build Coastguard Worker     }
1029*89c4ff92SAndroid Build Coastguard Worker     catch (const ParseException& e)
1030*89c4ff92SAndroid Build Coastguard Worker     {
1031*89c4ff92SAndroid Build Coastguard Worker         std::stringstream errorString;
1032*89c4ff92SAndroid Build Coastguard Worker         errorString << "Failed to parse operator #" << operatorIndex << " within subgraph #"
1033*89c4ff92SAndroid Build Coastguard Worker                     << subgraphIndex << " error: " << e.what();
1034*89c4ff92SAndroid Build Coastguard Worker         ARMNN_LOG(error) << errorString.str();
1035*89c4ff92SAndroid Build Coastguard Worker         std::stringstream errors;
1036*89c4ff92SAndroid Build Coastguard Worker         errors << errorString.str() << "\n";
1037*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(errors.str());
1038*89c4ff92SAndroid Build Coastguard Worker     }
1039*89c4ff92SAndroid Build Coastguard Worker 
1040*89c4ff92SAndroid Build Coastguard Worker     // establish the connections from the layer outputs to the inputs of the subsequent layers
1041*89c4ff92SAndroid Build Coastguard Worker     for (subgraphIndex = 0; subgraphIndex < m_SubgraphConnections.size(); ++subgraphIndex)
1042*89c4ff92SAndroid Build Coastguard Worker     {
1043*89c4ff92SAndroid Build Coastguard Worker         for (size_t tensorIndex = 0; tensorIndex < m_SubgraphConnections[subgraphIndex].size(); ++tensorIndex)
1044*89c4ff92SAndroid Build Coastguard Worker         {
1045*89c4ff92SAndroid Build Coastguard Worker             if (m_SubgraphConnections[subgraphIndex][tensorIndex].outputSlot != nullptr)
1046*89c4ff92SAndroid Build Coastguard Worker             {
1047*89c4ff92SAndroid Build Coastguard Worker                 for (size_t inputSlotIdx = 0;
1048*89c4ff92SAndroid Build Coastguard Worker                     inputSlotIdx < m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots.size();
1049*89c4ff92SAndroid Build Coastguard Worker                     ++inputSlotIdx)
1050*89c4ff92SAndroid Build Coastguard Worker                 {
1051*89c4ff92SAndroid Build Coastguard Worker                     m_SubgraphConnections[subgraphIndex][tensorIndex].outputSlot->Connect(
1052*89c4ff92SAndroid Build Coastguard Worker                         *(m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots[inputSlotIdx]));
1053*89c4ff92SAndroid Build Coastguard Worker                 }
1054*89c4ff92SAndroid Build Coastguard Worker             }
1055*89c4ff92SAndroid Build Coastguard Worker         }
1056*89c4ff92SAndroid Build Coastguard Worker     }
1057*89c4ff92SAndroid Build Coastguard Worker     return std::move(m_Network);
1058*89c4ff92SAndroid Build Coastguard Worker }
1059*89c4ff92SAndroid Build Coastguard Worker 
ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,armnn::DataType inputDataType,armnn::DataType tensorDataType)1060*89c4ff92SAndroid Build Coastguard Worker bool TfLiteParserImpl::ShouldConstantTensorBeConverted(TfLiteParserImpl::TensorRawPtr tensorPtr,
1061*89c4ff92SAndroid Build Coastguard Worker                                                        armnn::DataType inputDataType,
1062*89c4ff92SAndroid Build Coastguard Worker                                                        armnn::DataType tensorDataType)
1063*89c4ff92SAndroid Build Coastguard Worker {
1064*89c4ff92SAndroid Build Coastguard Worker     return (TfLiteParserImpl::IsConstTensor(tensorPtr) && inputDataType == DataType::Float32 &&
1065*89c4ff92SAndroid Build Coastguard Worker             (tensorDataType == DataType::QAsymmU8 ||
1066*89c4ff92SAndroid Build Coastguard Worker              tensorDataType == DataType::QAsymmS8 ||
1067*89c4ff92SAndroid Build Coastguard Worker              tensorDataType == DataType::QSymmS8 ||
1068*89c4ff92SAndroid Build Coastguard Worker              tensorDataType == DataType::Signed32 ||
1069*89c4ff92SAndroid Build Coastguard Worker              tensorDataType == DataType::Signed64));
1070*89c4ff92SAndroid Build Coastguard Worker }
1071*89c4ff92SAndroid Build Coastguard Worker 
RegisterProducerOfTensor(size_t subgraphIndex,size_t tensorIndex,armnn::IOutputSlot * slot)1072*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::RegisterProducerOfTensor(size_t subgraphIndex,
1073*89c4ff92SAndroid Build Coastguard Worker                                                 size_t tensorIndex,
1074*89c4ff92SAndroid Build Coastguard Worker                                                 armnn::IOutputSlot* slot)
1075*89c4ff92SAndroid Build Coastguard Worker {
1076*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR(m_Model, subgraphIndex, tensorIndex);
1077*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(m_SubgraphConnections.size() > subgraphIndex);
1078*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(m_SubgraphConnections[subgraphIndex].size() > tensorIndex);
1079*89c4ff92SAndroid Build Coastguard Worker 
1080*89c4ff92SAndroid Build Coastguard Worker     TensorSlots & tensorSlots = m_SubgraphConnections[subgraphIndex][tensorIndex];
1081*89c4ff92SAndroid Build Coastguard Worker 
1082*89c4ff92SAndroid Build Coastguard Worker     if (slot->GetOwningIConnectableLayer().GetType() != LayerType::Constant)
1083*89c4ff92SAndroid Build Coastguard Worker     {
1084*89c4ff92SAndroid Build Coastguard Worker 
1085*89c4ff92SAndroid Build Coastguard Worker         // assuming there is only one producer for that tensor
1086*89c4ff92SAndroid Build Coastguard Worker         if (tensorSlots.outputSlot != nullptr)
1087*89c4ff92SAndroid Build Coastguard Worker         {
1088*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(fmt::format("Another layer has already registered itself as the producer of "
1089*89c4ff92SAndroid Build Coastguard Worker                                              "subgraph:{} tensor:{} {}",
1090*89c4ff92SAndroid Build Coastguard Worker                                              subgraphIndex,
1091*89c4ff92SAndroid Build Coastguard Worker                                              tensorIndex,
1092*89c4ff92SAndroid Build Coastguard Worker                                              CHECK_LOCATION().AsString()));
1093*89c4ff92SAndroid Build Coastguard Worker         }
1094*89c4ff92SAndroid Build Coastguard Worker     }
1095*89c4ff92SAndroid Build Coastguard Worker     tensorSlots.outputSlot = slot;
1096*89c4ff92SAndroid Build Coastguard Worker }
1097*89c4ff92SAndroid Build Coastguard Worker 
RegisterConsumerOfTensor(size_t subgraphIndex,size_t tensorIndex,armnn::IInputSlot * slot)1098*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::RegisterConsumerOfTensor(size_t subgraphIndex,
1099*89c4ff92SAndroid Build Coastguard Worker                                                 size_t tensorIndex,
1100*89c4ff92SAndroid Build Coastguard Worker                                                 armnn::IInputSlot* slot)
1101*89c4ff92SAndroid Build Coastguard Worker {
1102*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR(m_Model, subgraphIndex, tensorIndex);
1103*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(m_SubgraphConnections.size() > subgraphIndex);
1104*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(m_SubgraphConnections[subgraphIndex].size() > tensorIndex);
1105*89c4ff92SAndroid Build Coastguard Worker 
1106*89c4ff92SAndroid Build Coastguard Worker     TensorSlots& tensorSlots = m_SubgraphConnections[subgraphIndex][tensorIndex];
1107*89c4ff92SAndroid Build Coastguard Worker     tensorSlots.inputSlots.push_back(slot);
1108*89c4ff92SAndroid Build Coastguard Worker }
1109*89c4ff92SAndroid Build Coastguard Worker 
ParseCustomOperator(size_t subgraphIndex,size_t operatorIndex)1110*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseCustomOperator(size_t subgraphIndex, size_t operatorIndex)
1111*89c4ff92SAndroid Build Coastguard Worker {
1112*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1113*89c4ff92SAndroid Build Coastguard Worker 
1114*89c4ff92SAndroid Build Coastguard Worker     // NOTE: By default we presume the custom operator is not supported
1115*89c4ff92SAndroid Build Coastguard Worker     auto customParserFunction = &TfLiteParserImpl::ParseUnsupportedOperator;
1116*89c4ff92SAndroid Build Coastguard Worker 
1117*89c4ff92SAndroid Build Coastguard Worker     // Identify custom code defined for custom operator
1118*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1119*89c4ff92SAndroid Build Coastguard Worker     const auto& customCode  = m_Model->operator_codes[operatorPtr->opcode_index]->custom_code;
1120*89c4ff92SAndroid Build Coastguard Worker 
1121*89c4ff92SAndroid Build Coastguard Worker     // Find parser function that corresponds to custom code (if any)
1122*89c4ff92SAndroid Build Coastguard Worker     auto iterator = m_CustomParserFunctions.find(customCode);
1123*89c4ff92SAndroid Build Coastguard Worker     if (iterator != m_CustomParserFunctions.end())
1124*89c4ff92SAndroid Build Coastguard Worker     {
1125*89c4ff92SAndroid Build Coastguard Worker         customParserFunction = iterator->second;
1126*89c4ff92SAndroid Build Coastguard Worker     }
1127*89c4ff92SAndroid Build Coastguard Worker 
1128*89c4ff92SAndroid Build Coastguard Worker     // Run parser function
1129*89c4ff92SAndroid Build Coastguard Worker     (this->*customParserFunction)(subgraphIndex, operatorIndex);
1130*89c4ff92SAndroid Build Coastguard Worker }
1131*89c4ff92SAndroid Build Coastguard Worker 
ParseUnsupportedOperator(size_t subgraphIndex,size_t operatorIndex)1132*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseUnsupportedOperator(size_t subgraphIndex, size_t operatorIndex)
1133*89c4ff92SAndroid Build Coastguard Worker {
1134*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1135*89c4ff92SAndroid Build Coastguard Worker 
1136*89c4ff92SAndroid Build Coastguard Worker     const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1137*89c4ff92SAndroid Build Coastguard Worker 
1138*89c4ff92SAndroid Build Coastguard Worker     auto opcodeIndex = operatorPtr->opcode_index;
1139*89c4ff92SAndroid Build Coastguard Worker 
1140*89c4ff92SAndroid Build Coastguard Worker // work around the introduction of the deprecated_builtin_code introduced in 2.4 in a backwards compatible manner
1141*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_3)
1142*89c4ff92SAndroid Build Coastguard Worker     auto opcode      = std::max(m_Model->operator_codes[opcodeIndex]->builtin_code,
1143*89c4ff92SAndroid Build Coastguard Worker             static_cast<tflite::BuiltinOperator>(m_Model->operator_codes[opcodeIndex]->deprecated_builtin_code));
1144*89c4ff92SAndroid Build Coastguard Worker #else
1145*89c4ff92SAndroid Build Coastguard Worker     auto opcode      = m_Model->operator_codes[opcodeIndex]->builtin_code;
1146*89c4ff92SAndroid Build Coastguard Worker #endif
1147*89c4ff92SAndroid Build Coastguard Worker 
1148*89c4ff92SAndroid Build Coastguard Worker     if (!m_Options || !m_Options.value().m_StandInLayerForUnsupported)
1149*89c4ff92SAndroid Build Coastguard Worker     {
1150*89c4ff92SAndroid Build Coastguard Worker         // Do not add StandInLayer, throw ParseException instead
1151*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
1152*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Operator not supported. "
1153*89c4ff92SAndroid Build Coastguard Worker                         "subgraph:{} operator:{} "
1154*89c4ff92SAndroid Build Coastguard Worker                         "opcode_index:{} opcode:{} / {} {}",
1155*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
1156*89c4ff92SAndroid Build Coastguard Worker                         operatorIndex,
1157*89c4ff92SAndroid Build Coastguard Worker                         opcodeIndex,
1158*89c4ff92SAndroid Build Coastguard Worker                         opcode,
1159*89c4ff92SAndroid Build Coastguard Worker                         tflite::EnumNameBuiltinOperator(opcode),
1160*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
1161*89c4ff92SAndroid Build Coastguard Worker     }
1162*89c4ff92SAndroid Build Coastguard Worker 
1163*89c4ff92SAndroid Build Coastguard Worker     auto inputs  = GetInputs(m_Model, subgraphIndex, operatorIndex);
1164*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1165*89c4ff92SAndroid Build Coastguard Worker 
1166*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numInputs  = armnn::numeric_cast<unsigned int>(inputs.size());
1167*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numOutputs = armnn::numeric_cast<unsigned int>(outputs.size());
1168*89c4ff92SAndroid Build Coastguard Worker 
1169*89c4ff92SAndroid Build Coastguard Worker     StandInDescriptor descriptor(numInputs, numOutputs);
1170*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("StandIn:{}:{}:{}", subgraphIndex, operatorIndex, opcode);
1171*89c4ff92SAndroid Build Coastguard Worker 
1172*89c4ff92SAndroid Build Coastguard Worker     // Add a non-executable StandInLayer as a placeholder for any unsupported operator
1173*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddStandInLayer(descriptor, layerName.c_str());
1174*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1175*89c4ff92SAndroid Build Coastguard Worker 
1176*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0u; i < numOutputs; ++i)
1177*89c4ff92SAndroid Build Coastguard Worker     {
1178*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(i).SetTensorInfo(ToTensorInfo(outputs[0], true));
1179*89c4ff92SAndroid Build Coastguard Worker     }
1180*89c4ff92SAndroid Build Coastguard Worker 
1181*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIds  = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1182*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIds = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1183*89c4ff92SAndroid Build Coastguard Worker 
1184*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, inputTensorIds);
1185*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIds);
1186*89c4ff92SAndroid Build Coastguard Worker }
1187*89c4ff92SAndroid Build Coastguard Worker 
ParseCast(size_t subgraphIndex,size_t operatorIndex)1188*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseCast(size_t subgraphIndex, size_t operatorIndex)
1189*89c4ff92SAndroid Build Coastguard Worker {
1190*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1191*89c4ff92SAndroid Build Coastguard Worker 
1192*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1193*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
1194*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1195*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1196*89c4ff92SAndroid Build Coastguard Worker 
1197*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Cast:{}:{}", subgraphIndex, operatorIndex);
1198*89c4ff92SAndroid Build Coastguard Worker 
1199*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddCastLayer(layerName.c_str());
1200*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1201*89c4ff92SAndroid Build Coastguard Worker 
1202*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
1203*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1204*89c4ff92SAndroid Build Coastguard Worker 
1205*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1206*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1207*89c4ff92SAndroid Build Coastguard Worker 
1208*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1209*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
1210*89c4ff92SAndroid Build Coastguard Worker }
1211*89c4ff92SAndroid Build Coastguard Worker 
ParseConv2D(size_t subgraphIndex,size_t operatorIndex)1212*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseConv2D(size_t subgraphIndex, size_t operatorIndex)
1213*89c4ff92SAndroid Build Coastguard Worker {
1214*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1215*89c4ff92SAndroid Build Coastguard Worker 
1216*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1217*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsConv2DOptions();
1218*89c4ff92SAndroid Build Coastguard Worker 
1219*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
1220*89c4ff92SAndroid Build Coastguard Worker 
1221*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1222*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1223*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1224*89c4ff92SAndroid Build Coastguard Worker 
1225*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor desc;
1226*89c4ff92SAndroid Build Coastguard Worker     inputs.size() == 3 ?
1227*89c4ff92SAndroid Build Coastguard Worker         desc.m_BiasEnabled = true : desc.m_BiasEnabled = false;
1228*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = CHECKED_NON_NEGATIVE(options->stride_w);
1229*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h);
1230*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
1231*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
1232*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
1233*89c4ff92SAndroid Build Coastguard Worker 
1234*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1235*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo filterTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1236*89c4ff92SAndroid Build Coastguard Worker 
1237*89c4ff92SAndroid Build Coastguard Worker     // assuming input is NHWC
1238*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = inputTensorInfo.GetShape()[1];
1239*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth = inputTensorInfo.GetShape()[2];
1240*89c4ff92SAndroid Build Coastguard Worker 
1241*89c4ff92SAndroid Build Coastguard Worker     // assuming the filter is OHWI : Output, H, W, Input
1242*89c4ff92SAndroid Build Coastguard Worker     // which is essentially the same as NHWC
1243*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterHeight = filterTensorInfo.GetShape()[1];
1244*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterWidth = filterTensorInfo.GetShape()[2];
1245*89c4ff92SAndroid Build Coastguard Worker 
1246*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
1247*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
1248*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
1249*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
1250*89c4ff92SAndroid Build Coastguard Worker 
1251*89c4ff92SAndroid Build Coastguard Worker     // Add the first input and weights tensor to the registration list.
1252*89c4ff92SAndroid Build Coastguard Worker     // The constant weights will be added by SetupConstantLayers.
1253*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1254*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> tensorIndexesToRegister = { inputTensorIndexes[0], inputTensorIndexes[1] };
1255*89c4ff92SAndroid Build Coastguard Worker 
1256*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Conv2D:{}:{}", subgraphIndex, operatorIndex);
1257*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = m_Network->AddConvolution2dLayer(desc, layerName.c_str());
1258*89c4ff92SAndroid Build Coastguard Worker 
1259*89c4ff92SAndroid Build Coastguard Worker     if (ShouldConstantTensorBeConverted(inputs[1], inputTensorInfo.GetDataType(), filterTensorInfo.GetDataType()))
1260*89c4ff92SAndroid Build Coastguard Worker     {
1261*89c4ff92SAndroid Build Coastguard Worker         m_ConstantsToDequantize.emplace_back(inputs[1]->buffer);
1262*89c4ff92SAndroid Build Coastguard Worker     }
1263*89c4ff92SAndroid Build Coastguard Worker 
1264*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_BiasEnabled)
1265*89c4ff92SAndroid Build Coastguard Worker     {
1266*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo biasTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
1267*89c4ff92SAndroid Build Coastguard Worker 
1268*89c4ff92SAndroid Build Coastguard Worker         // Add the biases input to the registration list, a constant layer will be added by SetupConstantLayers.
1269*89c4ff92SAndroid Build Coastguard Worker         tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
1270*89c4ff92SAndroid Build Coastguard Worker 
1271*89c4ff92SAndroid Build Coastguard Worker         if (ShouldConstantTensorBeConverted(inputs[2], inputTensorInfo.GetDataType(), biasTensorInfo.GetDataType()))
1272*89c4ff92SAndroid Build Coastguard Worker         {
1273*89c4ff92SAndroid Build Coastguard Worker             m_ConstantsToDequantize.emplace_back(inputs[2]->buffer);
1274*89c4ff92SAndroid Build Coastguard Worker         }
1275*89c4ff92SAndroid Build Coastguard Worker     }
1276*89c4ff92SAndroid Build Coastguard Worker 
1277*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1278*89c4ff92SAndroid Build Coastguard Worker 
1279*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
1280*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1281*89c4ff92SAndroid Build Coastguard Worker 
1282*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
1283*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
1284*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, tensorIndexesToRegister);
1285*89c4ff92SAndroid Build Coastguard Worker 
1286*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
1287*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
1288*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1289*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, { outputTensorIndexes[0] });
1290*89c4ff92SAndroid Build Coastguard Worker }
1291*89c4ff92SAndroid Build Coastguard Worker 
1292*89c4ff92SAndroid Build Coastguard Worker // Conv3D support was added in TF 2.5, so for backwards compatibility a hash define is needed.
1293*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_4)
ParseConv3D(size_t subgraphIndex,size_t operatorIndex)1294*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseConv3D(size_t subgraphIndex, size_t operatorIndex)
1295*89c4ff92SAndroid Build Coastguard Worker {
1296*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1297*89c4ff92SAndroid Build Coastguard Worker 
1298*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1299*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsConv3DOptions();
1300*89c4ff92SAndroid Build Coastguard Worker 
1301*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
1302*89c4ff92SAndroid Build Coastguard Worker 
1303*89c4ff92SAndroid Build Coastguard Worker     Convolution3dDescriptor desc;
1304*89c4ff92SAndroid Build Coastguard Worker     desc.m_BiasEnabled = false;
1305*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NDHWC;
1306*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = CHECKED_NON_NEGATIVE(options->stride_w);
1307*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h);
1308*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideZ = CHECKED_NON_NEGATIVE(options->stride_d);
1309*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
1310*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
1311*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationZ = CHECKED_NON_NEGATIVE(options->dilation_d_factor);
1312*89c4ff92SAndroid Build Coastguard Worker 
1313*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1314*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2, 3);
1315*89c4ff92SAndroid Build Coastguard Worker 
1316*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1317*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1318*89c4ff92SAndroid Build Coastguard Worker 
1319*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1320*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo filterTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1321*89c4ff92SAndroid Build Coastguard Worker 
1322*89c4ff92SAndroid Build Coastguard Worker     // Assuming input is NDHWC
1323*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputDepth  = inputTensorInfo.GetShape()[1];
1324*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = inputTensorInfo.GetShape()[2];
1325*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth  = inputTensorInfo.GetShape()[3];
1326*89c4ff92SAndroid Build Coastguard Worker 
1327*89c4ff92SAndroid Build Coastguard Worker     // Assuming the filter is DHWIO : Depth, Height, Width, OutputChannels, InputChannels
1328*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterDepth  = filterTensorInfo.GetShape()[0];
1329*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterHeight = filterTensorInfo.GetShape()[1];
1330*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterWidth  = filterTensorInfo.GetShape()[2];
1331*89c4ff92SAndroid Build Coastguard Worker 
1332*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputDepth, filterDepth, desc.m_StrideZ,
1333*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationZ, desc.m_PadFront, desc.m_PadBack, options->padding);
1334*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
1335*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
1336*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
1337*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
1338*89c4ff92SAndroid Build Coastguard Worker 
1339*89c4ff92SAndroid Build Coastguard Worker     auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo, inputTensorInfo.GetDataType());
1340*89c4ff92SAndroid Build Coastguard Worker 
1341*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Conv3D:{}:{}", subgraphIndex, operatorIndex);
1342*89c4ff92SAndroid Build Coastguard Worker 
1343*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1344*89c4ff92SAndroid Build Coastguard Worker     // Add the first input and weights tensor to the registration list.
1345*89c4ff92SAndroid Build Coastguard Worker     // The constant weights will be added by SetupConstantLayers.
1346*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0], inputTensorIndexes[1]};
1347*89c4ff92SAndroid Build Coastguard Worker 
1348*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() == 3)
1349*89c4ff92SAndroid Build Coastguard Worker     {
1350*89c4ff92SAndroid Build Coastguard Worker         desc.m_BiasEnabled = true;
1351*89c4ff92SAndroid Build Coastguard Worker 
1352*89c4ff92SAndroid Build Coastguard Worker         // Add the biases input to the registration list, a constant layer will be added by SetupConstantLayers.
1353*89c4ff92SAndroid Build Coastguard Worker         tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
1354*89c4ff92SAndroid Build Coastguard Worker     }
1355*89c4ff92SAndroid Build Coastguard Worker 
1356*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = m_Network->AddConvolution3dLayer(desc, layerName.c_str());
1357*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1358*89c4ff92SAndroid Build Coastguard Worker 
1359*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
1360*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1361*89c4ff92SAndroid Build Coastguard Worker 
1362*89c4ff92SAndroid Build Coastguard Worker     // Register the input connection slots for the layer, connections are made after all layers have been created
1363*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, tensorIndexesToRegister);
1364*89c4ff92SAndroid Build Coastguard Worker 
1365*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
1366*89c4ff92SAndroid Build Coastguard Worker     // Register the output connection slots for the layer, connections are made after all layers have been created
1367*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1368*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1369*89c4ff92SAndroid Build Coastguard Worker }
1370*89c4ff92SAndroid Build Coastguard Worker #endif
1371*89c4ff92SAndroid Build Coastguard Worker 
ParseDepthwiseConv2D(size_t subgraphIndex,size_t operatorIndex)1372*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseDepthwiseConv2D(size_t subgraphIndex, size_t operatorIndex)
1373*89c4ff92SAndroid Build Coastguard Worker {
1374*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1375*89c4ff92SAndroid Build Coastguard Worker 
1376*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1377*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsDepthwiseConv2DOptions();
1378*89c4ff92SAndroid Build Coastguard Worker 
1379*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
1380*89c4ff92SAndroid Build Coastguard Worker 
1381*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dDescriptor desc;
1382*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = CHECKED_NON_NEGATIVE(options->stride_w);
1383*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h);
1384*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
1385*89c4ff92SAndroid Build Coastguard Worker     CHECKED_NON_NEGATIVE(options->depth_multiplier);
1386*89c4ff92SAndroid Build Coastguard Worker 
1387*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1388*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2, 3);
1389*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() == 3)
1390*89c4ff92SAndroid Build Coastguard Worker     {
1391*89c4ff92SAndroid Build Coastguard Worker         desc.m_BiasEnabled = true;
1392*89c4ff92SAndroid Build Coastguard Worker     }
1393*89c4ff92SAndroid Build Coastguard Worker 
1394*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1395*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1396*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationX = CHECKED_NON_NEGATIVE(options->dilation_w_factor);
1397*89c4ff92SAndroid Build Coastguard Worker     desc.m_DilationY = CHECKED_NON_NEGATIVE(options->dilation_h_factor);
1398*89c4ff92SAndroid Build Coastguard Worker 
1399*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1400*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo filterTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1401*89c4ff92SAndroid Build Coastguard Worker 
1402*89c4ff92SAndroid Build Coastguard Worker     // Assuming input is NHWC
1403*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = inputTensorInfo.GetShape()[1];
1404*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth  = inputTensorInfo.GetShape()[2];
1405*89c4ff92SAndroid Build Coastguard Worker 
1406*89c4ff92SAndroid Build Coastguard Worker     // TensorflowLite weights come in the format [1, H, W, I * M]
1407*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterHeight = filterTensorInfo.GetShape()[1];
1408*89c4ff92SAndroid Build Coastguard Worker     unsigned int filterWidth  = filterTensorInfo.GetShape()[2];
1409*89c4ff92SAndroid Build Coastguard Worker 
1410*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight, filterHeight, desc.m_StrideY,
1411*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationY, desc.m_PadTop, desc.m_PadBottom, options->padding);
1412*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth, filterWidth, desc.m_StrideX,
1413*89c4ff92SAndroid Build Coastguard Worker                 desc.m_DilationX, desc.m_PadLeft, desc.m_PadRight, options->padding);
1414*89c4ff92SAndroid Build Coastguard Worker 
1415*89c4ff92SAndroid Build Coastguard Worker     // ArmNN uses the same filter tensor layout at TfLite [1, H, W, O] no need for any permutation
1416*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("DepthwiseConv2D:{}:{}", subgraphIndex, operatorIndex);
1417*89c4ff92SAndroid Build Coastguard Worker 
1418*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1419*89c4ff92SAndroid Build Coastguard Worker     // Add the first input and weights tensor to the registration list.
1420*89c4ff92SAndroid Build Coastguard Worker     // The constant weights will be added by SetupConstantLayers.
1421*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0], inputTensorIndexes[1]};
1422*89c4ff92SAndroid Build Coastguard Worker 
1423*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = m_Network->AddDepthwiseConvolution2dLayer(desc, layerName.c_str());
1424*89c4ff92SAndroid Build Coastguard Worker 
1425*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_BiasEnabled)
1426*89c4ff92SAndroid Build Coastguard Worker     {
1427*89c4ff92SAndroid Build Coastguard Worker         desc.m_BiasEnabled = true;
1428*89c4ff92SAndroid Build Coastguard Worker         TensorInfo biasTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
1429*89c4ff92SAndroid Build Coastguard Worker 
1430*89c4ff92SAndroid Build Coastguard Worker         // Add the biases input to the registration list, a constant layer will be added by SetupConstantLayers.
1431*89c4ff92SAndroid Build Coastguard Worker         tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
1432*89c4ff92SAndroid Build Coastguard Worker     }
1433*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1434*89c4ff92SAndroid Build Coastguard Worker 
1435*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
1436*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1437*89c4ff92SAndroid Build Coastguard Worker 
1438*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
1439*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
1440*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, tensorIndexesToRegister);
1441*89c4ff92SAndroid Build Coastguard Worker 
1442*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
1443*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
1444*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1445*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1446*89c4ff92SAndroid Build Coastguard Worker }
1447*89c4ff92SAndroid Build Coastguard Worker 
ParseDequantize(size_t subgraphIndex,size_t operatorIndex)1448*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseDequantize(size_t subgraphIndex, size_t operatorIndex)
1449*89c4ff92SAndroid Build Coastguard Worker {
1450*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1451*89c4ff92SAndroid Build Coastguard Worker 
1452*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1453*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
1454*89c4ff92SAndroid Build Coastguard Worker 
1455*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1456*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1457*89c4ff92SAndroid Build Coastguard Worker 
1458*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Dequantize:{}:{}", subgraphIndex, operatorIndex);
1459*89c4ff92SAndroid Build Coastguard Worker 
1460*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddDequantizeLayer(layerName.c_str());
1461*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1462*89c4ff92SAndroid Build Coastguard Worker 
1463*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
1464*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1465*89c4ff92SAndroid Build Coastguard Worker 
1466*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1467*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1468*89c4ff92SAndroid Build Coastguard Worker 
1469*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1470*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
1471*89c4ff92SAndroid Build Coastguard Worker }
1472*89c4ff92SAndroid Build Coastguard Worker 
ParseExpandDims(size_t subgraphIndex,size_t operatorIndex)1473*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseExpandDims(size_t subgraphIndex, size_t operatorIndex)
1474*89c4ff92SAndroid Build Coastguard Worker {
1475*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1476*89c4ff92SAndroid Build Coastguard Worker 
1477*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1478*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
1479*89c4ff92SAndroid Build Coastguard Worker 
1480*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1481*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1482*89c4ff92SAndroid Build Coastguard Worker 
1483*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("ExpandDims:{}:{}", subgraphIndex, operatorIndex);
1484*89c4ff92SAndroid Build Coastguard Worker 
1485*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1486*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
1487*89c4ff92SAndroid Build Coastguard Worker 
1488*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
1489*89c4ff92SAndroid Build Coastguard Worker 
1490*89c4ff92SAndroid Build Coastguard Worker     ReshapeDescriptor reshapeDesc;
1491*89c4ff92SAndroid Build Coastguard Worker 
1492*89c4ff92SAndroid Build Coastguard Worker     if (outputTensorInfo.GetShape().AreAllDimensionsSpecified())
1493*89c4ff92SAndroid Build Coastguard Worker     {
1494*89c4ff92SAndroid Build Coastguard Worker         reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
1495*89c4ff92SAndroid Build Coastguard Worker     }
1496*89c4ff92SAndroid Build Coastguard Worker     else
1497*89c4ff92SAndroid Build Coastguard Worker     {
1498*89c4ff92SAndroid Build Coastguard Worker         int32_t axis = inputs[1]->shape[0];
1499*89c4ff92SAndroid Build Coastguard Worker 
1500*89c4ff92SAndroid Build Coastguard Worker         int32_t inputDimSize = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions());
1501*89c4ff92SAndroid Build Coastguard Worker 
1502*89c4ff92SAndroid Build Coastguard Worker         if (axis > inputDimSize || axis < 0 - (inputDimSize + 1))
1503*89c4ff92SAndroid Build Coastguard Worker         {
1504*89c4ff92SAndroid Build Coastguard Worker             throw ParseException("axis must be in range [0 - (inputDimSize + 1), inputDimSize] inclusive");
1505*89c4ff92SAndroid Build Coastguard Worker         }
1506*89c4ff92SAndroid Build Coastguard Worker 
1507*89c4ff92SAndroid Build Coastguard Worker         if(axis < 0)
1508*89c4ff92SAndroid Build Coastguard Worker         {
1509*89c4ff92SAndroid Build Coastguard Worker             axis = inputDimSize + axis + 1;
1510*89c4ff92SAndroid Build Coastguard Worker         }
1511*89c4ff92SAndroid Build Coastguard Worker 
1512*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int> shape(static_cast<unsigned int>(inputDimSize) + 1);
1513*89c4ff92SAndroid Build Coastguard Worker         unsigned int inputShapeIndex = 0;
1514*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < static_cast<unsigned int>(inputDimSize + 1); ++i)
1515*89c4ff92SAndroid Build Coastguard Worker         {
1516*89c4ff92SAndroid Build Coastguard Worker             if (i == static_cast<unsigned int>(axis))
1517*89c4ff92SAndroid Build Coastguard Worker             {
1518*89c4ff92SAndroid Build Coastguard Worker                 shape[i] = 1;
1519*89c4ff92SAndroid Build Coastguard Worker             }
1520*89c4ff92SAndroid Build Coastguard Worker             else
1521*89c4ff92SAndroid Build Coastguard Worker             {
1522*89c4ff92SAndroid Build Coastguard Worker                 shape[i] = inputTensorInfo.GetShape()[inputShapeIndex];
1523*89c4ff92SAndroid Build Coastguard Worker                 ++inputShapeIndex;
1524*89c4ff92SAndroid Build Coastguard Worker             }
1525*89c4ff92SAndroid Build Coastguard Worker         }
1526*89c4ff92SAndroid Build Coastguard Worker 
1527*89c4ff92SAndroid Build Coastguard Worker         reshapeDesc.m_TargetShape = TensorShape(static_cast<unsigned int>(inputDimSize + 1), shape.data());
1528*89c4ff92SAndroid Build Coastguard Worker     }
1529*89c4ff92SAndroid Build Coastguard Worker 
1530*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
1531*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1532*89c4ff92SAndroid Build Coastguard Worker 
1533*89c4ff92SAndroid Build Coastguard Worker     reshapeDesc.m_TargetShape = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0}).GetShape();
1534*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo.SetShape(reshapeDesc.m_TargetShape);
1535*89c4ff92SAndroid Build Coastguard Worker 
1536*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1537*89c4ff92SAndroid Build Coastguard Worker 
1538*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1539*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1540*89c4ff92SAndroid Build Coastguard Worker 
1541*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1542*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1543*89c4ff92SAndroid Build Coastguard Worker }
1544*89c4ff92SAndroid Build Coastguard Worker 
ParseTranspose(size_t subgraphIndex,size_t operatorIndex)1545*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseTranspose(size_t subgraphIndex, size_t operatorIndex)
1546*89c4ff92SAndroid Build Coastguard Worker {
1547*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1548*89c4ff92SAndroid Build Coastguard Worker 
1549*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1550*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1, 2);
1551*89c4ff92SAndroid Build Coastguard Worker 
1552*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1553*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1554*89c4ff92SAndroid Build Coastguard Worker 
1555*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Transpose:{}:{}", subgraphIndex, operatorIndex);
1556*89c4ff92SAndroid Build Coastguard Worker     TransposeDescriptor desc;
1557*89c4ff92SAndroid Build Coastguard Worker 
1558*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() == 2)
1559*89c4ff92SAndroid Build Coastguard Worker     {
1560*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo permuteTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1561*89c4ff92SAndroid Build Coastguard Worker         BufferRawPtr permuteBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
1562*89c4ff92SAndroid Build Coastguard Worker         auto numPermVecElements = permuteTensorInfo.GetNumElements();
1563*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int> permuteShape(numPermVecElements);
1564*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(permuteShape.data(), permuteBufferPtr->data.data(), permuteTensorInfo.GetNumBytes());
1565*89c4ff92SAndroid Build Coastguard Worker         PermutationVector permutationVector(permuteShape.data(), permuteTensorInfo.GetNumElements());
1566*89c4ff92SAndroid Build Coastguard Worker 
1567*89c4ff92SAndroid Build Coastguard Worker         desc = TransposeDescriptor(permutationVector);
1568*89c4ff92SAndroid Build Coastguard Worker     }
1569*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1570*89c4ff92SAndroid Build Coastguard Worker 
1571*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddTransposeLayer(desc, layerName.c_str());
1572*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1573*89c4ff92SAndroid Build Coastguard Worker 
1574*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
1575*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
1576*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1577*89c4ff92SAndroid Build Coastguard Worker 
1578*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1579*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1580*89c4ff92SAndroid Build Coastguard Worker 
1581*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1582*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1583*89c4ff92SAndroid Build Coastguard Worker }
1584*89c4ff92SAndroid Build Coastguard Worker 
ParseTransposeConv(size_t subgraphIndex,size_t operatorIndex)1585*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseTransposeConv(size_t subgraphIndex, size_t operatorIndex)
1586*89c4ff92SAndroid Build Coastguard Worker {
1587*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1588*89c4ff92SAndroid Build Coastguard Worker 
1589*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1590*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsTransposeConvOptions();
1591*89c4ff92SAndroid Build Coastguard Worker 
1592*89c4ff92SAndroid Build Coastguard Worker     TransposeConvolution2dDescriptor desc;
1593*89c4ff92SAndroid Build Coastguard Worker     desc.m_BiasEnabled = false;
1594*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = CHECKED_NON_NEGATIVE(options->stride_w);
1595*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h);
1596*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
1597*89c4ff92SAndroid Build Coastguard Worker 
1598*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1599*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() == 4)
1600*89c4ff92SAndroid Build Coastguard Worker     {
1601*89c4ff92SAndroid Build Coastguard Worker         desc.m_BiasEnabled = true;
1602*89c4ff92SAndroid Build Coastguard Worker     }
1603*89c4ff92SAndroid Build Coastguard Worker     else
1604*89c4ff92SAndroid Build Coastguard Worker     {
1605*89c4ff92SAndroid Build Coastguard Worker         CHECK_VALID_SIZE(inputs.size(), 3);
1606*89c4ff92SAndroid Build Coastguard Worker     }
1607*89c4ff92SAndroid Build Coastguard Worker 
1608*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1609*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1610*89c4ff92SAndroid Build Coastguard Worker 
1611*89c4ff92SAndroid Build Coastguard Worker     // This block determines the output shape of the transpose convolution. If the output shape tensor ptr is not null
1612*89c4ff92SAndroid Build Coastguard Worker     // And the tensor is a constant, we can access the data at load time and set the output shape of the
1613*89c4ff92SAndroid Build Coastguard Worker     // layer. If this is not constant, We do not have access to the shape data, so we have to use
1614*89c4ff92SAndroid Build Coastguard Worker     // infer output shape and skip this code block.
1615*89c4ff92SAndroid Build Coastguard Worker     if (inputs[0] && IsConstTensor(inputs[0]))
1616*89c4ff92SAndroid Build Coastguard Worker     {
1617*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo tensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1618*89c4ff92SAndroid Build Coastguard Worker         std::vector<int> output_shape(tensorInfo.GetNumElements());
1619*89c4ff92SAndroid Build Coastguard Worker 
1620*89c4ff92SAndroid Build Coastguard Worker         if (tensorInfo.GetDataType() == DataType::Signed32)
1621*89c4ff92SAndroid Build Coastguard Worker         {
1622*89c4ff92SAndroid Build Coastguard Worker             ::memcpy(output_shape.data(), GetBuffer(m_Model, inputs[0]->buffer)->data.data(), tensorInfo.GetNumBytes());
1623*89c4ff92SAndroid Build Coastguard Worker         }
1624*89c4ff92SAndroid Build Coastguard Worker         if (tensorInfo.GetDataType() == DataType::QAsymmU8)
1625*89c4ff92SAndroid Build Coastguard Worker         {
1626*89c4ff92SAndroid Build Coastguard Worker             for(unsigned int i=0; i < tensorInfo.GetNumElements(); i++)
1627*89c4ff92SAndroid Build Coastguard Worker             {
1628*89c4ff92SAndroid Build Coastguard Worker                 output_shape[i] = GetBuffer(m_Model, inputs[0]->buffer)->data.data()[i];
1629*89c4ff92SAndroid Build Coastguard Worker             }
1630*89c4ff92SAndroid Build Coastguard Worker         }
1631*89c4ff92SAndroid Build Coastguard Worker         // Change from signed to unsigned int to store in TransposeConvolution2dDescriptor.
1632*89c4ff92SAndroid Build Coastguard Worker         for (int dimension : output_shape)
1633*89c4ff92SAndroid Build Coastguard Worker         {
1634*89c4ff92SAndroid Build Coastguard Worker             desc.m_OutputShape.push_back(static_cast<unsigned int>(dimension));
1635*89c4ff92SAndroid Build Coastguard Worker         }
1636*89c4ff92SAndroid Build Coastguard Worker         desc.m_OutputShapeEnabled = true;
1637*89c4ff92SAndroid Build Coastguard Worker     }
1638*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 2);
1639*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo filterTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1640*89c4ff92SAndroid Build Coastguard Worker 
1641*89c4ff92SAndroid Build Coastguard Worker     // TfLite uses NHWC tensors
1642*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputHeight = inputTensorInfo.GetShape()[1];
1643*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputWidth  = inputTensorInfo.GetShape()[2];
1644*89c4ff92SAndroid Build Coastguard Worker 
1645*89c4ff92SAndroid Build Coastguard Worker     const unsigned int filterHeight = filterTensorInfo.GetShape()[1];
1646*89c4ff92SAndroid Build Coastguard Worker     const unsigned int filterWidth  = filterTensorInfo.GetShape()[2];
1647*89c4ff92SAndroid Build Coastguard Worker 
1648*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight,
1649*89c4ff92SAndroid Build Coastguard Worker                 filterHeight,
1650*89c4ff92SAndroid Build Coastguard Worker                 desc.m_StrideY,
1651*89c4ff92SAndroid Build Coastguard Worker                 1, // DilationY
1652*89c4ff92SAndroid Build Coastguard Worker                 desc.m_PadTop,
1653*89c4ff92SAndroid Build Coastguard Worker                 desc.m_PadBottom,
1654*89c4ff92SAndroid Build Coastguard Worker                 options->padding);
1655*89c4ff92SAndroid Build Coastguard Worker 
1656*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth,
1657*89c4ff92SAndroid Build Coastguard Worker                 filterWidth,
1658*89c4ff92SAndroid Build Coastguard Worker                 desc.m_StrideX,
1659*89c4ff92SAndroid Build Coastguard Worker                 1, // DilationX
1660*89c4ff92SAndroid Build Coastguard Worker                 desc.m_PadLeft,
1661*89c4ff92SAndroid Build Coastguard Worker                 desc.m_PadRight,
1662*89c4ff92SAndroid Build Coastguard Worker                 options->padding);
1663*89c4ff92SAndroid Build Coastguard Worker 
1664*89c4ff92SAndroid Build Coastguard Worker     auto filterTensorAndData = CreateConstTensorNonPermuted(inputs[1], filterTensorInfo, inputTensorInfo.GetDataType());
1665*89c4ff92SAndroid Build Coastguard Worker 
1666*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = nullptr;
1667*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("TransposeConv:{}:{}", subgraphIndex, operatorIndex);
1668*89c4ff92SAndroid Build Coastguard Worker 
1669*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_BiasEnabled)
1670*89c4ff92SAndroid Build Coastguard Worker     {
1671*89c4ff92SAndroid Build Coastguard Worker         auto biasTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 3);
1672*89c4ff92SAndroid Build Coastguard Worker         auto biasConstTensor = CreateConstTensorNonPermuted(inputs[3], biasTensorInfo, inputTensorInfo.GetDataType());
1673*89c4ff92SAndroid Build Coastguard Worker         layer = m_Network->AddTransposeConvolution2dLayer(desc,
1674*89c4ff92SAndroid Build Coastguard Worker                                                           filterTensorAndData.first,
1675*89c4ff92SAndroid Build Coastguard Worker                                                           biasConstTensor.first,
1676*89c4ff92SAndroid Build Coastguard Worker                                                           layerName.c_str());
1677*89c4ff92SAndroid Build Coastguard Worker     }
1678*89c4ff92SAndroid Build Coastguard Worker     else
1679*89c4ff92SAndroid Build Coastguard Worker     {
1680*89c4ff92SAndroid Build Coastguard Worker         layer = m_Network->AddTransposeConvolution2dLayer(desc,
1681*89c4ff92SAndroid Build Coastguard Worker                                                           filterTensorAndData.first,
1682*89c4ff92SAndroid Build Coastguard Worker                                                           EmptyOptional(),
1683*89c4ff92SAndroid Build Coastguard Worker                                                           layerName.c_str());
1684*89c4ff92SAndroid Build Coastguard Worker     }
1685*89c4ff92SAndroid Build Coastguard Worker 
1686*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1687*89c4ff92SAndroid Build Coastguard Worker 
1688*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0 , { 2, 1 });
1689*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1690*89c4ff92SAndroid Build Coastguard Worker 
1691*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const (filter) tensor
1692*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1693*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[2]});
1694*89c4ff92SAndroid Build Coastguard Worker 
1695*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1696*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1697*89c4ff92SAndroid Build Coastguard Worker }
1698*89c4ff92SAndroid Build Coastguard Worker 
ParseAveragePool2D(size_t subgraphIndex,size_t operatorIndex)1699*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseAveragePool2D(size_t subgraphIndex, size_t operatorIndex)
1700*89c4ff92SAndroid Build Coastguard Worker {
1701*89c4ff92SAndroid Build Coastguard Worker     ParsePool(subgraphIndex, operatorIndex, PoolingAlgorithm::Average);
1702*89c4ff92SAndroid Build Coastguard Worker }
1703*89c4ff92SAndroid Build Coastguard Worker 
ParseBatchMatMul(size_t subgraphIndex,size_t operatorIndex)1704*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseBatchMatMul(size_t subgraphIndex, size_t operatorIndex)
1705*89c4ff92SAndroid Build Coastguard Worker {
1706*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1707*89c4ff92SAndroid Build Coastguard Worker 
1708*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1709*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
1710*89c4ff92SAndroid Build Coastguard Worker 
1711*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1712*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1713*89c4ff92SAndroid Build Coastguard Worker 
1714*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("BatchMatMul:{}:{}", subgraphIndex, operatorIndex);
1715*89c4ff92SAndroid Build Coastguard Worker 
1716*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputXTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1717*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputYTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1718*89c4ff92SAndroid Build Coastguard Worker 
1719*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1720*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsBatchMatMulOptions();
1721*89c4ff92SAndroid Build Coastguard Worker 
1722*89c4ff92SAndroid Build Coastguard Worker     // Adjoint in tensorflow lite performs transpose operation
1723*89c4ff92SAndroid Build Coastguard Worker     BatchMatMulDescriptor descriptor(options->adj_x,
1724*89c4ff92SAndroid Build Coastguard Worker                                      options->adj_y,
1725*89c4ff92SAndroid Build Coastguard Worker                                      false,
1726*89c4ff92SAndroid Build Coastguard Worker                                      false);
1727*89c4ff92SAndroid Build Coastguard Worker                                      // Arbitrary DataLayout
1728*89c4ff92SAndroid Build Coastguard Worker 
1729*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddBatchMatMulLayer(descriptor, layerName.c_str());
1730*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1731*89c4ff92SAndroid Build Coastguard Worker 
1732*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
1733*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1734*89c4ff92SAndroid Build Coastguard Worker 
1735*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1736*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
1737*89c4ff92SAndroid Build Coastguard Worker 
1738*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1739*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1740*89c4ff92SAndroid Build Coastguard Worker }
1741*89c4ff92SAndroid Build Coastguard Worker 
ParseBatchToSpaceND(size_t subgraphIndex,size_t operatorIndex)1742*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseBatchToSpaceND(size_t subgraphIndex, size_t operatorIndex)
1743*89c4ff92SAndroid Build Coastguard Worker {
1744*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1745*89c4ff92SAndroid Build Coastguard Worker 
1746*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1747*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 3);
1748*89c4ff92SAndroid Build Coastguard Worker 
1749*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1750*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1751*89c4ff92SAndroid Build Coastguard Worker 
1752*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo blockShapeTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1753*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr blockShapeBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
1754*89c4ff92SAndroid Build Coastguard Worker 
1755*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo cropsTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
1756*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr cropsBufferPtr = GetBuffer(m_Model, inputs[2]->buffer);
1757*89c4ff92SAndroid Build Coastguard Worker 
1758*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
1759*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(blockShape.data(), blockShapeBufferPtr->data.data(), blockShapeTensorInfo.GetNumBytes());
1760*89c4ff92SAndroid Build Coastguard Worker 
1761*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> cropsVector(cropsTensorInfo.GetNumElements());
1762*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(cropsVector.data(), cropsBufferPtr->data.data(), cropsTensorInfo.GetNumBytes());
1763*89c4ff92SAndroid Build Coastguard Worker 
1764*89c4ff92SAndroid Build Coastguard Worker     size_t step = 2;
1765*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::pair<unsigned int, unsigned int>> crops;
1766*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < cropsTensorInfo.GetNumElements() / step; ++i)
1767*89c4ff92SAndroid Build Coastguard Worker     {
1768*89c4ff92SAndroid Build Coastguard Worker         crops.emplace_back(cropsVector[i * step], cropsVector[i * step + 1]);
1769*89c4ff92SAndroid Build Coastguard Worker     }
1770*89c4ff92SAndroid Build Coastguard Worker 
1771*89c4ff92SAndroid Build Coastguard Worker     armnn::BatchToSpaceNdDescriptor desc;
1772*89c4ff92SAndroid Build Coastguard Worker     desc.m_BlockShape = blockShape;
1773*89c4ff92SAndroid Build Coastguard Worker     desc.m_Crops = crops;
1774*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
1775*89c4ff92SAndroid Build Coastguard Worker 
1776*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("BatchToSpaceND:{}:{}", subgraphIndex, operatorIndex);
1777*89c4ff92SAndroid Build Coastguard Worker 
1778*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1779*89c4ff92SAndroid Build Coastguard Worker 
1780*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddBatchToSpaceNdLayer(desc, layerName.c_str());
1781*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1782*89c4ff92SAndroid Build Coastguard Worker 
1783*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
1784*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
1785*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1786*89c4ff92SAndroid Build Coastguard Worker 
1787*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1788*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1789*89c4ff92SAndroid Build Coastguard Worker 
1790*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1791*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1792*89c4ff92SAndroid Build Coastguard Worker }
1793*89c4ff92SAndroid Build Coastguard Worker 
ParseL2Normalization(size_t subgraphIndex,size_t operatorIndex)1794*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseL2Normalization(size_t subgraphIndex, size_t operatorIndex)
1795*89c4ff92SAndroid Build Coastguard Worker {
1796*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1797*89c4ff92SAndroid Build Coastguard Worker 
1798*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1799*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
1800*89c4ff92SAndroid Build Coastguard Worker 
1801*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1802*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1803*89c4ff92SAndroid Build Coastguard Worker 
1804*89c4ff92SAndroid Build Coastguard Worker     L2NormalizationDescriptor desc;
1805*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
1806*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("L2Normalization:{}:{}", subgraphIndex, operatorIndex);
1807*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddL2NormalizationLayer(desc, layerName.c_str());
1808*89c4ff92SAndroid Build Coastguard Worker 
1809*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1810*89c4ff92SAndroid Build Coastguard Worker 
1811*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
1812*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1813*89c4ff92SAndroid Build Coastguard Worker 
1814*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1815*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1816*89c4ff92SAndroid Build Coastguard Worker 
1817*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1818*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1819*89c4ff92SAndroid Build Coastguard Worker }
1820*89c4ff92SAndroid Build Coastguard Worker 
ParseMaxPool2D(size_t subgraphIndex,size_t operatorIndex)1821*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseMaxPool2D(size_t subgraphIndex, size_t operatorIndex)
1822*89c4ff92SAndroid Build Coastguard Worker {
1823*89c4ff92SAndroid Build Coastguard Worker     ParsePool(subgraphIndex, operatorIndex, PoolingAlgorithm::Max);
1824*89c4ff92SAndroid Build Coastguard Worker }
1825*89c4ff92SAndroid Build Coastguard Worker 
ParseMaximum(size_t subgraphIndex,size_t operatorIndex)1826*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseMaximum(size_t subgraphIndex, size_t operatorIndex)
1827*89c4ff92SAndroid Build Coastguard Worker {
1828*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1829*89c4ff92SAndroid Build Coastguard Worker 
1830*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1831*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
1832*89c4ff92SAndroid Build Coastguard Worker 
1833*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1834*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1835*89c4ff92SAndroid Build Coastguard Worker 
1836*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Maximum:{}:{}", subgraphIndex, operatorIndex);
1837*89c4ff92SAndroid Build Coastguard Worker 
1838*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1839*89c4ff92SAndroid Build Coastguard Worker     TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1840*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, input1TensorInfo, layerName, "Input 0", "Input 1");
1841*89c4ff92SAndroid Build Coastguard Worker 
1842*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Maximum, layerName.c_str());
1843*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1844*89c4ff92SAndroid Build Coastguard Worker 
1845*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
1846*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
1847*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1848*89c4ff92SAndroid Build Coastguard Worker 
1849*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1850*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
1851*89c4ff92SAndroid Build Coastguard Worker 
1852*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1853*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1854*89c4ff92SAndroid Build Coastguard Worker }
1855*89c4ff92SAndroid Build Coastguard Worker 
ParseMinimum(size_t subgraphIndex,size_t operatorIndex)1856*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseMinimum(size_t subgraphIndex, size_t operatorIndex)
1857*89c4ff92SAndroid Build Coastguard Worker {
1858*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1859*89c4ff92SAndroid Build Coastguard Worker 
1860*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1861*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
1862*89c4ff92SAndroid Build Coastguard Worker 
1863*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1864*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1865*89c4ff92SAndroid Build Coastguard Worker 
1866*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Minimum:{}:{}", subgraphIndex, operatorIndex);
1867*89c4ff92SAndroid Build Coastguard Worker 
1868*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1869*89c4ff92SAndroid Build Coastguard Worker     TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1870*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, input1TensorInfo, layerName, "Input 0", "Input 1");
1871*89c4ff92SAndroid Build Coastguard Worker 
1872*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Minimum, layerName.c_str());
1873*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1874*89c4ff92SAndroid Build Coastguard Worker 
1875*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
1876*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
1877*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1878*89c4ff92SAndroid Build Coastguard Worker 
1879*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1880*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
1881*89c4ff92SAndroid Build Coastguard Worker 
1882*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1883*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1884*89c4ff92SAndroid Build Coastguard Worker }
1885*89c4ff92SAndroid Build Coastguard Worker 
ParsePool(size_t subgraphIndex,size_t operatorIndex,PoolingAlgorithm algorithm)1886*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParsePool(size_t subgraphIndex,
1887*89c4ff92SAndroid Build Coastguard Worker                                  size_t operatorIndex,
1888*89c4ff92SAndroid Build Coastguard Worker                                  PoolingAlgorithm algorithm)
1889*89c4ff92SAndroid Build Coastguard Worker {
1890*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1891*89c4ff92SAndroid Build Coastguard Worker 
1892*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
1893*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsPool2DOptions();
1894*89c4ff92SAndroid Build Coastguard Worker 
1895*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
1896*89c4ff92SAndroid Build Coastguard Worker 
1897*89c4ff92SAndroid Build Coastguard Worker     std::string layerName;
1898*89c4ff92SAndroid Build Coastguard Worker 
1899*89c4ff92SAndroid Build Coastguard Worker     switch (algorithm)
1900*89c4ff92SAndroid Build Coastguard Worker     {
1901*89c4ff92SAndroid Build Coastguard Worker         case PoolingAlgorithm::Average:
1902*89c4ff92SAndroid Build Coastguard Worker             layerName =
1903*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("AveragePool2D:{}:{}", subgraphIndex, operatorIndex);
1904*89c4ff92SAndroid Build Coastguard Worker             break;
1905*89c4ff92SAndroid Build Coastguard Worker         case PoolingAlgorithm::Max:
1906*89c4ff92SAndroid Build Coastguard Worker             layerName =
1907*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("MaxPool2D:{}:{}", subgraphIndex, operatorIndex);
1908*89c4ff92SAndroid Build Coastguard Worker             break;
1909*89c4ff92SAndroid Build Coastguard Worker         default:
1910*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "Unsupported Pooling Algorithm");
1911*89c4ff92SAndroid Build Coastguard Worker     }
1912*89c4ff92SAndroid Build Coastguard Worker 
1913*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
1914*89c4ff92SAndroid Build Coastguard Worker 
1915*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolType = algorithm;
1916*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = CHECKED_NON_NEGATIVE(options->stride_w);
1917*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY = CHECKED_NON_NEGATIVE(options->stride_h);
1918*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolWidth = CHECKED_NON_NEGATIVE(options->filter_width);
1919*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolHeight = CHECKED_NON_NEGATIVE(options->filter_height);
1920*89c4ff92SAndroid Build Coastguard Worker     desc.m_PaddingMethod = PaddingMethod::Exclude;
1921*89c4ff92SAndroid Build Coastguard Worker     desc.m_OutputShapeRounding = OutputShapeRounding::Floor;
1922*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
1923*89c4ff92SAndroid Build Coastguard Worker 
1924*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1925*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
1926*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1927*89c4ff92SAndroid Build Coastguard Worker 
1928*89c4ff92SAndroid Build Coastguard Worker     // assuming input is NHWC
1929*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputHeight = inputTensorInfo.GetShape()[1];
1930*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputWidth  = inputTensorInfo.GetShape()[2];
1931*89c4ff92SAndroid Build Coastguard Worker 
1932*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputHeight, desc.m_PoolHeight, desc.m_StrideY, 1u,
1933*89c4ff92SAndroid Build Coastguard Worker                 desc.m_PadTop, desc.m_PadBottom, options->padding);
1934*89c4ff92SAndroid Build Coastguard Worker     CalcPadding(inputWidth, desc.m_PoolWidth, desc.m_StrideX, 1u,
1935*89c4ff92SAndroid Build Coastguard Worker                 desc.m_PadLeft, desc.m_PadRight, options->padding);
1936*89c4ff92SAndroid Build Coastguard Worker 
1937*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1938*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1939*89c4ff92SAndroid Build Coastguard Worker 
1940*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddPooling2dLayer(desc, layerName.c_str());
1941*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
1942*89c4ff92SAndroid Build Coastguard Worker 
1943*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
1944*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
1945*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
1946*89c4ff92SAndroid Build Coastguard Worker 
1947*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
1948*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
1949*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
1950*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
1951*89c4ff92SAndroid Build Coastguard Worker 
1952*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
1953*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
1954*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
1955*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
1956*89c4ff92SAndroid Build Coastguard Worker }
1957*89c4ff92SAndroid Build Coastguard Worker 
ParseSlice(size_t subgraphIndex,size_t operatorIndex)1958*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSlice(size_t subgraphIndex, size_t operatorIndex)
1959*89c4ff92SAndroid Build Coastguard Worker {
1960*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
1961*89c4ff92SAndroid Build Coastguard Worker 
1962*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
1963*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 3);
1964*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
1965*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
1966*89c4ff92SAndroid Build Coastguard Worker 
1967*89c4ff92SAndroid Build Coastguard Worker     SliceDescriptor desc;
1968*89c4ff92SAndroid Build Coastguard Worker 
1969*89c4ff92SAndroid Build Coastguard Worker     // set begin tensor info for slice descriptor
1970*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo beginTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
1971*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr beginBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
1972*89c4ff92SAndroid Build Coastguard Worker 
1973*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> begin(beginTensorInfo.GetNumElements());
1974*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(begin.data(), beginBufferPtr->data.data(), beginTensorInfo.GetNumBytes());
1975*89c4ff92SAndroid Build Coastguard Worker 
1976*89c4ff92SAndroid Build Coastguard Worker     // set size tensor info for slice descriptor
1977*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo sizeTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
1978*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr sizeBufferPtr = GetBuffer(m_Model, inputs[2]->buffer);
1979*89c4ff92SAndroid Build Coastguard Worker 
1980*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> signedSize(sizeTensorInfo.GetNumElements(), 1);
1981*89c4ff92SAndroid Build Coastguard Worker 
1982*89c4ff92SAndroid Build Coastguard Worker     // if size buffer data is not specified, all contents of size vector remain as values of 1
1983*89c4ff92SAndroid Build Coastguard Worker     if (sizeBufferPtr->data.data())
1984*89c4ff92SAndroid Build Coastguard Worker     {
1985*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(signedSize.data(), sizeBufferPtr->data.data(), sizeTensorInfo.GetNumBytes());
1986*89c4ff92SAndroid Build Coastguard Worker     }
1987*89c4ff92SAndroid Build Coastguard Worker 
1988*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> size(sizeTensorInfo.GetNumElements());
1989*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
1990*89c4ff92SAndroid Build Coastguard Worker 
1991*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < signedSize.size(); ++i)
1992*89c4ff92SAndroid Build Coastguard Worker     {
1993*89c4ff92SAndroid Build Coastguard Worker         int signedValue = signedSize[i];
1994*89c4ff92SAndroid Build Coastguard Worker 
1995*89c4ff92SAndroid Build Coastguard Worker         if (signedValue < -1 || signedValue > static_cast<int>(inputTensorInfo.GetShape()[i] - begin[i]))
1996*89c4ff92SAndroid Build Coastguard Worker         {
1997*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(fmt::format("Invalid value for size {} size must be in range "
1998*89c4ff92SAndroid Build Coastguard Worker                                              "[-1, inputDimSize - begin] [-1, {}] inclusive {}",
1999*89c4ff92SAndroid Build Coastguard Worker                                              signedValue,
2000*89c4ff92SAndroid Build Coastguard Worker                                              inputTensorInfo.GetShape()[i] - begin[i],
2001*89c4ff92SAndroid Build Coastguard Worker                                              CHECK_LOCATION().AsString()));
2002*89c4ff92SAndroid Build Coastguard Worker         }
2003*89c4ff92SAndroid Build Coastguard Worker 
2004*89c4ff92SAndroid Build Coastguard Worker         if (signedValue == -1)
2005*89c4ff92SAndroid Build Coastguard Worker         {
2006*89c4ff92SAndroid Build Coastguard Worker             size[i] = inputTensorInfo.GetShape()[i] - begin[i];
2007*89c4ff92SAndroid Build Coastguard Worker         }
2008*89c4ff92SAndroid Build Coastguard Worker         else
2009*89c4ff92SAndroid Build Coastguard Worker         {
2010*89c4ff92SAndroid Build Coastguard Worker             size[i] = static_cast<unsigned int>(signedValue);
2011*89c4ff92SAndroid Build Coastguard Worker         }
2012*89c4ff92SAndroid Build Coastguard Worker     }
2013*89c4ff92SAndroid Build Coastguard Worker 
2014*89c4ff92SAndroid Build Coastguard Worker     desc = SliceDescriptor(begin, size);
2015*89c4ff92SAndroid Build Coastguard Worker 
2016*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Slice:{}:{}", subgraphIndex, operatorIndex);
2017*89c4ff92SAndroid Build Coastguard Worker 
2018*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const layer = m_Network->AddSliceLayer(desc, layerName.c_str());
2019*89c4ff92SAndroid Build Coastguard Worker 
2020*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2021*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
2022*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2023*89c4ff92SAndroid Build Coastguard Worker 
2024*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
2025*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
2026*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2027*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2028*89c4ff92SAndroid Build Coastguard Worker 
2029*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
2030*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2031*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2032*89c4ff92SAndroid Build Coastguard Worker }
2033*89c4ff92SAndroid Build Coastguard Worker 
ParseSoftmax(size_t subgraphIndex,size_t operatorIndex)2034*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSoftmax(size_t subgraphIndex, size_t operatorIndex)
2035*89c4ff92SAndroid Build Coastguard Worker {
2036*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2037*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2038*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsSoftmaxOptions();
2039*89c4ff92SAndroid Build Coastguard Worker 
2040*89c4ff92SAndroid Build Coastguard Worker     SoftmaxDescriptor desc;
2041*89c4ff92SAndroid Build Coastguard Worker     desc.m_Beta = options->beta;
2042*89c4ff92SAndroid Build Coastguard Worker 
2043*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2044*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2045*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2046*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2047*89c4ff92SAndroid Build Coastguard Worker 
2048*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Softmax:{}:{}", subgraphIndex, operatorIndex);
2049*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const layer = m_Network->AddSoftmaxLayer(desc, layerName.c_str());
2050*89c4ff92SAndroid Build Coastguard Worker 
2051*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2052*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2053*89c4ff92SAndroid Build Coastguard Worker 
2054*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
2055*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
2056*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2057*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2058*89c4ff92SAndroid Build Coastguard Worker 
2059*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
2060*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2061*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2062*89c4ff92SAndroid Build Coastguard Worker }
2063*89c4ff92SAndroid Build Coastguard Worker 
ParseLogSoftmax(size_t subgraphIndex,size_t operatorIndex)2064*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLogSoftmax(size_t subgraphIndex, size_t operatorIndex)
2065*89c4ff92SAndroid Build Coastguard Worker {
2066*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2067*89c4ff92SAndroid Build Coastguard Worker 
2068*89c4ff92SAndroid Build Coastguard Worker     LogSoftmaxDescriptor desc;
2069*89c4ff92SAndroid Build Coastguard Worker 
2070*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2071*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2072*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2073*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2074*89c4ff92SAndroid Build Coastguard Worker 
2075*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("LogSoftmax:{}:{}", subgraphIndex, operatorIndex);
2076*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const layer = m_Network->AddLogSoftmaxLayer(desc, layerName.c_str());
2077*89c4ff92SAndroid Build Coastguard Worker 
2078*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2079*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2080*89c4ff92SAndroid Build Coastguard Worker 
2081*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
2082*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
2083*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2084*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2085*89c4ff92SAndroid Build Coastguard Worker 
2086*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
2087*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2088*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2089*89c4ff92SAndroid Build Coastguard Worker }
2090*89c4ff92SAndroid Build Coastguard Worker 
ParseSpaceToBatchND(size_t subgraphIndex,size_t operatorIndex)2091*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSpaceToBatchND(size_t subgraphIndex, size_t operatorIndex)
2092*89c4ff92SAndroid Build Coastguard Worker {
2093*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2094*89c4ff92SAndroid Build Coastguard Worker 
2095*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2096*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 3);
2097*89c4ff92SAndroid Build Coastguard Worker 
2098*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2099*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2100*89c4ff92SAndroid Build Coastguard Worker 
2101*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo blockShapeTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2102*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr blockShapeBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
2103*89c4ff92SAndroid Build Coastguard Worker 
2104*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo padListTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
2105*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr padListBufferPtr = GetBuffer(m_Model, inputs[2]->buffer);
2106*89c4ff92SAndroid Build Coastguard Worker 
2107*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> blockShape(blockShapeTensorInfo.GetNumElements());
2108*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(blockShape.data(), blockShapeBufferPtr->data.data(), blockShapeTensorInfo.GetNumBytes());
2109*89c4ff92SAndroid Build Coastguard Worker 
2110*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> padListVector(padListTensorInfo.GetNumElements());
2111*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(padListVector.data(), padListBufferPtr->data.data(), padListTensorInfo.GetNumBytes());
2112*89c4ff92SAndroid Build Coastguard Worker 
2113*89c4ff92SAndroid Build Coastguard Worker     size_t step = 2;
2114*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::pair<unsigned int, unsigned int>> padList;
2115*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < padListTensorInfo.GetNumElements() / step; ++i)
2116*89c4ff92SAndroid Build Coastguard Worker     {
2117*89c4ff92SAndroid Build Coastguard Worker         padList.emplace_back(padListVector[i * step], padListVector[i * step + 1]);
2118*89c4ff92SAndroid Build Coastguard Worker     }
2119*89c4ff92SAndroid Build Coastguard Worker 
2120*89c4ff92SAndroid Build Coastguard Worker     armnn::SpaceToBatchNdDescriptor desc;
2121*89c4ff92SAndroid Build Coastguard Worker     desc.m_BlockShape = blockShape;
2122*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadList = padList;
2123*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
2124*89c4ff92SAndroid Build Coastguard Worker 
2125*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("SpaceToBatchND:{}:{}", subgraphIndex, operatorIndex);
2126*89c4ff92SAndroid Build Coastguard Worker 
2127*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2128*89c4ff92SAndroid Build Coastguard Worker 
2129*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddSpaceToBatchNdLayer(desc, layerName.c_str());
2130*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2131*89c4ff92SAndroid Build Coastguard Worker 
2132*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2133*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
2134*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2135*89c4ff92SAndroid Build Coastguard Worker 
2136*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2137*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2138*89c4ff92SAndroid Build Coastguard Worker 
2139*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2140*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2141*89c4ff92SAndroid Build Coastguard Worker }
2142*89c4ff92SAndroid Build Coastguard Worker 
ParseSpaceToDepth(size_t subgraphIndex,size_t operatorIndex)2143*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSpaceToDepth(size_t subgraphIndex, size_t operatorIndex)
2144*89c4ff92SAndroid Build Coastguard Worker {
2145*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2146*89c4ff92SAndroid Build Coastguard Worker 
2147*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2148*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2149*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2150*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2151*89c4ff92SAndroid Build Coastguard Worker 
2152*89c4ff92SAndroid Build Coastguard Worker     armnn::SpaceToDepthDescriptor descriptor;
2153*89c4ff92SAndroid Build Coastguard Worker 
2154*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2155*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsSpaceToDepthOptions();
2156*89c4ff92SAndroid Build Coastguard Worker     auto blockSize = options->block_size;
2157*89c4ff92SAndroid Build Coastguard Worker     if (blockSize < 2)
2158*89c4ff92SAndroid Build Coastguard Worker     {
2159*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
2160*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid block size: {} Block size should be >= 2 {}",
2161*89c4ff92SAndroid Build Coastguard Worker                             blockSize,
2162*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
2163*89c4ff92SAndroid Build Coastguard Worker     }
2164*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BlockSize = armnn::numeric_cast<uint32_t>(blockSize);
2165*89c4ff92SAndroid Build Coastguard Worker 
2166*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("SpaceToDepth:{}:{}", subgraphIndex, operatorIndex);
2167*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddSpaceToDepthLayer(descriptor, layerName.c_str());
2168*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2169*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2170*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2171*89c4ff92SAndroid Build Coastguard Worker 
2172*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2173*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2174*89c4ff92SAndroid Build Coastguard Worker 
2175*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2176*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2177*89c4ff92SAndroid Build Coastguard Worker }
2178*89c4ff92SAndroid Build Coastguard Worker 
OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,const armnn::TensorInfo & inputTensorInfo)2179*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TfLiteParserImpl::OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
2180*89c4ff92SAndroid Build Coastguard Worker                                                          const armnn::TensorInfo& inputTensorInfo)
2181*89c4ff92SAndroid Build Coastguard Worker {
2182*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(squeezeDims.size(), 0, 1, 2, 3, 4);
2183*89c4ff92SAndroid Build Coastguard Worker     static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 };
2184*89c4ff92SAndroid Build Coastguard Worker 
2185*89c4ff92SAndroid Build Coastguard Worker     if (inputTensorInfo.GetNumDimensions() > 4)
2186*89c4ff92SAndroid Build Coastguard Worker     {
2187*89c4ff92SAndroid Build Coastguard Worker         std::stringstream ss;
2188*89c4ff92SAndroid Build Coastguard Worker         ss << "Input tensor has unexpected number of dimensions:" << inputTensorInfo.GetNumDimensions()
2189*89c4ff92SAndroid Build Coastguard Worker            << " shape:" << inputTensorInfo.GetShape() << " "
2190*89c4ff92SAndroid Build Coastguard Worker            << CHECK_LOCATION().AsString();
2191*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(ss.str());
2192*89c4ff92SAndroid Build Coastguard Worker     }
2193*89c4ff92SAndroid Build Coastguard Worker 
2194*89c4ff92SAndroid Build Coastguard Worker     if (squeezeDims.empty())
2195*89c4ff92SAndroid Build Coastguard Worker     {
2196*89c4ff92SAndroid Build Coastguard Worker         squeezeDims.assign(dimensionSequence,
2197*89c4ff92SAndroid Build Coastguard Worker                            dimensionSequence+inputTensorInfo.GetNumDimensions());
2198*89c4ff92SAndroid Build Coastguard Worker     }
2199*89c4ff92SAndroid Build Coastguard Worker 
2200*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t> outputDims;
2201*89c4ff92SAndroid Build Coastguard Worker     for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
2202*89c4ff92SAndroid Build Coastguard Worker     {
2203*89c4ff92SAndroid Build Coastguard Worker         bool skipSqueeze = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end());
2204*89c4ff92SAndroid Build Coastguard Worker         auto currentDimension = inputTensorInfo.GetShape()[i];
2205*89c4ff92SAndroid Build Coastguard Worker         if (skipSqueeze || currentDimension != 1)
2206*89c4ff92SAndroid Build Coastguard Worker         {
2207*89c4ff92SAndroid Build Coastguard Worker             outputDims.push_back(currentDimension);
2208*89c4ff92SAndroid Build Coastguard Worker         }
2209*89c4ff92SAndroid Build Coastguard Worker     }
2210*89c4ff92SAndroid Build Coastguard Worker 
2211*89c4ff92SAndroid Build Coastguard Worker     if (outputDims.size() > 4)
2212*89c4ff92SAndroid Build Coastguard Worker     {
2213*89c4ff92SAndroid Build Coastguard Worker         std::stringstream ss;
2214*89c4ff92SAndroid Build Coastguard Worker         ss << "Output tensor has unexpected number of dimensions:" << inputTensorInfo.GetNumDimensions()
2215*89c4ff92SAndroid Build Coastguard Worker            << " shape:" << inputTensorInfo.GetShape() << " "
2216*89c4ff92SAndroid Build Coastguard Worker            << CHECK_LOCATION().AsString();
2217*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(ss.str());
2218*89c4ff92SAndroid Build Coastguard Worker     }
2219*89c4ff92SAndroid Build Coastguard Worker 
2220*89c4ff92SAndroid Build Coastguard Worker     TensorShape outShape = TensorShape(static_cast<unsigned int>(outputDims.size()),
2221*89c4ff92SAndroid Build Coastguard Worker                                        outputDims.data());
2222*89c4ff92SAndroid Build Coastguard Worker 
2223*89c4ff92SAndroid Build Coastguard Worker     // we need to preserve the tensor type and the quantization data as well
2224*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outTensorInfo = inputTensorInfo;
2225*89c4ff92SAndroid Build Coastguard Worker     outTensorInfo.SetShape(outShape);
2226*89c4ff92SAndroid Build Coastguard Worker 
2227*89c4ff92SAndroid Build Coastguard Worker     return outTensorInfo;
2228*89c4ff92SAndroid Build Coastguard Worker }
2229*89c4ff92SAndroid Build Coastguard Worker 
ParseShape(size_t subgraphIndex,size_t operatorIndex)2230*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseShape(size_t subgraphIndex, size_t operatorIndex)
2231*89c4ff92SAndroid Build Coastguard Worker {
2232*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2233*89c4ff92SAndroid Build Coastguard Worker 
2234*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2235*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2236*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2237*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2238*89c4ff92SAndroid Build Coastguard Worker 
2239*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Shape:{}:{}", subgraphIndex, operatorIndex);
2240*89c4ff92SAndroid Build Coastguard Worker 
2241*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddShapeLayer(layerName.c_str());
2242*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2243*89c4ff92SAndroid Build Coastguard Worker 
2244*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2245*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2246*89c4ff92SAndroid Build Coastguard Worker 
2247*89c4ff92SAndroid Build Coastguard Worker     // Check if output tensor type is Signed32 or Signed64
2248*89c4ff92SAndroid Build Coastguard Worker     if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 &&
2249*89c4ff92SAndroid Build Coastguard Worker         outputTensorInfo.GetDataType() != armnn::DataType::Signed64)
2250*89c4ff92SAndroid Build Coastguard Worker     {
2251*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
2252*89c4ff92SAndroid Build Coastguard Worker             fmt::format(
2253*89c4ff92SAndroid Build Coastguard Worker                 "Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}",
2254*89c4ff92SAndroid Build Coastguard Worker                 CHECK_LOCATION().AsString()));
2255*89c4ff92SAndroid Build Coastguard Worker     }
2256*89c4ff92SAndroid Build Coastguard Worker 
2257*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2258*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2259*89c4ff92SAndroid Build Coastguard Worker 
2260*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2261*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
2262*89c4ff92SAndroid Build Coastguard Worker }
2263*89c4ff92SAndroid Build Coastguard Worker 
ParseSqueeze(size_t subgraphIndex,size_t operatorIndex)2264*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSqueeze(size_t subgraphIndex, size_t operatorIndex)
2265*89c4ff92SAndroid Build Coastguard Worker {
2266*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2267*89c4ff92SAndroid Build Coastguard Worker 
2268*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2269*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2270*89c4ff92SAndroid Build Coastguard Worker 
2271*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2272*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2273*89c4ff92SAndroid Build Coastguard Worker 
2274*89c4ff92SAndroid Build Coastguard Worker     const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2275*89c4ff92SAndroid Build Coastguard Worker     const auto * options = operatorPtr->builtin_options.AsSqueezeOptions();
2276*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Squeeze:{}:{}", subgraphIndex, operatorIndex);
2277*89c4ff92SAndroid Build Coastguard Worker 
2278*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2279*89c4ff92SAndroid Build Coastguard Worker 
2280*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t> squeezeDim;
2281*89c4ff92SAndroid Build Coastguard Worker     // A single negative dim index is interpreted as a negative index in python
2282*89c4ff92SAndroid Build Coastguard Worker     // Meaning the index will be the shape size plus the negative index value
2283*89c4ff92SAndroid Build Coastguard Worker     if (options->squeeze_dims.size() == 1 && options->squeeze_dims[0] < 0)
2284*89c4ff92SAndroid Build Coastguard Worker     {
2285*89c4ff92SAndroid Build Coastguard Worker         int32_t dim = static_cast<int32_t>(inputTensorInfo.GetShape().GetNumDimensions()) + options->squeeze_dims[0];
2286*89c4ff92SAndroid Build Coastguard Worker         squeezeDim.push_back(static_cast<uint32_t>(dim));
2287*89c4ff92SAndroid Build Coastguard Worker     }
2288*89c4ff92SAndroid Build Coastguard Worker     else
2289*89c4ff92SAndroid Build Coastguard Worker     {
2290*89c4ff92SAndroid Build Coastguard Worker         squeezeDim = AsUnsignedVector(options->squeeze_dims);
2291*89c4ff92SAndroid Build Coastguard Worker     }
2292*89c4ff92SAndroid Build Coastguard Worker 
2293*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = TfLiteParserImpl::OutputShapeOfSqueeze(squeezeDim, inputTensorInfo);
2294*89c4ff92SAndroid Build Coastguard Worker 
2295*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
2296*89c4ff92SAndroid Build Coastguard Worker 
2297*89c4ff92SAndroid Build Coastguard Worker     ReshapeDescriptor reshapeDesc;
2298*89c4ff92SAndroid Build Coastguard Worker     reshapeDesc.m_TargetShape = outputTensorInfo.GetShape();
2299*89c4ff92SAndroid Build Coastguard Worker 
2300*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIds = GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex);
2301*89c4ff92SAndroid Build Coastguard Worker     m_TensorInfos[outputTensorIds[0]] = outputTensorInfo;
2302*89c4ff92SAndroid Build Coastguard Worker 
2303*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
2304*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2305*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2306*89c4ff92SAndroid Build Coastguard Worker 
2307*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2308*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2309*89c4ff92SAndroid Build Coastguard Worker 
2310*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2311*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2312*89c4ff92SAndroid Build Coastguard Worker }
2313*89c4ff92SAndroid Build Coastguard Worker 
ParseStridedSlice(size_t subgraphIndex,size_t operatorIndex)2314*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseStridedSlice(size_t subgraphIndex, size_t operatorIndex)
2315*89c4ff92SAndroid Build Coastguard Worker {
2316*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2317*89c4ff92SAndroid Build Coastguard Worker 
2318*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2319*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 4);
2320*89c4ff92SAndroid Build Coastguard Worker 
2321*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2322*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2323*89c4ff92SAndroid Build Coastguard Worker 
2324*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2325*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsStridedSliceOptions();
2326*89c4ff92SAndroid Build Coastguard Worker 
2327*89c4ff92SAndroid Build Coastguard Worker     StridedSliceDescriptor desc;
2328*89c4ff92SAndroid Build Coastguard Worker     desc.m_BeginMask = options->begin_mask;
2329*89c4ff92SAndroid Build Coastguard Worker     desc.m_EllipsisMask = options->ellipsis_mask;
2330*89c4ff92SAndroid Build Coastguard Worker     desc.m_EndMask = options->end_mask;
2331*89c4ff92SAndroid Build Coastguard Worker     desc.m_NewAxisMask = options->new_axis_mask;
2332*89c4ff92SAndroid Build Coastguard Worker     desc.m_ShrinkAxisMask = options->shrink_axis_mask;
2333*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout = armnn::DataLayout::NHWC;
2334*89c4ff92SAndroid Build Coastguard Worker 
2335*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo beginTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2336*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr beginBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
2337*89c4ff92SAndroid Build Coastguard Worker 
2338*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> begin(beginTensorInfo.GetNumElements());
2339*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(begin.data(), beginBufferPtr->data.data(), beginTensorInfo.GetNumBytes());
2340*89c4ff92SAndroid Build Coastguard Worker 
2341*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo endTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
2342*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr endBufferPtr = GetBuffer(m_Model, inputs[2]->buffer);
2343*89c4ff92SAndroid Build Coastguard Worker 
2344*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> end(endTensorInfo.GetNumElements());
2345*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(end.data(), endBufferPtr->data.data(), endTensorInfo.GetNumBytes());
2346*89c4ff92SAndroid Build Coastguard Worker 
2347*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo strideTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 3);
2348*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr strideBufferPtr = GetBuffer(m_Model, inputs[3]->buffer);
2349*89c4ff92SAndroid Build Coastguard Worker 
2350*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> stride(strideTensorInfo.GetNumElements());
2351*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(stride.data(), strideBufferPtr->data.data(), strideTensorInfo.GetNumBytes());
2352*89c4ff92SAndroid Build Coastguard Worker 
2353*89c4ff92SAndroid Build Coastguard Worker     desc.m_Begin = begin;
2354*89c4ff92SAndroid Build Coastguard Worker     desc.m_End = end;
2355*89c4ff92SAndroid Build Coastguard Worker     desc.m_Stride = stride;
2356*89c4ff92SAndroid Build Coastguard Worker 
2357*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("StridedSlice:{}:{}", subgraphIndex, operatorIndex);
2358*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddStridedSliceLayer(desc, layerName.c_str());
2359*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2360*89c4ff92SAndroid Build Coastguard Worker 
2361*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2362*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2363*89c4ff92SAndroid Build Coastguard Worker 
2364*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2365*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2366*89c4ff92SAndroid Build Coastguard Worker 
2367*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2368*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2369*89c4ff92SAndroid Build Coastguard Worker }
2370*89c4ff92SAndroid Build Coastguard Worker 
ParseSub(size_t subgraphIndex,size_t operatorIndex)2371*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSub(size_t subgraphIndex, size_t operatorIndex)
2372*89c4ff92SAndroid Build Coastguard Worker {
2373*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2374*89c4ff92SAndroid Build Coastguard Worker 
2375*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2376*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsSubOptions();
2377*89c4ff92SAndroid Build Coastguard Worker 
2378*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2379*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2380*89c4ff92SAndroid Build Coastguard Worker 
2381*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2382*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2383*89c4ff92SAndroid Build Coastguard Worker 
2384*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2385*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2386*89c4ff92SAndroid Build Coastguard Worker 
2387*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Sub:{}:{}", subgraphIndex, operatorIndex);
2388*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Sub, layerName.c_str());
2389*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2390*89c4ff92SAndroid Build Coastguard Worker 
2391*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
2392*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2393*89c4ff92SAndroid Build Coastguard Worker 
2394*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2395*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
2396*89c4ff92SAndroid Build Coastguard Worker 
2397*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
2398*89c4ff92SAndroid Build Coastguard Worker 
2399*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2400*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2401*89c4ff92SAndroid Build Coastguard Worker }
2402*89c4ff92SAndroid Build Coastguard Worker 
ParseDiv(size_t subgraphIndex,size_t operatorIndex)2403*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseDiv(size_t subgraphIndex, size_t operatorIndex)
2404*89c4ff92SAndroid Build Coastguard Worker {
2405*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2406*89c4ff92SAndroid Build Coastguard Worker 
2407*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2408*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsDivOptions();
2409*89c4ff92SAndroid Build Coastguard Worker 
2410*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2411*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2412*89c4ff92SAndroid Build Coastguard Worker 
2413*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2414*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2415*89c4ff92SAndroid Build Coastguard Worker 
2416*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2417*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2418*89c4ff92SAndroid Build Coastguard Worker 
2419*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Div:{}:{}", subgraphIndex, operatorIndex);
2420*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Div, layerName.c_str());
2421*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2422*89c4ff92SAndroid Build Coastguard Worker 
2423*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
2424*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2425*89c4ff92SAndroid Build Coastguard Worker 
2426*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2427*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
2428*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
2429*89c4ff92SAndroid Build Coastguard Worker 
2430*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2431*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2432*89c4ff92SAndroid Build Coastguard Worker }
2433*89c4ff92SAndroid Build Coastguard Worker 
ParseFloorDiv(size_t subgraphIndex,size_t operatorIndex)2434*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseFloorDiv(size_t subgraphIndex, size_t operatorIndex)
2435*89c4ff92SAndroid Build Coastguard Worker {
2436*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2437*89c4ff92SAndroid Build Coastguard Worker 
2438*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2439*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2440*89c4ff92SAndroid Build Coastguard Worker 
2441*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2442*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2443*89c4ff92SAndroid Build Coastguard Worker 
2444*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2445*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2446*89c4ff92SAndroid Build Coastguard Worker 
2447*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Div:{}:{}", subgraphIndex, operatorIndex);
2448*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Div, layerName.c_str());
2449*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2450*89c4ff92SAndroid Build Coastguard Worker 
2451*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
2452*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2453*89c4ff92SAndroid Build Coastguard Worker 
2454*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2455*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
2456*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedFloorLayer(layer, 0);
2457*89c4ff92SAndroid Build Coastguard Worker 
2458*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2459*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2460*89c4ff92SAndroid Build Coastguard Worker }
2461*89c4ff92SAndroid Build Coastguard Worker 
ParseAdd(size_t subgraphIndex,size_t operatorIndex)2462*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseAdd(size_t subgraphIndex, size_t operatorIndex)
2463*89c4ff92SAndroid Build Coastguard Worker {
2464*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2465*89c4ff92SAndroid Build Coastguard Worker 
2466*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2467*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsAddOptions();
2468*89c4ff92SAndroid Build Coastguard Worker 
2469*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2470*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2471*89c4ff92SAndroid Build Coastguard Worker 
2472*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2473*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2474*89c4ff92SAndroid Build Coastguard Worker 
2475*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2476*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2477*89c4ff92SAndroid Build Coastguard Worker 
2478*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Add:{}:{}", subgraphIndex, operatorIndex);
2479*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Add, layerName.c_str());
2480*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2481*89c4ff92SAndroid Build Coastguard Worker 
2482*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
2483*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2484*89c4ff92SAndroid Build Coastguard Worker 
2485*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2486*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
2487*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
2488*89c4ff92SAndroid Build Coastguard Worker 
2489*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2490*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2491*89c4ff92SAndroid Build Coastguard Worker }
2492*89c4ff92SAndroid Build Coastguard Worker 
ParseMul(size_t subgraphIndex,size_t operatorIndex)2493*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseMul(size_t subgraphIndex, size_t operatorIndex)
2494*89c4ff92SAndroid Build Coastguard Worker {
2495*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2496*89c4ff92SAndroid Build Coastguard Worker 
2497*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2498*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsMulOptions();
2499*89c4ff92SAndroid Build Coastguard Worker 
2500*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2501*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2502*89c4ff92SAndroid Build Coastguard Worker 
2503*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2504*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2505*89c4ff92SAndroid Build Coastguard Worker 
2506*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2507*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2508*89c4ff92SAndroid Build Coastguard Worker 
2509*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Mul:{}:{}", subgraphIndex, operatorIndex);
2510*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseBinaryLayer(BinaryOperation::Mul, layerName.c_str());
2511*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2512*89c4ff92SAndroid Build Coastguard Worker 
2513*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
2514*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2515*89c4ff92SAndroid Build Coastguard Worker 
2516*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2517*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
2518*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
2519*89c4ff92SAndroid Build Coastguard Worker 
2520*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2521*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2522*89c4ff92SAndroid Build Coastguard Worker }
2523*89c4ff92SAndroid Build Coastguard Worker 
ParseMean(size_t subgraphIndex,size_t operatorIndex)2524*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseMean(size_t subgraphIndex, size_t operatorIndex)
2525*89c4ff92SAndroid Build Coastguard Worker {
2526*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2527*89c4ff92SAndroid Build Coastguard Worker 
2528*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2529*89c4ff92SAndroid Build Coastguard Worker 
2530*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2531*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2532*89c4ff92SAndroid Build Coastguard Worker 
2533*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2534*89c4ff92SAndroid Build Coastguard Worker     TensorInfo dimTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2535*89c4ff92SAndroid Build Coastguard Worker 
2536*89c4ff92SAndroid Build Coastguard Worker     armnn::MeanDescriptor desc;
2537*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
2538*89c4ff92SAndroid Build Coastguard Worker     // Get const axis value from model and set it to descriptor.
2539*89c4ff92SAndroid Build Coastguard Worker     if (axisBufferPtr != nullptr)
2540*89c4ff92SAndroid Build Coastguard Worker     {
2541*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> axisData(dimTensorInfo.GetNumElements());
2542*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(axisData.data(), axisBufferPtr->data.data(), dimTensorInfo.GetNumBytes());
2543*89c4ff92SAndroid Build Coastguard Worker 
2544*89c4ff92SAndroid Build Coastguard Worker         // Convert the axis to unsigned int and remove duplicates.
2545*89c4ff92SAndroid Build Coastguard Worker         auto rank = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
2546*89c4ff92SAndroid Build Coastguard Worker         std::set<unsigned int> uniqueAxis;
2547*89c4ff92SAndroid Build Coastguard Worker         std::transform(axisData.begin(),
2548*89c4ff92SAndroid Build Coastguard Worker                        axisData.end(),
2549*89c4ff92SAndroid Build Coastguard Worker                        std::inserter(uniqueAxis, uniqueAxis.begin()),
2550*89c4ff92SAndroid Build Coastguard Worker                        [rank](int i)->unsigned int{
2551*89c4ff92SAndroid Build Coastguard Worker                            return static_cast<uint32_t>(((i + rank) % rank)); });
2552*89c4ff92SAndroid Build Coastguard Worker         desc.m_Axis.assign(uniqueAxis.begin(), uniqueAxis.end());
2553*89c4ff92SAndroid Build Coastguard Worker     }
2554*89c4ff92SAndroid Build Coastguard Worker     else
2555*89c4ff92SAndroid Build Coastguard Worker     {
2556*89c4ff92SAndroid Build Coastguard Worker         for (uint32_t i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
2557*89c4ff92SAndroid Build Coastguard Worker         {
2558*89c4ff92SAndroid Build Coastguard Worker             desc.m_Axis.push_back(i);
2559*89c4ff92SAndroid Build Coastguard Worker         }
2560*89c4ff92SAndroid Build Coastguard Worker     }
2561*89c4ff92SAndroid Build Coastguard Worker 
2562*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0], true);
2563*89c4ff92SAndroid Build Coastguard Worker 
2564*89c4ff92SAndroid Build Coastguard Worker     desc.m_KeepDims = inputTensorInfo.GetNumDimensions() == outputTensorInfo.GetNumDimensions() ? true : false;
2565*89c4ff92SAndroid Build Coastguard Worker 
2566*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Mean:{}:{}", subgraphIndex, operatorIndex);
2567*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddMeanLayer(desc, layerName.c_str());
2568*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2569*89c4ff92SAndroid Build Coastguard Worker 
2570*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2571*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2572*89c4ff92SAndroid Build Coastguard Worker 
2573*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2574*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2575*89c4ff92SAndroid Build Coastguard Worker 
2576*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2577*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2578*89c4ff92SAndroid Build Coastguard Worker }
2579*89c4ff92SAndroid Build Coastguard Worker 
ParsePad(size_t subgraphIndex,size_t operatorIndex)2580*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParsePad(size_t subgraphIndex, size_t operatorIndex)
2581*89c4ff92SAndroid Build Coastguard Worker {
2582*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2583*89c4ff92SAndroid Build Coastguard Worker 
2584*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2585*89c4ff92SAndroid Build Coastguard Worker 
2586*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2587*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2588*89c4ff92SAndroid Build Coastguard Worker 
2589*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2590*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo padTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2591*89c4ff92SAndroid Build Coastguard Worker 
2592*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> padBuffer = GetUIntBuffer(padTensorInfo, m_Model, inputs[1]->buffer);
2593*89c4ff92SAndroid Build Coastguard Worker 
2594*89c4ff92SAndroid Build Coastguard Worker     size_t step = 2;
2595*89c4ff92SAndroid Build Coastguard Worker     armnn::PadDescriptor desc;
2596*89c4ff92SAndroid Build Coastguard Worker     auto opcode = GetOpCode(m_Model, subgraphIndex, operatorIndex);
2597*89c4ff92SAndroid Build Coastguard Worker 
2598*89c4ff92SAndroid Build Coastguard Worker     if (opcode == tflite::BuiltinOperator_PAD)
2599*89c4ff92SAndroid Build Coastguard Worker     {
2600*89c4ff92SAndroid Build Coastguard Worker         CHECK_VALID_SIZE(inputs.size(), 2);
2601*89c4ff92SAndroid Build Coastguard Worker 
2602*89c4ff92SAndroid Build Coastguard Worker         if (inputTensorInfo.IsQuantized())
2603*89c4ff92SAndroid Build Coastguard Worker         {
2604*89c4ff92SAndroid Build Coastguard Worker             desc.m_PadValue = static_cast<float>(inputTensorInfo.GetQuantizationOffset());
2605*89c4ff92SAndroid Build Coastguard Worker         }
2606*89c4ff92SAndroid Build Coastguard Worker     }
2607*89c4ff92SAndroid Build Coastguard Worker     else if (opcode == tflite::BuiltinOperator_PADV2)
2608*89c4ff92SAndroid Build Coastguard Worker     {
2609*89c4ff92SAndroid Build Coastguard Worker         CHECK_VALID_SIZE(inputs.size(), 3);
2610*89c4ff92SAndroid Build Coastguard Worker 
2611*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo padValueTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
2612*89c4ff92SAndroid Build Coastguard Worker 
2613*89c4ff92SAndroid Build Coastguard Worker         if (padValueTensorInfo.GetNumElements() != 1)
2614*89c4ff92SAndroid Build Coastguard Worker         {
2615*89c4ff92SAndroid Build Coastguard Worker             ARMNN_THROW_PARSE_EXCEPTION("Multiple padding values are not supported in PADV2");
2616*89c4ff92SAndroid Build Coastguard Worker         }
2617*89c4ff92SAndroid Build Coastguard Worker         BufferRawPtr padValueBufferPtr = GetBuffer(m_Model, inputs[2]->buffer);
2618*89c4ff92SAndroid Build Coastguard Worker 
2619*89c4ff92SAndroid Build Coastguard Worker         // Get the pad value from the input tensor
2620*89c4ff92SAndroid Build Coastguard Worker         if (padValueBufferPtr->data.size() > 0)
2621*89c4ff92SAndroid Build Coastguard Worker         {
2622*89c4ff92SAndroid Build Coastguard Worker             switch (padValueTensorInfo.GetDataType())
2623*89c4ff92SAndroid Build Coastguard Worker             {
2624*89c4ff92SAndroid Build Coastguard Worker                 case armnn::DataType::Float32:
2625*89c4ff92SAndroid Build Coastguard Worker                 {
2626*89c4ff92SAndroid Build Coastguard Worker                     std::vector<float> padValueBuffer(padValueTensorInfo.GetNumElements());
2627*89c4ff92SAndroid Build Coastguard Worker                     ::memcpy(padValueBuffer.data(), padValueBufferPtr->data.data(), padValueBufferPtr->data.size());
2628*89c4ff92SAndroid Build Coastguard Worker                     desc.m_PadValue = padValueBuffer[0];
2629*89c4ff92SAndroid Build Coastguard Worker                     break;
2630*89c4ff92SAndroid Build Coastguard Worker                 }
2631*89c4ff92SAndroid Build Coastguard Worker                 case armnn::DataType::QAsymmU8:
2632*89c4ff92SAndroid Build Coastguard Worker                 {
2633*89c4ff92SAndroid Build Coastguard Worker                     std::vector<uint8_t> padValueBuffer(padValueTensorInfo.GetNumElements());
2634*89c4ff92SAndroid Build Coastguard Worker                     ::memcpy(padValueBuffer.data(), padValueBufferPtr->data.data(), padValueBufferPtr->data.size());
2635*89c4ff92SAndroid Build Coastguard Worker                     desc.m_PadValue = armnn::Dequantize<uint8_t>(padValueBuffer[0],
2636*89c4ff92SAndroid Build Coastguard Worker                                                                  padValueTensorInfo.GetQuantizationScale(),
2637*89c4ff92SAndroid Build Coastguard Worker                                                                  padValueTensorInfo.GetQuantizationOffset());
2638*89c4ff92SAndroid Build Coastguard Worker                     break;
2639*89c4ff92SAndroid Build Coastguard Worker                 }
2640*89c4ff92SAndroid Build Coastguard Worker                 case armnn::DataType::QAsymmS8:
2641*89c4ff92SAndroid Build Coastguard Worker                 case armnn::DataType::QSymmS8:
2642*89c4ff92SAndroid Build Coastguard Worker                 {
2643*89c4ff92SAndroid Build Coastguard Worker                     std::vector<int8_t> padValueBuffer(padValueTensorInfo.GetNumElements());
2644*89c4ff92SAndroid Build Coastguard Worker                     ::memcpy(padValueBuffer.data(), padValueBufferPtr->data.data(), padValueBufferPtr->data.size());
2645*89c4ff92SAndroid Build Coastguard Worker                     desc.m_PadValue = armnn::Dequantize<int8_t>(padValueBuffer[0],
2646*89c4ff92SAndroid Build Coastguard Worker                                                                 padValueTensorInfo.GetQuantizationScale(),
2647*89c4ff92SAndroid Build Coastguard Worker                                                                 padValueTensorInfo.GetQuantizationOffset());
2648*89c4ff92SAndroid Build Coastguard Worker                     break;
2649*89c4ff92SAndroid Build Coastguard Worker                 }
2650*89c4ff92SAndroid Build Coastguard Worker                 default: ARMNN_THROW_PARSE_EXCEPTION("Unsupported DataType");
2651*89c4ff92SAndroid Build Coastguard Worker             }
2652*89c4ff92SAndroid Build Coastguard Worker         }
2653*89c4ff92SAndroid Build Coastguard Worker         else if (inputTensorInfo.IsQuantized())
2654*89c4ff92SAndroid Build Coastguard Worker         {
2655*89c4ff92SAndroid Build Coastguard Worker             desc.m_PadValue = static_cast<float>(inputTensorInfo.GetQuantizationOffset());
2656*89c4ff92SAndroid Build Coastguard Worker         }
2657*89c4ff92SAndroid Build Coastguard Worker     }
2658*89c4ff92SAndroid Build Coastguard Worker 
2659*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < padTensorInfo.GetNumElements() / step; ++i)
2660*89c4ff92SAndroid Build Coastguard Worker     {
2661*89c4ff92SAndroid Build Coastguard Worker         desc.m_PadList.emplace_back(padBuffer[i * step], padBuffer[i * step + 1]);
2662*89c4ff92SAndroid Build Coastguard Worker     }
2663*89c4ff92SAndroid Build Coastguard Worker 
2664*89c4ff92SAndroid Build Coastguard Worker     auto layerName = (opcode == tflite::BuiltinOperator_PAD) ? fmt::format("Pad:{}:{}", subgraphIndex, operatorIndex)
2665*89c4ff92SAndroid Build Coastguard Worker             : fmt::format("PadV2:{}:{}", subgraphIndex, operatorIndex);
2666*89c4ff92SAndroid Build Coastguard Worker 
2667*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddPadLayer(desc, layerName.c_str());
2668*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2669*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2670*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2671*89c4ff92SAndroid Build Coastguard Worker 
2672*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2673*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2674*89c4ff92SAndroid Build Coastguard Worker 
2675*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2676*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2677*89c4ff92SAndroid Build Coastguard Worker }
2678*89c4ff92SAndroid Build Coastguard Worker 
ParseMirrorPad(size_t subgraphIndex,size_t operatorIndex)2679*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseMirrorPad(size_t subgraphIndex, size_t operatorIndex)
2680*89c4ff92SAndroid Build Coastguard Worker {
2681*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2682*89c4ff92SAndroid Build Coastguard Worker 
2683*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2684*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2685*89c4ff92SAndroid Build Coastguard Worker 
2686*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2687*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2688*89c4ff92SAndroid Build Coastguard Worker 
2689*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2690*89c4ff92SAndroid Build Coastguard Worker 
2691*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo padTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2692*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
2693*89c4ff92SAndroid Build Coastguard Worker 
2694*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> padBuffer(padTensorInfo.GetNumElements());
2695*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(padBuffer.data(), bufferPtr->data.data(), padTensorInfo.GetNumBytes());
2696*89c4ff92SAndroid Build Coastguard Worker 
2697*89c4ff92SAndroid Build Coastguard Worker     size_t step = 2;
2698*89c4ff92SAndroid Build Coastguard Worker     armnn::PadDescriptor desc;
2699*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < padTensorInfo.GetNumElements() / step; ++i)
2700*89c4ff92SAndroid Build Coastguard Worker     {
2701*89c4ff92SAndroid Build Coastguard Worker         desc.m_PadList.emplace_back(padBuffer[i * step], padBuffer[i * step + 1]);
2702*89c4ff92SAndroid Build Coastguard Worker     }
2703*89c4ff92SAndroid Build Coastguard Worker 
2704*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2705*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsMirrorPadOptions();
2706*89c4ff92SAndroid Build Coastguard Worker 
2707*89c4ff92SAndroid Build Coastguard Worker     if (options->mode == tflite::MirrorPadMode_REFLECT)
2708*89c4ff92SAndroid Build Coastguard Worker     {
2709*89c4ff92SAndroid Build Coastguard Worker         desc.m_PaddingMode = PaddingMode::Reflect;
2710*89c4ff92SAndroid Build Coastguard Worker     }
2711*89c4ff92SAndroid Build Coastguard Worker     else if (options->mode == tflite::MirrorPadMode_SYMMETRIC)
2712*89c4ff92SAndroid Build Coastguard Worker     {
2713*89c4ff92SAndroid Build Coastguard Worker         desc.m_PaddingMode = PaddingMode::Symmetric;
2714*89c4ff92SAndroid Build Coastguard Worker     }
2715*89c4ff92SAndroid Build Coastguard Worker     else
2716*89c4ff92SAndroid Build Coastguard Worker     {
2717*89c4ff92SAndroid Build Coastguard Worker         ARMNN_THROW_PARSE_EXCEPTION("PaddingMode must be either REFLECT or SYMMETRIC");
2718*89c4ff92SAndroid Build Coastguard Worker     }
2719*89c4ff92SAndroid Build Coastguard Worker 
2720*89c4ff92SAndroid Build Coastguard Worker     // If padding mode is Reflect then both paddings must be no greater than inputShape(i) - 1.
2721*89c4ff92SAndroid Build Coastguard Worker     // If padding mode is Symmetric then both paddings must be no greater than inputShape(i).
2722*89c4ff92SAndroid Build Coastguard Worker     auto inputShape = inputTensorInfo.GetShape();
2723*89c4ff92SAndroid Build Coastguard Worker     auto padList = desc.m_PadList;
2724*89c4ff92SAndroid Build Coastguard Worker 
2725*89c4ff92SAndroid Build Coastguard Worker     const unsigned int isReflect = static_cast<unsigned int>(desc.m_PaddingMode == PaddingMode::Reflect);
2726*89c4ff92SAndroid Build Coastguard Worker     for(unsigned int i = 0; i < padList.size(); ++i)
2727*89c4ff92SAndroid Build Coastguard Worker     {
2728*89c4ff92SAndroid Build Coastguard Worker         if(padList.at(i).first > (inputShape[i] - isReflect) ||
2729*89c4ff92SAndroid Build Coastguard Worker            padList.at(i).second > (inputShape[i] - isReflect))
2730*89c4ff92SAndroid Build Coastguard Worker         {
2731*89c4ff92SAndroid Build Coastguard Worker             ARMNN_THROW_PARSE_EXCEPTION("Padding values must be less (Reflect) or "
2732*89c4ff92SAndroid Build Coastguard Worker                                         "equal (Symmetric) to the dimension size.");
2733*89c4ff92SAndroid Build Coastguard Worker         }
2734*89c4ff92SAndroid Build Coastguard Worker     }
2735*89c4ff92SAndroid Build Coastguard Worker 
2736*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("MirrorPad:{}:{}", subgraphIndex, operatorIndex);
2737*89c4ff92SAndroid Build Coastguard Worker 
2738*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddPadLayer(desc, layerName.c_str());
2739*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2740*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2741*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2742*89c4ff92SAndroid Build Coastguard Worker 
2743*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2744*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2745*89c4ff92SAndroid Build Coastguard Worker 
2746*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2747*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2748*89c4ff92SAndroid Build Coastguard Worker }
2749*89c4ff92SAndroid Build Coastguard Worker 
ParsePrelu(size_t subgraphIndex,size_t operatorIndex)2750*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParsePrelu(size_t subgraphIndex, size_t operatorIndex)
2751*89c4ff92SAndroid Build Coastguard Worker {
2752*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2753*89c4ff92SAndroid Build Coastguard Worker 
2754*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2755*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
2756*89c4ff92SAndroid Build Coastguard Worker 
2757*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2758*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2759*89c4ff92SAndroid Build Coastguard Worker 
2760*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Prelu:{}:{}", subgraphIndex, operatorIndex);
2761*89c4ff92SAndroid Build Coastguard Worker 
2762*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2763*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo alphaTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
2764*89c4ff92SAndroid Build Coastguard Worker 
2765*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddPreluLayer(layerName.c_str());
2766*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2767*89c4ff92SAndroid Build Coastguard Worker 
2768*89c4ff92SAndroid Build Coastguard Worker 
2769*89c4ff92SAndroid Build Coastguard Worker     if (IsConstTensor(inputs[1]))
2770*89c4ff92SAndroid Build Coastguard Worker     {
2771*89c4ff92SAndroid Build Coastguard Worker         auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2772*89c4ff92SAndroid Build Coastguard Worker         armnn::IInputSlot* slot = &(layer->GetInputSlot(0));
2773*89c4ff92SAndroid Build Coastguard Worker         RegisterConsumerOfTensor(subgraphIndex, inputTensorIndexes[0], slot);
2774*89c4ff92SAndroid Build Coastguard Worker 
2775*89c4ff92SAndroid Build Coastguard Worker         auto alphaTensorAndData = CreateConstTensorNonPermuted(inputs[1], alphaTensorInfo,
2776*89c4ff92SAndroid Build Coastguard Worker                                                                inputTensorInfo.GetDataType());
2777*89c4ff92SAndroid Build Coastguard Worker         std::string constLayerName = fmt::format("Constant:{}", inputs[1]->name);
2778*89c4ff92SAndroid Build Coastguard Worker         IConnectableLayer* constLayer =
2779*89c4ff92SAndroid Build Coastguard Worker                     m_Network->AddConstantLayer(alphaTensorAndData.first, constLayerName.c_str());
2780*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(constLayer != nullptr);
2781*89c4ff92SAndroid Build Coastguard Worker 
2782*89c4ff92SAndroid Build Coastguard Worker         constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
2783*89c4ff92SAndroid Build Coastguard Worker         constLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
2784*89c4ff92SAndroid Build Coastguard Worker         RegisterOutputSlots(subgraphIndex,
2785*89c4ff92SAndroid Build Coastguard Worker                             VIRTUAL_OPERATOR_ID,
2786*89c4ff92SAndroid Build Coastguard Worker                             constLayer,
2787*89c4ff92SAndroid Build Coastguard Worker                             { inputTensorIndexes[1] });
2788*89c4ff92SAndroid Build Coastguard Worker     }
2789*89c4ff92SAndroid Build Coastguard Worker     else
2790*89c4ff92SAndroid Build Coastguard Worker     {
2791*89c4ff92SAndroid Build Coastguard Worker         auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2792*89c4ff92SAndroid Build Coastguard Worker         RegisterInputSlots(subgraphIndex, operatorIndex, layer, inputTensorIndexes);
2793*89c4ff92SAndroid Build Coastguard Worker     }
2794*89c4ff92SAndroid Build Coastguard Worker 
2795*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
2796*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
2797*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2798*89c4ff92SAndroid Build Coastguard Worker 
2799*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2800*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
2801*89c4ff92SAndroid Build Coastguard Worker }
2802*89c4ff92SAndroid Build Coastguard Worker 
ParseQuantize(size_t subgraphIndex,size_t operatorIndex)2803*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseQuantize(size_t subgraphIndex, size_t operatorIndex)
2804*89c4ff92SAndroid Build Coastguard Worker {
2805*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2806*89c4ff92SAndroid Build Coastguard Worker 
2807*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2808*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2809*89c4ff92SAndroid Build Coastguard Worker 
2810*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2811*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2812*89c4ff92SAndroid Build Coastguard Worker 
2813*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Quantize:{}:{}", subgraphIndex, operatorIndex);
2814*89c4ff92SAndroid Build Coastguard Worker 
2815*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddQuantizeLayer(layerName.c_str());
2816*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
2817*89c4ff92SAndroid Build Coastguard Worker 
2818*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2819*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2820*89c4ff92SAndroid Build Coastguard Worker 
2821*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2822*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2823*89c4ff92SAndroid Build Coastguard Worker 
2824*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2825*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
2826*89c4ff92SAndroid Build Coastguard Worker }
2827*89c4ff92SAndroid Build Coastguard Worker 
ParseRelu(size_t subgraphIndex,size_t operatorIndex)2828*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseRelu(size_t subgraphIndex, size_t operatorIndex)
2829*89c4ff92SAndroid Build Coastguard Worker {
2830*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex,operatorIndex, ActivationFunction::ReLu);
2831*89c4ff92SAndroid Build Coastguard Worker }
2832*89c4ff92SAndroid Build Coastguard Worker 
ParseRelu6(size_t subgraphIndex,size_t operatorIndex)2833*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseRelu6(size_t subgraphIndex, size_t operatorIndex)
2834*89c4ff92SAndroid Build Coastguard Worker {
2835*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex,operatorIndex, ActivationFunction::BoundedReLu);
2836*89c4ff92SAndroid Build Coastguard Worker }
2837*89c4ff92SAndroid Build Coastguard Worker 
ParseLeakyRelu(size_t subgraphIndex,size_t operatorIndex)2838*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLeakyRelu(size_t subgraphIndex, size_t operatorIndex)
2839*89c4ff92SAndroid Build Coastguard Worker {
2840*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::LeakyReLu);
2841*89c4ff92SAndroid Build Coastguard Worker }
2842*89c4ff92SAndroid Build Coastguard Worker 
ParseLogistic(size_t subgraphIndex,size_t operatorIndex)2843*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLogistic(size_t subgraphIndex, size_t operatorIndex)
2844*89c4ff92SAndroid Build Coastguard Worker {
2845*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex,operatorIndex,ActivationFunction::Sigmoid);
2846*89c4ff92SAndroid Build Coastguard Worker }
2847*89c4ff92SAndroid Build Coastguard Worker 
ParseTanH(size_t subgraphIndex,size_t operatorIndex)2848*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseTanH(size_t subgraphIndex, size_t operatorIndex)
2849*89c4ff92SAndroid Build Coastguard Worker {
2850*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex,operatorIndex,ActivationFunction::TanH);
2851*89c4ff92SAndroid Build Coastguard Worker }
2852*89c4ff92SAndroid Build Coastguard Worker 
ParseElu(size_t subgraphIndex,size_t operatorIndex)2853*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseElu(size_t subgraphIndex, size_t operatorIndex)
2854*89c4ff92SAndroid Build Coastguard Worker {
2855*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::Elu);
2856*89c4ff92SAndroid Build Coastguard Worker }
2857*89c4ff92SAndroid Build Coastguard Worker 
ParseHardSwish(size_t subgraphIndex,size_t operatorIndex)2858*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseHardSwish(size_t subgraphIndex, size_t operatorIndex)
2859*89c4ff92SAndroid Build Coastguard Worker {
2860*89c4ff92SAndroid Build Coastguard Worker     ParseActivation(subgraphIndex, operatorIndex, ActivationFunction::HardSwish);
2861*89c4ff92SAndroid Build Coastguard Worker }
2862*89c4ff92SAndroid Build Coastguard Worker 
ParseActivation(size_t subgraphIndex,size_t operatorIndex,ActivationFunction activationType)2863*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseActivation(size_t subgraphIndex, size_t operatorIndex, ActivationFunction activationType)
2864*89c4ff92SAndroid Build Coastguard Worker {
2865*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2866*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2867*89c4ff92SAndroid Build Coastguard Worker     IgnoreUnused(operatorPtr);
2868*89c4ff92SAndroid Build Coastguard Worker 
2869*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2870*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
2871*89c4ff92SAndroid Build Coastguard Worker 
2872*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2873*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2874*89c4ff92SAndroid Build Coastguard Worker 
2875*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Activation:");
2876*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor activationDesc;
2877*89c4ff92SAndroid Build Coastguard Worker     activationDesc.m_Function = activationType;
2878*89c4ff92SAndroid Build Coastguard Worker 
2879*89c4ff92SAndroid Build Coastguard Worker     switch (activationType)
2880*89c4ff92SAndroid Build Coastguard Worker     {
2881*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::ReLu:
2882*89c4ff92SAndroid Build Coastguard Worker         {
2883*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("RELU:{}:{}", subgraphIndex, operatorIndex);
2884*89c4ff92SAndroid Build Coastguard Worker             break;
2885*89c4ff92SAndroid Build Coastguard Worker         }
2886*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::BoundedReLu:
2887*89c4ff92SAndroid Build Coastguard Worker         {
2888*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("RELU6:{}:{}", subgraphIndex, operatorIndex);
2889*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 6.0f;
2890*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = 0.0f;
2891*89c4ff92SAndroid Build Coastguard Worker             break;
2892*89c4ff92SAndroid Build Coastguard Worker         }
2893*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Sigmoid:
2894*89c4ff92SAndroid Build Coastguard Worker         {
2895*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("SIGMOID:{}:{}", subgraphIndex, operatorIndex);
2896*89c4ff92SAndroid Build Coastguard Worker             break;
2897*89c4ff92SAndroid Build Coastguard Worker         }
2898*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::TanH:
2899*89c4ff92SAndroid Build Coastguard Worker         {
2900*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("TANH:{}:{}", subgraphIndex, operatorIndex);
2901*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 1.0f;
2902*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = 1.0f;
2903*89c4ff92SAndroid Build Coastguard Worker             break;
2904*89c4ff92SAndroid Build Coastguard Worker         }
2905*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::LeakyReLu:
2906*89c4ff92SAndroid Build Coastguard Worker         {
2907*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("LEAKYRELU:{}:{}", subgraphIndex, operatorIndex);
2908*89c4ff92SAndroid Build Coastguard Worker             const auto* options = operatorPtr->builtin_options.AsLeakyReluOptions();
2909*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = options->alpha;
2910*89c4ff92SAndroid Build Coastguard Worker             break;
2911*89c4ff92SAndroid Build Coastguard Worker         }
2912*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::Elu:
2913*89c4ff92SAndroid Build Coastguard Worker         {
2914*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("ELU:{}:{}", subgraphIndex, operatorIndex);
2915*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 1.0f;
2916*89c4ff92SAndroid Build Coastguard Worker             break;
2917*89c4ff92SAndroid Build Coastguard Worker         }
2918*89c4ff92SAndroid Build Coastguard Worker         case ActivationFunction::HardSwish:
2919*89c4ff92SAndroid Build Coastguard Worker         {
2920*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("HARDSWISH:{}:{}", subgraphIndex, operatorIndex);
2921*89c4ff92SAndroid Build Coastguard Worker             break;
2922*89c4ff92SAndroid Build Coastguard Worker         }
2923*89c4ff92SAndroid Build Coastguard Worker         default:
2924*89c4ff92SAndroid Build Coastguard Worker         {
2925*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
2926*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Unexpected ActivationFunction[{}] when creating layerName {} ",
2927*89c4ff92SAndroid Build Coastguard Worker                             static_cast<int>(activationType), CHECK_LOCATION().AsString()));
2928*89c4ff92SAndroid Build Coastguard Worker         }
2929*89c4ff92SAndroid Build Coastguard Worker     }
2930*89c4ff92SAndroid Build Coastguard Worker 
2931*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* const layer = m_Network->AddActivationLayer(activationDesc, layerName.c_str());
2932*89c4ff92SAndroid Build Coastguard Worker 
2933*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
2934*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
2935*89c4ff92SAndroid Build Coastguard Worker 
2936*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
2937*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
2938*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
2939*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
2940*89c4ff92SAndroid Build Coastguard Worker 
2941*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
2942*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
2943*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
2944*89c4ff92SAndroid Build Coastguard Worker }
OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,const std::vector<int32_t> & targetDimsIn)2945*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TfLiteParserImpl::OutputShapeOfReshape(const armnn::TensorInfo& inputTensorInfo,
2946*89c4ff92SAndroid Build Coastguard Worker                                                          const std::vector<int32_t>& targetDimsIn)
2947*89c4ff92SAndroid Build Coastguard Worker {
2948*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> outputDims(targetDimsIn.begin(), targetDimsIn.end());
2949*89c4ff92SAndroid Build Coastguard Worker     const auto stretchDim = std::find(targetDimsIn.begin(), targetDimsIn.end(), -1);
2950*89c4ff92SAndroid Build Coastguard Worker 
2951*89c4ff92SAndroid Build Coastguard Worker     if (stretchDim != targetDimsIn.end())
2952*89c4ff92SAndroid Build Coastguard Worker     {
2953*89c4ff92SAndroid Build Coastguard Worker         if (std::find(std::next(stretchDim), targetDimsIn.end(), -1) != targetDimsIn.end())
2954*89c4ff92SAndroid Build Coastguard Worker         {
2955*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
2956*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("At most one component of shape can be -1 {}", CHECK_LOCATION().AsString()));
2957*89c4ff92SAndroid Build Coastguard Worker         }
2958*89c4ff92SAndroid Build Coastguard Worker 
2959*89c4ff92SAndroid Build Coastguard Worker         auto targetNumElements =
2960*89c4ff92SAndroid Build Coastguard Worker             armnn::numeric_cast<unsigned int>(
2961*89c4ff92SAndroid Build Coastguard Worker                 std::accumulate(targetDimsIn.begin(), targetDimsIn.end(), -1, std::multiplies<int32_t>()));
2962*89c4ff92SAndroid Build Coastguard Worker 
2963*89c4ff92SAndroid Build Coastguard Worker         auto stretchIndex = static_cast<size_t>(std::distance(targetDimsIn.begin(), stretchDim));
2964*89c4ff92SAndroid Build Coastguard Worker         outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
2965*89c4ff92SAndroid Build Coastguard Worker     }
2966*89c4ff92SAndroid Build Coastguard Worker 
2967*89c4ff92SAndroid Build Coastguard Worker     TensorShape outputShape = TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
2968*89c4ff92SAndroid Build Coastguard Worker 
2969*89c4ff92SAndroid Build Coastguard Worker     TensorInfo reshapeInfo = inputTensorInfo;
2970*89c4ff92SAndroid Build Coastguard Worker     reshapeInfo.SetShape(outputShape);
2971*89c4ff92SAndroid Build Coastguard Worker 
2972*89c4ff92SAndroid Build Coastguard Worker     return reshapeInfo;
2973*89c4ff92SAndroid Build Coastguard Worker }
2974*89c4ff92SAndroid Build Coastguard Worker 
ParseReshape(size_t subgraphIndex,size_t operatorIndex)2975*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseReshape(size_t subgraphIndex, size_t operatorIndex)
2976*89c4ff92SAndroid Build Coastguard Worker {
2977*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
2978*89c4ff92SAndroid Build Coastguard Worker 
2979*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
2980*89c4ff92SAndroid Build Coastguard Worker 
2981*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
2982*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
2983*89c4ff92SAndroid Build Coastguard Worker 
2984*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
2985*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsReshapeOptions();
2986*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Reshape:{}:{}", subgraphIndex, operatorIndex);
2987*89c4ff92SAndroid Build Coastguard Worker 
2988*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
2989*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo actualOutputTensorInfo  = ToTensorInfo(outputs[0]);
2990*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, actualOutputTensorInfo, layerName, "Input 0", "Output 0");
2991*89c4ff92SAndroid Build Coastguard Worker 
2992*89c4ff92SAndroid Build Coastguard Worker     // Extracting new shape for the output
2993*89c4ff92SAndroid Build Coastguard Worker     // There are two ways it can be passed
2994*89c4ff92SAndroid Build Coastguard Worker     //  * First is to define the target shape in the operator built-in options
2995*89c4ff92SAndroid Build Coastguard Worker     //  * Second is to pass it as a second input tensor
2996*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> targetShape;
2997*89c4ff92SAndroid Build Coastguard Worker     bool targetShapeFound = false;
2998*89c4ff92SAndroid Build Coastguard Worker     // Check if built-in options were given
2999*89c4ff92SAndroid Build Coastguard Worker     if (options != nullptr)
3000*89c4ff92SAndroid Build Coastguard Worker     {
3001*89c4ff92SAndroid Build Coastguard Worker         // make sure the parameter is given
3002*89c4ff92SAndroid Build Coastguard Worker         if (options->new_shape.empty() == false)
3003*89c4ff92SAndroid Build Coastguard Worker         {
3004*89c4ff92SAndroid Build Coastguard Worker             targetShape = options->new_shape;
3005*89c4ff92SAndroid Build Coastguard Worker             targetShapeFound = true;
3006*89c4ff92SAndroid Build Coastguard Worker         }
3007*89c4ff92SAndroid Build Coastguard Worker     }
3008*89c4ff92SAndroid Build Coastguard Worker 
3009*89c4ff92SAndroid Build Coastguard Worker     // If there is no built-in option given or if the built-in new_shape parameter was empty
3010*89c4ff92SAndroid Build Coastguard Worker     if (!targetShapeFound)
3011*89c4ff92SAndroid Build Coastguard Worker     {
3012*89c4ff92SAndroid Build Coastguard Worker         // Check for a second input tensor
3013*89c4ff92SAndroid Build Coastguard Worker         if (inputs.size() > 1 && inputs[1] != nullptr)
3014*89c4ff92SAndroid Build Coastguard Worker         {
3015*89c4ff92SAndroid Build Coastguard Worker             if (inputs[1]->is_variable)
3016*89c4ff92SAndroid Build Coastguard Worker             {
3017*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_THROW_PARSE_EXCEPTION( "Target shapes defined in non-const input tensors is not supported");
3018*89c4ff92SAndroid Build Coastguard Worker             }
3019*89c4ff92SAndroid Build Coastguard Worker 
3020*89c4ff92SAndroid Build Coastguard Worker             if (inputs[1]->shape.size() != 1)
3021*89c4ff92SAndroid Build Coastguard Worker             {
3022*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not a 1D tensor");
3023*89c4ff92SAndroid Build Coastguard Worker             }
3024*89c4ff92SAndroid Build Coastguard Worker 
3025*89c4ff92SAndroid Build Coastguard Worker             if (inputs[1]->type != tflite::TensorType_INT32)
3026*89c4ff92SAndroid Build Coastguard Worker             {
3027*89c4ff92SAndroid Build Coastguard Worker                 ARMNN_THROW_PARSE_EXCEPTION("Target 'shape' input is not an int32 type");
3028*89c4ff92SAndroid Build Coastguard Worker             }
3029*89c4ff92SAndroid Build Coastguard Worker 
3030*89c4ff92SAndroid Build Coastguard Worker             // Extract target shape from input
3031*89c4ff92SAndroid Build Coastguard Worker             auto bufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
3032*89c4ff92SAndroid Build Coastguard Worker             auto values = reinterpret_cast<const int32_t*>(bufferPtr->data.data());
3033*89c4ff92SAndroid Build Coastguard Worker             if (values)
3034*89c4ff92SAndroid Build Coastguard Worker             {
3035*89c4ff92SAndroid Build Coastguard Worker                 for (int i = 0; i < inputs[1]->shape[0]; ++i)
3036*89c4ff92SAndroid Build Coastguard Worker                 {
3037*89c4ff92SAndroid Build Coastguard Worker                     targetShape.push_back(values[i]);
3038*89c4ff92SAndroid Build Coastguard Worker                 }
3039*89c4ff92SAndroid Build Coastguard Worker             }
3040*89c4ff92SAndroid Build Coastguard Worker             else
3041*89c4ff92SAndroid Build Coastguard Worker             {
3042*89c4ff92SAndroid Build Coastguard Worker                 try
3043*89c4ff92SAndroid Build Coastguard Worker                 {
3044*89c4ff92SAndroid Build Coastguard Worker                     // We attempt to infer during Runtime.
3045*89c4ff92SAndroid Build Coastguard Worker                     TensorShape reshapeShapes = ToTensorInfo(inputs[1]).GetShape();
3046*89c4ff92SAndroid Build Coastguard Worker 
3047*89c4ff92SAndroid Build Coastguard Worker                     if (reshapeShapes[0] == actualOutputTensorInfo.GetNumDimensions())
3048*89c4ff92SAndroid Build Coastguard Worker                     {
3049*89c4ff92SAndroid Build Coastguard Worker                         for (unsigned int i = 0; i < actualOutputTensorInfo.GetShape().GetNumDimensions(); ++i)
3050*89c4ff92SAndroid Build Coastguard Worker                         {
3051*89c4ff92SAndroid Build Coastguard Worker                             targetShape.push_back(actualOutputTensorInfo.GetShape()[i]);
3052*89c4ff92SAndroid Build Coastguard Worker                         }
3053*89c4ff92SAndroid Build Coastguard Worker                     }
3054*89c4ff92SAndroid Build Coastguard Worker                     // The parser only supports shape (batch, -1) or (-1) for non-constant shape input.
3055*89c4ff92SAndroid Build Coastguard Worker                     else if (reshapeShapes[0] > 2)
3056*89c4ff92SAndroid Build Coastguard Worker                     {
3057*89c4ff92SAndroid Build Coastguard Worker                         throw ParseException(fmt::format("Invalid input shape '{}' in Reshape layer '{}' {}. "
3058*89c4ff92SAndroid Build Coastguard Worker                                                          "When inferring during runtime, the parser only supports "
3059*89c4ff92SAndroid Build Coastguard Worker                                                          "shape (batch, -1) or (-1) for target shape input.",
3060*89c4ff92SAndroid Build Coastguard Worker                                                          reshapeShapes[0],
3061*89c4ff92SAndroid Build Coastguard Worker                                                          layerName,
3062*89c4ff92SAndroid Build Coastguard Worker                                                          CHECK_LOCATION().AsString()));
3063*89c4ff92SAndroid Build Coastguard Worker                     }
3064*89c4ff92SAndroid Build Coastguard Worker                     else
3065*89c4ff92SAndroid Build Coastguard Worker                     {
3066*89c4ff92SAndroid Build Coastguard Worker                         const int32_t numInputElements = inputTensorInfo.GetNumElements();
3067*89c4ff92SAndroid Build Coastguard Worker                         const int32_t inputTensorShape = inputTensorInfo.GetShape()[0];
3068*89c4ff92SAndroid Build Coastguard Worker                         if (reshapeShapes[0] == 1)
3069*89c4ff92SAndroid Build Coastguard Worker                         {
3070*89c4ff92SAndroid Build Coastguard Worker                             targetShape = {numInputElements};
3071*89c4ff92SAndroid Build Coastguard Worker                         }
3072*89c4ff92SAndroid Build Coastguard Worker                         else if (reshapeShapes[0] == 2)
3073*89c4ff92SAndroid Build Coastguard Worker                         {
3074*89c4ff92SAndroid Build Coastguard Worker                             targetShape = {inputTensorShape, numInputElements / inputTensorShape};
3075*89c4ff92SAndroid Build Coastguard Worker                         }
3076*89c4ff92SAndroid Build Coastguard Worker                     }
3077*89c4ff92SAndroid Build Coastguard Worker                 }
3078*89c4ff92SAndroid Build Coastguard Worker                 catch (const std::exception& exc)
3079*89c4ff92SAndroid Build Coastguard Worker                 {
3080*89c4ff92SAndroid Build Coastguard Worker                     ARMNN_THROW_PARSE_EXCEPTION("Failed attempt to infer during runtime the target shape input for "
3081*89c4ff92SAndroid Build Coastguard Worker                                                 "Reshape operation. Reshape operator target shape input buffer data "
3082*89c4ff92SAndroid Build Coastguard Worker                                                 "is null. " << exc.what());
3083*89c4ff92SAndroid Build Coastguard Worker                 }
3084*89c4ff92SAndroid Build Coastguard Worker             }
3085*89c4ff92SAndroid Build Coastguard Worker         }
3086*89c4ff92SAndroid Build Coastguard Worker         else
3087*89c4ff92SAndroid Build Coastguard Worker         {
3088*89c4ff92SAndroid Build Coastguard Worker             ARMNN_THROW_PARSE_EXCEPTION("Target shape not defined in reshape parameters or input tensor. "
3089*89c4ff92SAndroid Build Coastguard Worker                                         "At least one method required");
3090*89c4ff92SAndroid Build Coastguard Worker         }
3091*89c4ff92SAndroid Build Coastguard Worker     }
3092*89c4ff92SAndroid Build Coastguard Worker 
3093*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo reshapeOutputTensorInfo =
3094*89c4ff92SAndroid Build Coastguard Worker         TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, targetShape);
3095*89c4ff92SAndroid Build Coastguard Worker 
3096*89c4ff92SAndroid Build Coastguard Worker     // Check for valid input size and that reshape parameters equal output shape
3097*89c4ff92SAndroid Build Coastguard Worker     // The output shape can be provided to us in 2 ways:
3098*89c4ff92SAndroid Build Coastguard Worker     // 1. through the normal 'shape' parameter given by outputs[indx]->shape
3099*89c4ff92SAndroid Build Coastguard Worker     // 2. through additional parameter 'shape_signature' given by outputs[indx]->buffer.
3100*89c4ff92SAndroid Build Coastguard Worker     //    This parameter can sometimes contain -1 value not visible in the 'shape' parameter.
3101*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape& reshapeOutputTensorShape = reshapeOutputTensorInfo.GetShape();
3102*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() > 1 && !CheckShape(reshapeOutputTensorShape, outputs[0]->shape))
3103*89c4ff92SAndroid Build Coastguard Worker     {
3104*89c4ff92SAndroid Build Coastguard Worker         // Attempt to extract output shape from secondary 'shape_signature'
3105*89c4ff92SAndroid Build Coastguard Worker         // parameter and try to CheckShape() with this param.
3106*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> secondaryOutputTargetShape = outputs[0]->shape_signature;
3107*89c4ff92SAndroid Build Coastguard Worker 
3108*89c4ff92SAndroid Build Coastguard Worker         // if outputs[0]->shape_signature contain a -1 value, we need to compute its actual value
3109*89c4ff92SAndroid Build Coastguard Worker         // from reshape input in order to correctly verify reshape parameters equal output shape
3110*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo secondaryReshapeOutputTensorInfo =
3111*89c4ff92SAndroid Build Coastguard Worker             TfLiteParserImpl::OutputShapeOfReshape(inputTensorInfo, secondaryOutputTargetShape);
3112*89c4ff92SAndroid Build Coastguard Worker 
3113*89c4ff92SAndroid Build Coastguard Worker         if (!CheckShape(reshapeOutputTensorShape, secondaryReshapeOutputTensorInfo.GetShape()))
3114*89c4ff92SAndroid Build Coastguard Worker         {
3115*89c4ff92SAndroid Build Coastguard Worker             std::stringstream ss;
3116*89c4ff92SAndroid Build Coastguard Worker             ss << "New shape defined in reshape parameters "
3117*89c4ff92SAndroid Build Coastguard Worker                << reshapeOutputTensorShape
3118*89c4ff92SAndroid Build Coastguard Worker                << " does not equal output shape "
3119*89c4ff92SAndroid Build Coastguard Worker                << actualOutputTensorInfo.GetShape()
3120*89c4ff92SAndroid Build Coastguard Worker                << ": "
3121*89c4ff92SAndroid Build Coastguard Worker                << CHECK_LOCATION().AsString();
3122*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(ss.str());
3123*89c4ff92SAndroid Build Coastguard Worker         }
3124*89c4ff92SAndroid Build Coastguard Worker     }
3125*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIds = GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex);
3126*89c4ff92SAndroid Build Coastguard Worker 
3127*89c4ff92SAndroid Build Coastguard Worker     ReshapeDescriptor reshapeDesc;
3128*89c4ff92SAndroid Build Coastguard Worker     reshapeDesc.m_TargetShape = reshapeOutputTensorInfo.GetShape();
3129*89c4ff92SAndroid Build Coastguard Worker     m_TensorInfos[outputTensorIds[0]] = reshapeOutputTensorInfo;
3130*89c4ff92SAndroid Build Coastguard Worker 
3131*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddReshapeLayer(reshapeDesc, layerName.c_str());
3132*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3133*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(reshapeOutputTensorInfo);
3134*89c4ff92SAndroid Build Coastguard Worker 
3135*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3136*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
3137*89c4ff92SAndroid Build Coastguard Worker 
3138*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3139*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
3140*89c4ff92SAndroid Build Coastguard Worker }
3141*89c4ff92SAndroid Build Coastguard Worker 
ParseResizeBilinear(size_t subgraphIndex,size_t operatorIndex)3142*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseResizeBilinear(size_t subgraphIndex, size_t operatorIndex)
3143*89c4ff92SAndroid Build Coastguard Worker {
3144*89c4ff92SAndroid Build Coastguard Worker     ParseResize(subgraphIndex, operatorIndex, ResizeMethod::Bilinear);
3145*89c4ff92SAndroid Build Coastguard Worker }
3146*89c4ff92SAndroid Build Coastguard Worker 
ParseResizeNearestNeighbor(size_t subgraphIndex,size_t operatorIndex)3147*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseResizeNearestNeighbor(size_t subgraphIndex, size_t operatorIndex)
3148*89c4ff92SAndroid Build Coastguard Worker {
3149*89c4ff92SAndroid Build Coastguard Worker     ParseResize(subgraphIndex, operatorIndex, ResizeMethod::NearestNeighbor);
3150*89c4ff92SAndroid Build Coastguard Worker }
3151*89c4ff92SAndroid Build Coastguard Worker 
ParseResize(size_t subgraphIndex,size_t operatorIndex,ResizeMethod resizeMethod)3152*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseResize(size_t subgraphIndex, size_t operatorIndex, ResizeMethod resizeMethod)
3153*89c4ff92SAndroid Build Coastguard Worker {
3154*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3155*89c4ff92SAndroid Build Coastguard Worker 
3156*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3157*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
3158*89c4ff92SAndroid Build Coastguard Worker 
3159*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3160*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
3161*89c4ff92SAndroid Build Coastguard Worker 
3162*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo sizeTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
3163*89c4ff92SAndroid Build Coastguard Worker 
3164*89c4ff92SAndroid Build Coastguard Worker     // Data for the parsed tensor args (size) must be stored locally.
3165*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> sizeTensorData(sizeTensorInfo.GetNumElements());
3166*89c4ff92SAndroid Build Coastguard Worker 
3167*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr sizeBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
3168*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(sizeTensorData.data(), sizeBufferPtr->data.data(), sizeTensorInfo.GetNumBytes());
3169*89c4ff92SAndroid Build Coastguard Worker 
3170*89c4ff92SAndroid Build Coastguard Worker     ResizeDescriptor desc;
3171*89c4ff92SAndroid Build Coastguard Worker     desc.m_Method       = resizeMethod;
3172*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetHeight = static_cast<uint32_t> (sizeTensorData[0]);
3173*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetWidth  = static_cast<uint32_t> (sizeTensorData[1]);
3174*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout   = armnn::DataLayout::NHWC;
3175*89c4ff92SAndroid Build Coastguard Worker 
3176*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Resize:");
3177*89c4ff92SAndroid Build Coastguard Worker 
3178*89c4ff92SAndroid Build Coastguard Worker     switch (resizeMethod)
3179*89c4ff92SAndroid Build Coastguard Worker     {
3180*89c4ff92SAndroid Build Coastguard Worker         case ResizeMethod::Bilinear:
3181*89c4ff92SAndroid Build Coastguard Worker         {
3182*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("BILINEAR:{}:{}", subgraphIndex, operatorIndex);
3183*89c4ff92SAndroid Build Coastguard Worker 
3184*89c4ff92SAndroid Build Coastguard Worker             const auto & operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3185*89c4ff92SAndroid Build Coastguard Worker             const auto * options     = operatorPtr->builtin_options.AsResizeBilinearOptions();
3186*89c4ff92SAndroid Build Coastguard Worker 
3187*89c4ff92SAndroid Build Coastguard Worker             desc.m_AlignCorners = options->align_corners;
3188*89c4ff92SAndroid Build Coastguard Worker             break;
3189*89c4ff92SAndroid Build Coastguard Worker         }
3190*89c4ff92SAndroid Build Coastguard Worker         case ResizeMethod::NearestNeighbor:
3191*89c4ff92SAndroid Build Coastguard Worker         {
3192*89c4ff92SAndroid Build Coastguard Worker             layerName += fmt::format("NEARESTNEIGHBOR:{}:{}", subgraphIndex, operatorIndex);
3193*89c4ff92SAndroid Build Coastguard Worker             break;
3194*89c4ff92SAndroid Build Coastguard Worker         }
3195*89c4ff92SAndroid Build Coastguard Worker         default:
3196*89c4ff92SAndroid Build Coastguard Worker         {
3197*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
3198*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Unexpected ResizeMethod[{}] when creating layerName {} ",
3199*89c4ff92SAndroid Build Coastguard Worker                             static_cast<int>(resizeMethod), CHECK_LOCATION().AsString()));
3200*89c4ff92SAndroid Build Coastguard Worker         }
3201*89c4ff92SAndroid Build Coastguard Worker     }
3202*89c4ff92SAndroid Build Coastguard Worker 
3203*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3204*89c4ff92SAndroid Build Coastguard Worker 
3205*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddResizeLayer(desc, layerName.c_str());
3206*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3207*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
3208*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, outputTensorInfo, layerName, "Input 0", "Output 0");
3209*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3210*89c4ff92SAndroid Build Coastguard Worker 
3211*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3212*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
3213*89c4ff92SAndroid Build Coastguard Worker 
3214*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3215*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
3216*89c4ff92SAndroid Build Coastguard Worker }
3217*89c4ff92SAndroid Build Coastguard Worker 
ParseConcatenation(size_t subgraphIndex,size_t operatorIndex)3218*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseConcatenation(size_t subgraphIndex, size_t operatorIndex)
3219*89c4ff92SAndroid Build Coastguard Worker {
3220*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3221*89c4ff92SAndroid Build Coastguard Worker 
3222*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3223*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsConcatenationOptions();
3224*89c4ff92SAndroid Build Coastguard Worker 
3225*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
3226*89c4ff92SAndroid Build Coastguard Worker 
3227*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3228*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3229*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIds = GetInputTensorIds(m_Model, subgraphIndex, operatorIndex);
3230*89c4ff92SAndroid Build Coastguard Worker 
3231*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
3232*89c4ff92SAndroid Build Coastguard Worker 
3233*89c4ff92SAndroid Build Coastguard Worker     unsigned int numConcatView = static_cast<unsigned int>(inputs.size());
3234*89c4ff92SAndroid Build Coastguard Worker     uint32_t inputRank = InputTensorInfo(subgraphIndex, operatorIndex, 0).GetNumDimensions();
3235*89c4ff92SAndroid Build Coastguard Worker 
3236*89c4ff92SAndroid Build Coastguard Worker     const unsigned int concatDimInput = static_cast<unsigned int>(
3237*89c4ff92SAndroid Build Coastguard Worker             (static_cast<int>(inputRank) + options->axis) % static_cast<int>(inputRank));
3238*89c4ff92SAndroid Build Coastguard Worker 
3239*89c4ff92SAndroid Build Coastguard Worker     OriginsDescriptor concatDescriptor(static_cast<uint32_t>(numConcatView), inputRank);
3240*89c4ff92SAndroid Build Coastguard Worker     concatDescriptor.SetConcatAxis(concatDimInput);
3241*89c4ff92SAndroid Build Coastguard Worker     unsigned int mergeDimOrigin = 0;
3242*89c4ff92SAndroid Build Coastguard Worker 
3243*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int viewIndex = 0; viewIndex < numConcatView; ++viewIndex)
3244*89c4ff92SAndroid Build Coastguard Worker     {
3245*89c4ff92SAndroid Build Coastguard Worker         TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, viewIndex);
3246*89c4ff92SAndroid Build Coastguard Worker 
3247*89c4ff92SAndroid Build Coastguard Worker         // This set up concatDescriptor view origin
3248*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::ProcessConcatInputTensorInfo(
3249*89c4ff92SAndroid Build Coastguard Worker                 inputTensorInfo, concatDescriptor, concatDimInput, viewIndex, mergeDimOrigin);
3250*89c4ff92SAndroid Build Coastguard Worker     }
3251*89c4ff92SAndroid Build Coastguard Worker 
3252*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Concatenation:{}:{}", subgraphIndex, operatorIndex);
3253*89c4ff92SAndroid Build Coastguard Worker 
3254*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddConcatLayer(concatDescriptor, layerName.c_str());
3255*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3256*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {});
3257*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3258*89c4ff92SAndroid Build Coastguard Worker 
3259*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3260*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
3261*89c4ff92SAndroid Build Coastguard Worker 
3262*89c4ff92SAndroid Build Coastguard Worker     // add fused activation layer
3263*89c4ff92SAndroid Build Coastguard Worker     layer = AddFusedActivationLayer(layer, 0, options->fused_activation_function);
3264*89c4ff92SAndroid Build Coastguard Worker 
3265*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3266*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
3267*89c4ff92SAndroid Build Coastguard Worker }
3268*89c4ff92SAndroid Build Coastguard Worker 
ParseFullyConnected(size_t subgraphIndex,size_t operatorIndex)3269*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseFullyConnected(size_t subgraphIndex, size_t operatorIndex)
3270*89c4ff92SAndroid Build Coastguard Worker {
3271*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3272*89c4ff92SAndroid Build Coastguard Worker 
3273*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorRfr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3274*89c4ff92SAndroid Build Coastguard Worker     const auto options = operatorRfr->builtin_options.AsFullyConnectedOptions();
3275*89c4ff92SAndroid Build Coastguard Worker 
3276*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(options, subgraphIndex, operatorIndex);
3277*89c4ff92SAndroid Build Coastguard Worker 
3278*89c4ff92SAndroid Build Coastguard Worker     FullyConnectedDescriptor desc;
3279*89c4ff92SAndroid Build Coastguard Worker     desc.m_BiasEnabled = false;
3280*89c4ff92SAndroid Build Coastguard Worker     desc.m_TransposeWeightMatrix = true;
3281*89c4ff92SAndroid Build Coastguard Worker 
3282*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3283*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3284*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
3285*89c4ff92SAndroid Build Coastguard Worker 
3286*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo filterTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
3287*89c4ff92SAndroid Build Coastguard Worker 
3288*89c4ff92SAndroid Build Coastguard Worker     // Fully Connected Layer accepts two dimensional weights input
3289*89c4ff92SAndroid Build Coastguard Worker     int32_t weightsDimension = static_cast<int32_t>(filterTensorInfo.GetNumDimensions());
3290*89c4ff92SAndroid Build Coastguard Worker     if (weightsDimension != 2)
3291*89c4ff92SAndroid Build Coastguard Worker     {
3292*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
3293*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Dimension {} for Fully Connected weights is not supported by Armnn. "
3294*89c4ff92SAndroid Build Coastguard Worker                         "Node {}",
3295*89c4ff92SAndroid Build Coastguard Worker                         weightsDimension,
3296*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
3297*89c4ff92SAndroid Build Coastguard Worker     }
3298*89c4ff92SAndroid Build Coastguard Worker 
3299*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = nullptr;
3300*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("FullyConnected:{}:{}", subgraphIndex, operatorIndex);
3301*89c4ff92SAndroid Build Coastguard Worker 
3302*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3303*89c4ff92SAndroid Build Coastguard Worker     // Add the first input tensor to the registration list
3304*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> tensorIndexesToRegister = {inputTensorIndexes[0]};
3305*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3306*89c4ff92SAndroid Build Coastguard Worker 
3307*89c4ff92SAndroid Build Coastguard Worker     desc.m_ConstantWeights = IsConstTensor(inputs[1]);
3308*89c4ff92SAndroid Build Coastguard Worker 
3309*89c4ff92SAndroid Build Coastguard Worker     // Add the weights input to the registration list, constant layers will be added by SetupConstantLayers if constant.
3310*89c4ff92SAndroid Build Coastguard Worker     tensorIndexesToRegister.emplace_back(inputTensorIndexes[1]);
3311*89c4ff92SAndroid Build Coastguard Worker 
3312*89c4ff92SAndroid Build Coastguard Worker     if (ShouldConstantTensorBeConverted(inputs[1], inputTensorInfo.GetDataType(), filterTensorInfo.GetDataType()))
3313*89c4ff92SAndroid Build Coastguard Worker     {
3314*89c4ff92SAndroid Build Coastguard Worker         m_ConstantsToDequantize.emplace_back(inputs[1]->buffer);
3315*89c4ff92SAndroid Build Coastguard Worker     }
3316*89c4ff92SAndroid Build Coastguard Worker 
3317*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() == 3)
3318*89c4ff92SAndroid Build Coastguard Worker     {
3319*89c4ff92SAndroid Build Coastguard Worker         desc.m_BiasEnabled = true;
3320*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo biasTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
3321*89c4ff92SAndroid Build Coastguard Worker 
3322*89c4ff92SAndroid Build Coastguard Worker         // Add the biases input to the registration list, constant layer will be added by SetupConstantLayers.
3323*89c4ff92SAndroid Build Coastguard Worker         tensorIndexesToRegister.emplace_back(inputTensorIndexes[2]);
3324*89c4ff92SAndroid Build Coastguard Worker 
3325*89c4ff92SAndroid Build Coastguard Worker         if (ShouldConstantTensorBeConverted(inputs[2], inputTensorInfo.GetDataType(), biasTensorInfo.GetDataType()))
3326*89c4ff92SAndroid Build Coastguard Worker         {
3327*89c4ff92SAndroid Build Coastguard Worker             m_ConstantsToDequantize.emplace_back(inputs[2]->buffer);
3328*89c4ff92SAndroid Build Coastguard Worker         }
3329*89c4ff92SAndroid Build Coastguard Worker     }
3330*89c4ff92SAndroid Build Coastguard Worker 
3331*89c4ff92SAndroid Build Coastguard Worker     // Filters and biases are always passed to fully connected as inputs
3332*89c4ff92SAndroid Build Coastguard Worker     layer = m_Network->AddFullyConnectedLayer(desc, layerName.c_str());
3333*89c4ff92SAndroid Build Coastguard Worker 
3334*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3335*89c4ff92SAndroid Build Coastguard Worker 
3336*89c4ff92SAndroid Build Coastguard Worker     unsigned int startingSlotIndex = 0;
3337*89c4ff92SAndroid Build Coastguard Worker     if (inputTensorInfo.GetNumDimensions() > 2)
3338*89c4ff92SAndroid Build Coastguard Worker     {
3339*89c4ff92SAndroid Build Coastguard Worker         // Add reshape to flatten to 2D [batch_size, input_size],
3340*89c4ff92SAndroid Build Coastguard Worker         // where "input_size" corresponds to the number of inputs to the layer,
3341*89c4ff92SAndroid Build Coastguard Worker         // matching the second dimension of weights,
3342*89c4ff92SAndroid Build Coastguard Worker         // and "batch_size" is calculated by dividing the number of elements by "input_size".
3343*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int> reshapedDimensions(2);
3344*89c4ff92SAndroid Build Coastguard Worker         reshapedDimensions[1] = filterTensorInfo.GetShape()[1];
3345*89c4ff92SAndroid Build Coastguard Worker         reshapedDimensions[0] = inputTensorInfo.GetNumElements() / reshapedDimensions[1];
3346*89c4ff92SAndroid Build Coastguard Worker 
3347*89c4ff92SAndroid Build Coastguard Worker         if (inputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
3348*89c4ff92SAndroid Build Coastguard Worker         {
3349*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
3350*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("Failed to deduce input tensor shape from filter size {} {}",
3351*89c4ff92SAndroid Build Coastguard Worker                                 reshapedDimensions[1],
3352*89c4ff92SAndroid Build Coastguard Worker                                 CHECK_LOCATION().AsString()));
3353*89c4ff92SAndroid Build Coastguard Worker         }
3354*89c4ff92SAndroid Build Coastguard Worker 
3355*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo reshapedTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3356*89c4ff92SAndroid Build Coastguard Worker         reshapedTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
3357*89c4ff92SAndroid Build Coastguard Worker         inputTensorInfo = reshapedTensorInfo;
3358*89c4ff92SAndroid Build Coastguard Worker 
3359*89c4ff92SAndroid Build Coastguard Worker         std::string reshapeLayerName = fmt::format("Reshape_for:{}", layer->GetName());
3360*89c4ff92SAndroid Build Coastguard Worker         armnn::ReshapeDescriptor reshapeDescriptor;
3361*89c4ff92SAndroid Build Coastguard Worker         reshapeDescriptor.m_TargetShape = reshapedTensorInfo.GetShape();
3362*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(reshapeDescriptor,
3363*89c4ff92SAndroid Build Coastguard Worker                                                                             reshapeLayerName.c_str());
3364*89c4ff92SAndroid Build Coastguard Worker 
3365*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapedTensorInfo);
3366*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(0));
3367*89c4ff92SAndroid Build Coastguard Worker 
3368*89c4ff92SAndroid Build Coastguard Worker         RegisterInputSlots(subgraphIndex, operatorIndex, reshapeLayer, {inputTensorIndexes[0]});
3369*89c4ff92SAndroid Build Coastguard Worker         // Fc layer connects to the reshape layer, so we skip the first input slot when registering fc's input slots
3370*89c4ff92SAndroid Build Coastguard Worker         tensorIndexesToRegister.erase(tensorIndexesToRegister.begin());
3371*89c4ff92SAndroid Build Coastguard Worker         startingSlotIndex = 1;
3372*89c4ff92SAndroid Build Coastguard Worker     }
3373*89c4ff92SAndroid Build Coastguard Worker 
3374*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, tensorIndexesToRegister, startingSlotIndex);
3375*89c4ff92SAndroid Build Coastguard Worker 
3376*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromShapes(subgraphIndex, operatorIndex, layer, 0,
3377*89c4ff92SAndroid Build Coastguard Worker                                                                     { inputTensorInfo.GetShape(),
3378*89c4ff92SAndroid Build Coastguard Worker                                                                       filterTensorInfo.GetShape() });
3379*89c4ff92SAndroid Build Coastguard Worker 
3380*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3381*89c4ff92SAndroid Build Coastguard Worker 
3382*89c4ff92SAndroid Build Coastguard Worker     if (outputTensorInfo.GetNumDimensions() > 2)
3383*89c4ff92SAndroid Build Coastguard Worker     {
3384*89c4ff92SAndroid Build Coastguard Worker         // Calculate reshape to flatten to 2D [batch_size, input_size]
3385*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int> reshapedDimensions(2);
3386*89c4ff92SAndroid Build Coastguard Worker         reshapedDimensions[1] = filterTensorInfo.GetShape()[0];
3387*89c4ff92SAndroid Build Coastguard Worker         reshapedDimensions[0] = outputTensorInfo.GetNumElements() / reshapedDimensions[1];
3388*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo reshapedOutputTensorInfo = outputTensorInfo;
3389*89c4ff92SAndroid Build Coastguard Worker         if (outputTensorInfo.GetNumElements() % reshapedDimensions[1] != 0)
3390*89c4ff92SAndroid Build Coastguard Worker         {
3391*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
3392*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("Failed to deduce output tensor shape from filter size {} {}",
3393*89c4ff92SAndroid Build Coastguard Worker                                 reshapedDimensions[1],
3394*89c4ff92SAndroid Build Coastguard Worker                                 CHECK_LOCATION().AsString()));
3395*89c4ff92SAndroid Build Coastguard Worker         }
3396*89c4ff92SAndroid Build Coastguard Worker         reshapedOutputTensorInfo.SetShape(armnn::TensorShape{ 2, reshapedDimensions.data() });
3397*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(0).SetTensorInfo(reshapedOutputTensorInfo);
3398*89c4ff92SAndroid Build Coastguard Worker 
3399*89c4ff92SAndroid Build Coastguard Worker         std::string reshapeLayerName = fmt::format("ExpandDims:{}:{}", subgraphIndex, operatorIndex);
3400*89c4ff92SAndroid Build Coastguard Worker         layer = AddReshapeLayer(layer, 0, reshapeLayerName, outputTensorInfo);
3401*89c4ff92SAndroid Build Coastguard Worker     }
3402*89c4ff92SAndroid Build Coastguard Worker 
3403*89c4ff92SAndroid Build Coastguard Worker     // we need to add the activation layer and fortunately we don't need to care about the data layout
3404*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* fusedActivationLayer = AddFusedActivationLayer(layer, 0,
3405*89c4ff92SAndroid Build Coastguard Worker                                                                              options->fused_activation_function);
3406*89c4ff92SAndroid Build Coastguard Worker 
3407*89c4ff92SAndroid Build Coastguard Worker     // register the output connection slots for the layer, connections are made after all layers have been created
3408*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3409*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, fusedActivationLayer, {outputTensorIndexes[0]});
3410*89c4ff92SAndroid Build Coastguard Worker 
3411*89c4ff92SAndroid Build Coastguard Worker     m_TensorInfos[outputTensorIndexes[0]] = layer->GetOutputSlot(0).GetTensorInfo();
3412*89c4ff92SAndroid Build Coastguard Worker }
3413*89c4ff92SAndroid Build Coastguard Worker 
ParseDetectionPostProcess(size_t subgraphIndex,size_t operatorIndex)3414*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseDetectionPostProcess(size_t subgraphIndex, size_t operatorIndex)
3415*89c4ff92SAndroid Build Coastguard Worker {
3416*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3417*89c4ff92SAndroid Build Coastguard Worker 
3418*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3419*89c4ff92SAndroid Build Coastguard Worker 
3420*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3421*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3422*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 4);
3423*89c4ff92SAndroid Build Coastguard Worker 
3424*89c4ff92SAndroid Build Coastguard Worker     // Obtain custom options from flexbuffers
3425*89c4ff92SAndroid Build Coastguard Worker     auto custom_options = operatorPtr->custom_options;
3426*89c4ff92SAndroid Build Coastguard Worker     const flexbuffers::Map& m = flexbuffers::GetRoot(custom_options.data(), custom_options.size()).AsMap();
3427*89c4ff92SAndroid Build Coastguard Worker 
3428*89c4ff92SAndroid Build Coastguard Worker     // Obtain descriptor information from tf lite
3429*89c4ff92SAndroid Build Coastguard Worker     DetectionPostProcessDescriptor desc;
3430*89c4ff92SAndroid Build Coastguard Worker     desc.m_MaxDetections           = m["max_detections"].AsUInt32();
3431*89c4ff92SAndroid Build Coastguard Worker     desc.m_MaxClassesPerDetection  = m["max_classes_per_detection"].AsUInt32();
3432*89c4ff92SAndroid Build Coastguard Worker     desc.m_NmsScoreThreshold       = m["nms_score_threshold"].AsFloat();
3433*89c4ff92SAndroid Build Coastguard Worker     desc.m_NmsIouThreshold         = m["nms_iou_threshold"].AsFloat();
3434*89c4ff92SAndroid Build Coastguard Worker     desc.m_NumClasses              = m["num_classes"].AsUInt32();
3435*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleH                  = m["h_scale"].AsFloat();
3436*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleW                  = m["w_scale"].AsFloat();
3437*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleX                  = m["x_scale"].AsFloat();
3438*89c4ff92SAndroid Build Coastguard Worker     desc.m_ScaleY                  = m["y_scale"].AsFloat();
3439*89c4ff92SAndroid Build Coastguard Worker 
3440*89c4ff92SAndroid Build Coastguard Worker     if (!(m["use_regular_nms"].IsNull()))
3441*89c4ff92SAndroid Build Coastguard Worker     {
3442*89c4ff92SAndroid Build Coastguard Worker         desc.m_UseRegularNms       = m["use_regular_nms"].AsBool();
3443*89c4ff92SAndroid Build Coastguard Worker     }
3444*89c4ff92SAndroid Build Coastguard Worker     if (!(m["detections_per_class"].IsNull()))
3445*89c4ff92SAndroid Build Coastguard Worker     {
3446*89c4ff92SAndroid Build Coastguard Worker         desc.m_DetectionsPerClass  = m["detections_per_class"].AsUInt32();
3447*89c4ff92SAndroid Build Coastguard Worker     }
3448*89c4ff92SAndroid Build Coastguard Worker 
3449*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_NmsIouThreshold <= 0.0f || desc.m_NmsIouThreshold > 1.0f)
3450*89c4ff92SAndroid Build Coastguard Worker     {
3451*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException("DetectionPostProcessTFLiteParser: Intersection over union threshold "
3452*89c4ff92SAndroid Build Coastguard Worker                                        "must be positive and less than or equal to 1.");
3453*89c4ff92SAndroid Build Coastguard Worker     }
3454*89c4ff92SAndroid Build Coastguard Worker 
3455*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo anchorTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 2);
3456*89c4ff92SAndroid Build Coastguard Worker     auto anchorTensorAndData = CreateConstTensorNonPermuted(inputs[2], anchorTensorInfo);
3457*89c4ff92SAndroid Build Coastguard Worker 
3458*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("DetectionPostProcess:{}:{}", subgraphIndex, operatorIndex);
3459*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddDetectionPostProcessLayer(desc, anchorTensorAndData,
3460*89c4ff92SAndroid Build Coastguard Worker                                                                        layerName.c_str());
3461*89c4ff92SAndroid Build Coastguard Worker 
3462*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3463*89c4ff92SAndroid Build Coastguard Worker 
3464*89c4ff92SAndroid Build Coastguard Worker     // The model does not specify the output shapes.
3465*89c4ff92SAndroid Build Coastguard Worker     // The output shapes are calculated from the max_detection and max_classes_per_detection.
3466*89c4ff92SAndroid Build Coastguard Worker     unsigned int numDetectedBox = desc.m_MaxDetections * desc.m_MaxClassesPerDetection;
3467*89c4ff92SAndroid Build Coastguard Worker     m_OverriddenOutputShapes.push_back({ 1, numDetectedBox, 4 });
3468*89c4ff92SAndroid Build Coastguard Worker     m_OverriddenOutputShapes.push_back({ 1, numDetectedBox });
3469*89c4ff92SAndroid Build Coastguard Worker     m_OverriddenOutputShapes.push_back({ 1, numDetectedBox });
3470*89c4ff92SAndroid Build Coastguard Worker     m_OverriddenOutputShapes.push_back({ 1 });
3471*89c4ff92SAndroid Build Coastguard Worker 
3472*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0 ; i < outputs.size() ; ++i)
3473*89c4ff92SAndroid Build Coastguard Worker     {
3474*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo detectionBoxOutputTensorInfo = ToTensorInfo(outputs[i], m_OverriddenOutputShapes[i]);
3475*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(i).SetTensorInfo(detectionBoxOutputTensorInfo);
3476*89c4ff92SAndroid Build Coastguard Worker     }
3477*89c4ff92SAndroid Build Coastguard Worker 
3478*89c4ff92SAndroid Build Coastguard Worker     // Register the input connection slots for the layer, connections are made after all layers have been created
3479*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
3480*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3481*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
3482*89c4ff92SAndroid Build Coastguard Worker 
3483*89c4ff92SAndroid Build Coastguard Worker     // Register the output connection slots for the layer, connections are made after all layers have been created
3484*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3485*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0],
3486*89c4ff92SAndroid Build Coastguard Worker                                                               outputTensorIndexes[1],
3487*89c4ff92SAndroid Build Coastguard Worker                                                               outputTensorIndexes[2],
3488*89c4ff92SAndroid Build Coastguard Worker                                                               outputTensorIndexes[3]});
3489*89c4ff92SAndroid Build Coastguard Worker }
3490*89c4ff92SAndroid Build Coastguard Worker 
3491*89c4ff92SAndroid Build Coastguard Worker /// The TfLite Pack operator is equivalent to the ArmNN Stack operator
ParsePack(size_t subgraphIndex,size_t operatorIndex)3492*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParsePack(size_t subgraphIndex, size_t operatorIndex)
3493*89c4ff92SAndroid Build Coastguard Worker {
3494*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3495*89c4ff92SAndroid Build Coastguard Worker 
3496*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3497*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3498*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
3499*89c4ff92SAndroid Build Coastguard Worker 
3500*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() < 1)
3501*89c4ff92SAndroid Build Coastguard Worker     {
3502*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("Pack must have at least one input.");
3503*89c4ff92SAndroid Build Coastguard Worker     }
3504*89c4ff92SAndroid Build Coastguard Worker 
3505*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3506*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsPackOptions();
3507*89c4ff92SAndroid Build Coastguard Worker 
3508*89c4ff92SAndroid Build Coastguard Worker     StackDescriptor desc;
3509*89c4ff92SAndroid Build Coastguard Worker     desc.m_Axis = static_cast<uint32_t>(options->axis);
3510*89c4ff92SAndroid Build Coastguard Worker     desc.m_NumInputs = static_cast<uint32_t>(inputs.size());
3511*89c4ff92SAndroid Build Coastguard Worker 
3512*89c4ff92SAndroid Build Coastguard Worker     // Use the tensor shape of the first input as the "correct" input shape in the descriptor
3513*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3514*89c4ff92SAndroid Build Coastguard Worker     desc.m_InputShape = inputTensorInfo.GetShape();
3515*89c4ff92SAndroid Build Coastguard Worker 
3516*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Pack:{}:{}", subgraphIndex, operatorIndex);
3517*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddStackLayer(desc, layerName.c_str());
3518*89c4ff92SAndroid Build Coastguard Worker 
3519*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3520*89c4ff92SAndroid Build Coastguard Worker 
3521*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {});
3522*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3523*89c4ff92SAndroid Build Coastguard Worker 
3524*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3525*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes});
3526*89c4ff92SAndroid Build Coastguard Worker 
3527*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3528*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
3529*89c4ff92SAndroid Build Coastguard Worker }
3530*89c4ff92SAndroid Build Coastguard Worker 
ParseUnidirectionalSequenceLSTM(size_t subgraphIndex,size_t operatorIndex)3531*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseUnidirectionalSequenceLSTM(size_t subgraphIndex, size_t operatorIndex)
3532*89c4ff92SAndroid Build Coastguard Worker {
3533*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3534*89c4ff92SAndroid Build Coastguard Worker 
3535*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3536*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3537*89c4ff92SAndroid Build Coastguard Worker 
3538*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() < 2)
3539*89c4ff92SAndroid Build Coastguard Worker     {
3540*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("UnidirectionalSequenceLSTM must have at least 2 input.");
3541*89c4ff92SAndroid Build Coastguard Worker     }
3542*89c4ff92SAndroid Build Coastguard Worker 
3543*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3544*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = m_Model->subgraphs[subgraphIndex];
3545*89c4ff92SAndroid Build Coastguard Worker     const auto nodeParams = operatorPtr->builtin_options.AsUnidirectionalSequenceLSTMOptions();
3546*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUPPORTED_FUSED_ACTIVATION(nodeParams, subgraphIndex, operatorIndex);
3547*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3548*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorInfo = ToTensorInfo(outputs[0]);
3549*89c4ff92SAndroid Build Coastguard Worker 
3550*89c4ff92SAndroid Build Coastguard Worker     // Set the params structure for the AddUnidirectionalSequenceLstmLayer call
3551*89c4ff92SAndroid Build Coastguard Worker     // Please refer to each operand at
3552*89c4ff92SAndroid Build Coastguard Worker     // https://www.tensorflow.org/mlir/tfl_ops#tflunidirectional_sequence_lstm_tflunidirectionalsequencelstmop
3553*89c4ff92SAndroid Build Coastguard Worker     armnn::LstmInputParams params;
3554*89c4ff92SAndroid Build Coastguard Worker 
3555*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[1]))
3556*89c4ff92SAndroid Build Coastguard Worker     {
3557*89c4ff92SAndroid Build Coastguard Worker         params.m_InputToInputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[1]].get(),
3558*89c4ff92SAndroid Build Coastguard Worker                                                             inputTensorInfo).first;
3559*89c4ff92SAndroid Build Coastguard Worker     }
3560*89c4ff92SAndroid Build Coastguard Worker 
3561*89c4ff92SAndroid Build Coastguard Worker     params.m_InputToForgetWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[2]].get(),
3562*89c4ff92SAndroid Build Coastguard Worker                                                          inputTensorInfo).first;
3563*89c4ff92SAndroid Build Coastguard Worker     params.m_InputToCellWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[3]].get(),
3564*89c4ff92SAndroid Build Coastguard Worker                                                        inputTensorInfo).first;
3565*89c4ff92SAndroid Build Coastguard Worker     params.m_InputToOutputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[4]].get(),
3566*89c4ff92SAndroid Build Coastguard Worker                                                          inputTensorInfo).first;
3567*89c4ff92SAndroid Build Coastguard Worker 
3568*89c4ff92SAndroid Build Coastguard Worker     // Recurrent weight tensors of size {n_cell, n_output}
3569*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[5]))
3570*89c4ff92SAndroid Build Coastguard Worker     {
3571*89c4ff92SAndroid Build Coastguard Worker         params.m_RecurrentToInputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[5]].get(),
3572*89c4ff92SAndroid Build Coastguard Worker                                                                 inputTensorInfo).first;
3573*89c4ff92SAndroid Build Coastguard Worker     }
3574*89c4ff92SAndroid Build Coastguard Worker 
3575*89c4ff92SAndroid Build Coastguard Worker     params.m_RecurrentToForgetWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[6]].get(),
3576*89c4ff92SAndroid Build Coastguard Worker                                                              inputTensorInfo).first;
3577*89c4ff92SAndroid Build Coastguard Worker     params.m_RecurrentToCellWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[7]].get(),
3578*89c4ff92SAndroid Build Coastguard Worker                                                            inputTensorInfo).first;
3579*89c4ff92SAndroid Build Coastguard Worker     params.m_RecurrentToOutputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[8]].get(),
3580*89c4ff92SAndroid Build Coastguard Worker                                                              inputTensorInfo).first;
3581*89c4ff92SAndroid Build Coastguard Worker 
3582*89c4ff92SAndroid Build Coastguard Worker     // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
3583*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[9]))
3584*89c4ff92SAndroid Build Coastguard Worker     {
3585*89c4ff92SAndroid Build Coastguard Worker         params.m_CellToInputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[9]].get(),
3586*89c4ff92SAndroid Build Coastguard Worker                                                            inputTensorInfo).first;
3587*89c4ff92SAndroid Build Coastguard Worker     }
3588*89c4ff92SAndroid Build Coastguard Worker 
3589*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[10]))
3590*89c4ff92SAndroid Build Coastguard Worker     {
3591*89c4ff92SAndroid Build Coastguard Worker         params.m_CellToForgetWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[10]].get(),
3592*89c4ff92SAndroid Build Coastguard Worker                                                             inputTensorInfo).first;
3593*89c4ff92SAndroid Build Coastguard Worker     }
3594*89c4ff92SAndroid Build Coastguard Worker 
3595*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[11]))
3596*89c4ff92SAndroid Build Coastguard Worker     {
3597*89c4ff92SAndroid Build Coastguard Worker         params.m_CellToOutputWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[11]].get(),
3598*89c4ff92SAndroid Build Coastguard Worker                                                             inputTensorInfo).first;
3599*89c4ff92SAndroid Build Coastguard Worker     }
3600*89c4ff92SAndroid Build Coastguard Worker 
3601*89c4ff92SAndroid Build Coastguard Worker     // Gates bias tensors of size {n_cell}
3602*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[12]))
3603*89c4ff92SAndroid Build Coastguard Worker     {
3604*89c4ff92SAndroid Build Coastguard Worker         params.m_InputGateBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[12]].get(),
3605*89c4ff92SAndroid Build Coastguard Worker                                                       inputTensorInfo).first;
3606*89c4ff92SAndroid Build Coastguard Worker     }
3607*89c4ff92SAndroid Build Coastguard Worker 
3608*89c4ff92SAndroid Build Coastguard Worker     params.m_ForgetGateBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[13]].get(),
3609*89c4ff92SAndroid Build Coastguard Worker                                                    inputTensorInfo).first;
3610*89c4ff92SAndroid Build Coastguard Worker     params.m_CellBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[14]].get(),
3611*89c4ff92SAndroid Build Coastguard Worker                                              inputTensorInfo).first;
3612*89c4ff92SAndroid Build Coastguard Worker     params.m_OutputGateBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[15]].get(),
3613*89c4ff92SAndroid Build Coastguard Worker                                                    inputTensorInfo).first;
3614*89c4ff92SAndroid Build Coastguard Worker 
3615*89c4ff92SAndroid Build Coastguard Worker     // Projection weight tensor of size {n_output, n_cell}
3616*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[16]))
3617*89c4ff92SAndroid Build Coastguard Worker     {
3618*89c4ff92SAndroid Build Coastguard Worker         params.m_ProjectionWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[16]].get(),
3619*89c4ff92SAndroid Build Coastguard Worker                                                           inputTensorInfo).first;
3620*89c4ff92SAndroid Build Coastguard Worker     }
3621*89c4ff92SAndroid Build Coastguard Worker     // Projection bias tensor of size {n_output}
3622*89c4ff92SAndroid Build Coastguard Worker     if (IsOptionalOperandPresent(operatorPtr->inputs[17]))
3623*89c4ff92SAndroid Build Coastguard Worker     {
3624*89c4ff92SAndroid Build Coastguard Worker         params.m_ProjectionBias = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[17]].get(),
3625*89c4ff92SAndroid Build Coastguard Worker                                                        inputTensorInfo).first;
3626*89c4ff92SAndroid Build Coastguard Worker     }
3627*89c4ff92SAndroid Build Coastguard Worker 
3628*89c4ff92SAndroid Build Coastguard Worker     // These state tensors are defined as variable tensors, and will be modified by this op.
3629*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputStateInInfo = ToTensorInfo(subgraphPtr->tensors[operatorPtr->inputs[18]].get());
3630*89c4ff92SAndroid Build Coastguard Worker     m_ConstantsToBeCreated.push_back(operatorPtr->inputs[18]);
3631*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo cellStateInInfo = ToTensorInfo(subgraphPtr->tensors[operatorPtr->inputs[19]].get());
3632*89c4ff92SAndroid Build Coastguard Worker     m_ConstantsToBeCreated.push_back(operatorPtr->inputs[19]);
3633*89c4ff92SAndroid Build Coastguard Worker 
3634*89c4ff92SAndroid Build Coastguard Worker     // Layer norm coefficient tensors of size {n_cell}, representing a diagonal matrix.
3635*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() >= 21 && IsOptionalOperandPresent(operatorPtr->inputs[20]))
3636*89c4ff92SAndroid Build Coastguard Worker     {
3637*89c4ff92SAndroid Build Coastguard Worker         params.m_InputLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[20]].get(),
3638*89c4ff92SAndroid Build Coastguard Worker                                                               inputTensorInfo).first;
3639*89c4ff92SAndroid Build Coastguard Worker     }
3640*89c4ff92SAndroid Build Coastguard Worker 
3641*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() >= 22 && IsOptionalOperandPresent(operatorPtr->inputs[21]))
3642*89c4ff92SAndroid Build Coastguard Worker     {
3643*89c4ff92SAndroid Build Coastguard Worker         params.m_ForgetLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[21]].get(),
3644*89c4ff92SAndroid Build Coastguard Worker                                                                inputTensorInfo).first;
3645*89c4ff92SAndroid Build Coastguard Worker     }
3646*89c4ff92SAndroid Build Coastguard Worker 
3647*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() >= 23 && IsOptionalOperandPresent(operatorPtr->inputs[22]))
3648*89c4ff92SAndroid Build Coastguard Worker     {
3649*89c4ff92SAndroid Build Coastguard Worker         params.m_CellLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[22]].get(),
3650*89c4ff92SAndroid Build Coastguard Worker                                                              inputTensorInfo).first;
3651*89c4ff92SAndroid Build Coastguard Worker     }
3652*89c4ff92SAndroid Build Coastguard Worker 
3653*89c4ff92SAndroid Build Coastguard Worker     if (inputs.size() >= 24 && IsOptionalOperandPresent(operatorPtr->inputs[23]))
3654*89c4ff92SAndroid Build Coastguard Worker     {
3655*89c4ff92SAndroid Build Coastguard Worker         params.m_OutputLayerNormWeights = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->inputs[23]].get(),
3656*89c4ff92SAndroid Build Coastguard Worker                                                                inputTensorInfo).first;
3657*89c4ff92SAndroid Build Coastguard Worker     }
3658*89c4ff92SAndroid Build Coastguard Worker 
3659*89c4ff92SAndroid Build Coastguard Worker     // set the layer descriptor
3660*89c4ff92SAndroid Build Coastguard Worker     armnn::UnidirectionalSequenceLstmDescriptor desc;
3661*89c4ff92SAndroid Build Coastguard Worker     desc.m_ActivationFunc    = nodeParams->fused_activation_function;
3662*89c4ff92SAndroid Build Coastguard Worker     desc.m_ClippingThresCell = nodeParams->cell_clip;
3663*89c4ff92SAndroid Build Coastguard Worker     desc.m_ClippingThresProj = nodeParams->proj_clip;
3664*89c4ff92SAndroid Build Coastguard Worker     desc.m_CifgEnabled       = (params.m_InputToInputWeights == nullptr
3665*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_RecurrentToInputWeights == nullptr
3666*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_InputGateBias == nullptr);
3667*89c4ff92SAndroid Build Coastguard Worker     desc.m_PeepholeEnabled   = (params.m_CellToForgetWeights != nullptr || params.m_CellToOutputWeights != nullptr);
3668*89c4ff92SAndroid Build Coastguard Worker     desc.m_ProjectionEnabled = (params.m_ProjectionWeights != nullptr);
3669*89c4ff92SAndroid Build Coastguard Worker     desc.m_LayerNormEnabled  = (params.m_InputLayerNormWeights != nullptr
3670*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_ForgetLayerNormWeights != nullptr
3671*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_CellLayerNormWeights != nullptr
3672*89c4ff92SAndroid Build Coastguard Worker                                 || params.m_OutputLayerNormWeights != nullptr);
3673*89c4ff92SAndroid Build Coastguard Worker     desc.m_TimeMajor         = nodeParams->time_major;
3674*89c4ff92SAndroid Build Coastguard Worker 
3675*89c4ff92SAndroid Build Coastguard Worker     if (operatorPtr->intermediates.size() > 3 && desc.m_LayerNormEnabled)
3676*89c4ff92SAndroid Build Coastguard Worker     {
3677*89c4ff92SAndroid Build Coastguard Worker         auto inputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[0]].get(),
3678*89c4ff92SAndroid Build Coastguard Worker                                                       inputTensorInfo).first;
3679*89c4ff92SAndroid Build Coastguard Worker         auto inputIntermediateTensorInfo = inputIntermediate->GetInfo();
3680*89c4ff92SAndroid Build Coastguard Worker         desc.m_InputIntermediateScale = inputIntermediateTensorInfo.GetQuantizationScale();
3681*89c4ff92SAndroid Build Coastguard Worker 
3682*89c4ff92SAndroid Build Coastguard Worker         auto forgetIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[1]].get(),
3683*89c4ff92SAndroid Build Coastguard Worker                                                       inputTensorInfo).first;
3684*89c4ff92SAndroid Build Coastguard Worker         auto forgetIntermediateTensorInfo = forgetIntermediate->GetInfo();
3685*89c4ff92SAndroid Build Coastguard Worker         desc.m_ForgetIntermediateScale = forgetIntermediateTensorInfo.GetQuantizationScale();
3686*89c4ff92SAndroid Build Coastguard Worker 
3687*89c4ff92SAndroid Build Coastguard Worker         auto cellIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[2]].get(),
3688*89c4ff92SAndroid Build Coastguard Worker                                                       inputTensorInfo).first;
3689*89c4ff92SAndroid Build Coastguard Worker         auto cellIntermediateTensorInfo = cellIntermediate->GetInfo();
3690*89c4ff92SAndroid Build Coastguard Worker         desc.m_CellIntermediateScale = cellIntermediateTensorInfo.GetQuantizationScale();
3691*89c4ff92SAndroid Build Coastguard Worker 
3692*89c4ff92SAndroid Build Coastguard Worker         auto outputIntermediate = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[3]].get(),
3693*89c4ff92SAndroid Build Coastguard Worker                                                       inputTensorInfo).first;
3694*89c4ff92SAndroid Build Coastguard Worker         auto outputIntermediateTensorInfo = outputIntermediate->GetInfo();
3695*89c4ff92SAndroid Build Coastguard Worker         desc.m_OutputIntermediateScale = outputIntermediateTensorInfo.GetQuantizationScale();
3696*89c4ff92SAndroid Build Coastguard Worker     }
3697*89c4ff92SAndroid Build Coastguard Worker     else
3698*89c4ff92SAndroid Build Coastguard Worker     {
3699*89c4ff92SAndroid Build Coastguard Worker         float defaultIntermediate = std::pow(2, -12);
3700*89c4ff92SAndroid Build Coastguard Worker         desc.m_InputIntermediateScale = defaultIntermediate;
3701*89c4ff92SAndroid Build Coastguard Worker         desc.m_ForgetIntermediateScale = defaultIntermediate;
3702*89c4ff92SAndroid Build Coastguard Worker         desc.m_CellIntermediateScale = defaultIntermediate;
3703*89c4ff92SAndroid Build Coastguard Worker         desc.m_OutputIntermediateScale = defaultIntermediate;
3704*89c4ff92SAndroid Build Coastguard Worker     }
3705*89c4ff92SAndroid Build Coastguard Worker 
3706*89c4ff92SAndroid Build Coastguard Worker     if (operatorPtr->intermediates.size() > 4)
3707*89c4ff92SAndroid Build Coastguard Worker     {
3708*89c4ff92SAndroid Build Coastguard Worker         auto hiddentensor = CreateConstTensorPtr(subgraphPtr->tensors[operatorPtr->intermediates[4]].get(),
3709*89c4ff92SAndroid Build Coastguard Worker                                                  inputTensorInfo).first;
3710*89c4ff92SAndroid Build Coastguard Worker 
3711*89c4ff92SAndroid Build Coastguard Worker         desc.m_HiddenStateScale = hiddentensor->GetInfo().GetQuantizationScale();
3712*89c4ff92SAndroid Build Coastguard Worker         desc.m_HiddenStateZeroPoint = hiddentensor->GetInfo().GetQuantizationOffset();
3713*89c4ff92SAndroid Build Coastguard Worker     }
3714*89c4ff92SAndroid Build Coastguard Worker     unsigned int batchSize  = inputTensorInfo.GetShape()[0];
3715*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputSize = outputTensorInfo.GetShape()[2];
3716*89c4ff92SAndroid Build Coastguard Worker     unsigned int numUnits   = cellStateInInfo.GetShape()[1];
3717*89c4ff92SAndroid Build Coastguard Worker 
3718*89c4ff92SAndroid Build Coastguard Worker     armnn::DataType dataType = inputTensorInfo.GetDataType();
3719*89c4ff92SAndroid Build Coastguard Worker     float qScale = inputTensorInfo.GetQuantizationScale();
3720*89c4ff92SAndroid Build Coastguard Worker     float qOffset = inputTensorInfo.GetQuantizationOffset();
3721*89c4ff92SAndroid Build Coastguard Worker 
3722*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo scratchBufferTensorInfo({batchSize, numUnits * 3}, dataType, qScale, qOffset);
3723*89c4ff92SAndroid Build Coastguard Worker     if (!desc.m_CifgEnabled)
3724*89c4ff92SAndroid Build Coastguard Worker     {
3725*89c4ff92SAndroid Build Coastguard Worker         scratchBufferTensorInfo = armnn::TensorInfo({batchSize, numUnits * 4}, dataType, qScale, qOffset);
3726*89c4ff92SAndroid Build Coastguard Worker     }
3727*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo cellStateOutTensorInfo({batchSize, numUnits},
3728*89c4ff92SAndroid Build Coastguard Worker                                              cellStateInInfo.GetDataType(),
3729*89c4ff92SAndroid Build Coastguard Worker                                              cellStateInInfo.GetQuantizationScale(),
3730*89c4ff92SAndroid Build Coastguard Worker                                              cellStateInInfo.GetQuantizationOffset());
3731*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputStateOutTensorInfo({batchSize, outputSize}, dataType, qScale, qOffset);
3732*89c4ff92SAndroid Build Coastguard Worker 
3733*89c4ff92SAndroid Build Coastguard Worker     armnn::LstmInputParamsInfo paramsInfo;
3734*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_InputToForgetWeights     = &(params.m_InputToForgetWeights->GetInfo());
3735*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_InputToCellWeights       = &(params.m_InputToCellWeights->GetInfo());
3736*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_InputToOutputWeights     = &(params.m_InputToOutputWeights->GetInfo());
3737*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_RecurrentToForgetWeights = &(params.m_RecurrentToForgetWeights->GetInfo());
3738*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_RecurrentToCellWeights   = &(params.m_RecurrentToCellWeights->GetInfo());
3739*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_RecurrentToOutputWeights = &(params.m_RecurrentToOutputWeights->GetInfo());
3740*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_ForgetGateBias           = &(params.m_ForgetGateBias->GetInfo());
3741*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_CellBias                 = &(params.m_CellBias->GetInfo());
3742*89c4ff92SAndroid Build Coastguard Worker     paramsInfo.m_OutputGateBias           = &(params.m_OutputGateBias->GetInfo());
3743*89c4ff92SAndroid Build Coastguard Worker 
3744*89c4ff92SAndroid Build Coastguard Worker     if (!desc.m_CifgEnabled)
3745*89c4ff92SAndroid Build Coastguard Worker     {
3746*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_InputToInputWeights = &(params.m_InputToInputWeights->GetInfo());
3747*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_RecurrentToInputWeights = &(params.m_RecurrentToInputWeights->GetInfo());
3748*89c4ff92SAndroid Build Coastguard Worker         if (params.m_CellToInputWeights != nullptr)
3749*89c4ff92SAndroid Build Coastguard Worker         {
3750*89c4ff92SAndroid Build Coastguard Worker             paramsInfo.m_CellToInputWeights = &(params.m_CellToInputWeights->GetInfo());
3751*89c4ff92SAndroid Build Coastguard Worker         }
3752*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_InputGateBias = &(params.m_InputGateBias->GetInfo());
3753*89c4ff92SAndroid Build Coastguard Worker     }
3754*89c4ff92SAndroid Build Coastguard Worker 
3755*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_ProjectionEnabled)
3756*89c4ff92SAndroid Build Coastguard Worker     {
3757*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_ProjectionWeights = &(params.m_ProjectionWeights->GetInfo());
3758*89c4ff92SAndroid Build Coastguard Worker         if (params.m_ProjectionBias != nullptr)
3759*89c4ff92SAndroid Build Coastguard Worker         {
3760*89c4ff92SAndroid Build Coastguard Worker             paramsInfo.m_ProjectionBias = &(params.m_ProjectionBias->GetInfo());
3761*89c4ff92SAndroid Build Coastguard Worker         }
3762*89c4ff92SAndroid Build Coastguard Worker     }
3763*89c4ff92SAndroid Build Coastguard Worker 
3764*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_PeepholeEnabled)
3765*89c4ff92SAndroid Build Coastguard Worker     {
3766*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_CellToForgetWeights = &(params.m_CellToForgetWeights->GetInfo());
3767*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_CellToOutputWeights = &(params.m_CellToOutputWeights->GetInfo());
3768*89c4ff92SAndroid Build Coastguard Worker     }
3769*89c4ff92SAndroid Build Coastguard Worker 
3770*89c4ff92SAndroid Build Coastguard Worker     if (desc.m_LayerNormEnabled)
3771*89c4ff92SAndroid Build Coastguard Worker     {
3772*89c4ff92SAndroid Build Coastguard Worker         if(!desc.m_CifgEnabled)
3773*89c4ff92SAndroid Build Coastguard Worker         {
3774*89c4ff92SAndroid Build Coastguard Worker             paramsInfo.m_InputLayerNormWeights = &(params.m_InputLayerNormWeights->GetInfo());
3775*89c4ff92SAndroid Build Coastguard Worker         }
3776*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_ForgetLayerNormWeights = &(params.m_ForgetLayerNormWeights->GetInfo());
3777*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_CellLayerNormWeights = &(params.m_CellLayerNormWeights->GetInfo());
3778*89c4ff92SAndroid Build Coastguard Worker         paramsInfo.m_OutputLayerNormWeights = &(params.m_OutputLayerNormWeights->GetInfo());
3779*89c4ff92SAndroid Build Coastguard Worker     }
3780*89c4ff92SAndroid Build Coastguard Worker 
3781*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("UnidirectionalSequenceLSTM:{}:{}", subgraphIndex, operatorIndex);
3782*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* layer = m_Network->AddUnidirectionalSequenceLstmLayer(desc, params);
3783*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3784*89c4ff92SAndroid Build Coastguard Worker 
3785*89c4ff92SAndroid Build Coastguard Worker     // register the input connection slots for the layer, connections are made after all layers have been created
3786*89c4ff92SAndroid Build Coastguard Worker     // only the tensors for the inputs are relevant, exclude the const tensors
3787*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector({operatorPtr->inputs[0],
3788*89c4ff92SAndroid Build Coastguard Worker                                operatorPtr->inputs[18],
3789*89c4ff92SAndroid Build Coastguard Worker                                operatorPtr->inputs[19]});
3790*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0],
3791*89c4ff92SAndroid Build Coastguard Worker                                                              inputTensorIndexes[1],
3792*89c4ff92SAndroid Build Coastguard Worker                                                              inputTensorIndexes[2]});
3793*89c4ff92SAndroid Build Coastguard Worker 
3794*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
3795*89c4ff92SAndroid Build Coastguard Worker 
3796*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputStateOutTensorInfo);
3797*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(1).SetTensorInfo(cellStateOutTensorInfo);
3798*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(2).SetTensorInfo(outputTensorInfo);
3799*89c4ff92SAndroid Build Coastguard Worker 
3800*89c4ff92SAndroid Build Coastguard Worker     unsigned int tensorIndex = outputTensorIndexes[0];
3801*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot* slot = &(layer->GetOutputSlot(2));
3802*89c4ff92SAndroid Build Coastguard Worker     RegisterProducerOfTensor(subgraphIndex, tensorIndex, slot);
3803*89c4ff92SAndroid Build Coastguard Worker }
3804*89c4ff92SAndroid Build Coastguard Worker 
ParseUnpack(size_t subgraphIndex,size_t operatorIndex)3805*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseUnpack(size_t subgraphIndex, size_t operatorIndex)
3806*89c4ff92SAndroid Build Coastguard Worker {
3807*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3808*89c4ff92SAndroid Build Coastguard Worker 
3809*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3810*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsUnpackOptions();
3811*89c4ff92SAndroid Build Coastguard Worker 
3812*89c4ff92SAndroid Build Coastguard Worker     // This unpackAxis indicates the axis to unpack
3813*89c4ff92SAndroid Build Coastguard Worker     const unsigned int unpackAxis = CHECKED_NON_NEGATIVE(options->axis);
3814*89c4ff92SAndroid Build Coastguard Worker 
3815*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3816*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
3817*89c4ff92SAndroid Build Coastguard Worker 
3818*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3819*89c4ff92SAndroid Build Coastguard Worker 
3820*89c4ff92SAndroid Build Coastguard Worker     if (unpackAxis >= inputTensorInfo.GetNumDimensions())
3821*89c4ff92SAndroid Build Coastguard Worker     {
3822*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
3823*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("The unpack axis: {} cannot be greater than or equal to "
3824*89c4ff92SAndroid Build Coastguard Worker                             "the number of input dimension {} {}",
3825*89c4ff92SAndroid Build Coastguard Worker                             unpackAxis,
3826*89c4ff92SAndroid Build Coastguard Worker                             inputTensorInfo.GetNumDimensions(),
3827*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
3828*89c4ff92SAndroid Build Coastguard Worker     }
3829*89c4ff92SAndroid Build Coastguard Worker 
3830*89c4ff92SAndroid Build Coastguard Worker     unsigned int unpackNum = CHECKED_NON_NEGATIVE(options->num);
3831*89c4ff92SAndroid Build Coastguard Worker     // If num is not defined, automatically infer from the length of the dimension axis.
3832*89c4ff92SAndroid Build Coastguard Worker     if(unpackNum == 0)
3833*89c4ff92SAndroid Build Coastguard Worker     {
3834*89c4ff92SAndroid Build Coastguard Worker         unpackNum = inputTensorInfo.GetShape()[unpackAxis];
3835*89c4ff92SAndroid Build Coastguard Worker     }
3836*89c4ff92SAndroid Build Coastguard Worker 
3837*89c4ff92SAndroid Build Coastguard Worker     // If unpack number cannot be inferred and is still zero, throw ParseException.
3838*89c4ff92SAndroid Build Coastguard Worker     if(unpackNum == 0)
3839*89c4ff92SAndroid Build Coastguard Worker     {
3840*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("Number to unpack must greater than zero.");
3841*89c4ff92SAndroid Build Coastguard Worker     }
3842*89c4ff92SAndroid Build Coastguard Worker 
3843*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3844*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), unpackNum);
3845*89c4ff92SAndroid Build Coastguard Worker 
3846*89c4ff92SAndroid Build Coastguard Worker     auto inputDimSize = inputTensorInfo.GetNumDimensions();
3847*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> unpackDimSizes(inputDimSize);
3848*89c4ff92SAndroid Build Coastguard Worker 
3849*89c4ff92SAndroid Build Coastguard Worker     // Add current input shape to unpackDimSizes
3850*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputDimSize; ++i)
3851*89c4ff92SAndroid Build Coastguard Worker     {
3852*89c4ff92SAndroid Build Coastguard Worker         unpackDimSizes[i] = inputTensorInfo.GetShape()[i];
3853*89c4ff92SAndroid Build Coastguard Worker     }
3854*89c4ff92SAndroid Build Coastguard Worker 
3855*89c4ff92SAndroid Build Coastguard Worker     if (unpackDimSizes[unpackAxis] != unpackNum)
3856*89c4ff92SAndroid Build Coastguard Worker     {
3857*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("Number to unpack must be the same as length of the dimension to "
3858*89c4ff92SAndroid Build Coastguard Worker                              "unpack along.");
3859*89c4ff92SAndroid Build Coastguard Worker     }
3860*89c4ff92SAndroid Build Coastguard Worker 
3861*89c4ff92SAndroid Build Coastguard Worker     unpackDimSizes[unpackAxis] /= unpackNum;
3862*89c4ff92SAndroid Build Coastguard Worker 
3863*89c4ff92SAndroid Build Coastguard Worker     SplitterDescriptor splitDesc(unpackNum, static_cast<unsigned int>(unpackDimSizes.size()));
3864*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < unpackNum; ++j)
3865*89c4ff92SAndroid Build Coastguard Worker     {
3866*89c4ff92SAndroid Build Coastguard Worker         // Set the size of the views.
3867*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < unpackDimSizes.size(); ++dimIdx)
3868*89c4ff92SAndroid Build Coastguard Worker         {
3869*89c4ff92SAndroid Build Coastguard Worker             splitDesc.SetViewSize(j, dimIdx, unpackDimSizes[dimIdx]);
3870*89c4ff92SAndroid Build Coastguard Worker         }
3871*89c4ff92SAndroid Build Coastguard Worker         splitDesc.SetViewOriginCoord(j, unpackAxis, unpackDimSizes[unpackAxis] * j);
3872*89c4ff92SAndroid Build Coastguard Worker     }
3873*89c4ff92SAndroid Build Coastguard Worker 
3874*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Unpack:{}:{}", subgraphIndex, operatorIndex);
3875*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
3876*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
3877*89c4ff92SAndroid Build Coastguard Worker 
3878*89c4ff92SAndroid Build Coastguard Worker     TensorShape splitOutShape = TensorShape(static_cast<unsigned int>(unpackDimSizes.size()),
3879*89c4ff92SAndroid Build Coastguard Worker                                             unpackDimSizes.data());
3880*89c4ff92SAndroid Build Coastguard Worker 
3881*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
3882*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
3883*89c4ff92SAndroid Build Coastguard Worker 
3884*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> reshapeDims;
3885*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int axis = 0; axis < splitOutShape.GetNumDimensions(); ++axis)
3886*89c4ff92SAndroid Build Coastguard Worker     {
3887*89c4ff92SAndroid Build Coastguard Worker         if (axis != unpackAxis)
3888*89c4ff92SAndroid Build Coastguard Worker         {
3889*89c4ff92SAndroid Build Coastguard Worker             reshapeDims.push_back(splitOutShape[axis]);
3890*89c4ff92SAndroid Build Coastguard Worker         }
3891*89c4ff92SAndroid Build Coastguard Worker     }
3892*89c4ff92SAndroid Build Coastguard Worker 
3893*89c4ff92SAndroid Build Coastguard Worker     TensorShape reshapeOutputShape(splitOutShape.GetNumDimensions() -1, reshapeDims.data());
3894*89c4ff92SAndroid Build Coastguard Worker 
3895*89c4ff92SAndroid Build Coastguard Worker     // Create reshape to remove the unpacked dimension for unpack operator of each output from Splitter.
3896*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
3897*89c4ff92SAndroid Build Coastguard Worker     {
3898*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo outputTensorInfo  = ToTensorInfo(outputs[k], true);
3899*89c4ff92SAndroid Build Coastguard Worker         std::string reshapeLayerName = fmt::format("Reshape_for:{}", layer->GetName());
3900*89c4ff92SAndroid Build Coastguard Worker         armnn::ReshapeDescriptor desc;
3901*89c4ff92SAndroid Build Coastguard Worker         desc.m_TargetShape = reshapeOutputShape;
3902*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer* reshapeLayer = m_Network->AddReshapeLayer(desc, layerName.c_str());
3903*89c4ff92SAndroid Build Coastguard Worker 
3904*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(k).SetTensorInfo(armnn::TensorInfo(splitOutShape,
3905*89c4ff92SAndroid Build Coastguard Worker                                                                 outputTensorInfo.GetDataType(),
3906*89c4ff92SAndroid Build Coastguard Worker                                                                 outputTensorInfo.GetQuantizationScale(),
3907*89c4ff92SAndroid Build Coastguard Worker                                                                 outputTensorInfo.GetQuantizationOffset()));
3908*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(k).Connect(reshapeLayer->GetInputSlot(0));
3909*89c4ff92SAndroid Build Coastguard Worker 
3910*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
3911*89c4ff92SAndroid Build Coastguard Worker 
3912*89c4ff92SAndroid Build Coastguard Worker         uint32_t reshapedOutputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[k]);
3913*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot* slot = &(reshapeLayer->GetOutputSlot(0));
3914*89c4ff92SAndroid Build Coastguard Worker         RegisterProducerOfTensor(subgraphIndex, reshapedOutputId, slot);
3915*89c4ff92SAndroid Build Coastguard Worker     }
3916*89c4ff92SAndroid Build Coastguard Worker }
3917*89c4ff92SAndroid Build Coastguard Worker 
ParseSplit(size_t subgraphIndex,size_t operatorIndex)3918*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSplit(size_t subgraphIndex, size_t operatorIndex)
3919*89c4ff92SAndroid Build Coastguard Worker {
3920*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
3921*89c4ff92SAndroid Build Coastguard Worker 
3922*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
3923*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsSplitOptions();
3924*89c4ff92SAndroid Build Coastguard Worker 
3925*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numSplits = CHECKED_NON_NEGATIVE(options->num_splits);
3926*89c4ff92SAndroid Build Coastguard Worker 
3927*89c4ff92SAndroid Build Coastguard Worker     // If number of splits cannot be inferred and is zero, throw ParseException.
3928*89c4ff92SAndroid Build Coastguard Worker     if(numSplits == 0)
3929*89c4ff92SAndroid Build Coastguard Worker     {
3930*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("Number to splits must greater than zero.");
3931*89c4ff92SAndroid Build Coastguard Worker     }
3932*89c4ff92SAndroid Build Coastguard Worker 
3933*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
3934*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
3935*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
3936*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), numSplits);
3937*89c4ff92SAndroid Build Coastguard Worker 
3938*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
3939*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo axisTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
3940*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(axisTensorInfo.GetNumElements() == 1);
3941*89c4ff92SAndroid Build Coastguard Worker 
3942*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[0]->buffer);
3943*89c4ff92SAndroid Build Coastguard Worker     if (axisBufferPtr == nullptr)
3944*89c4ff92SAndroid Build Coastguard Worker     {
3945*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
3946*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid inputs. Failed to read axis. {}",
3947*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
3948*89c4ff92SAndroid Build Coastguard Worker     }
3949*89c4ff92SAndroid Build Coastguard Worker 
3950*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> axisData(axisTensorInfo.GetNumElements());
3951*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
3952*89c4ff92SAndroid Build Coastguard Worker     int32_t axis = axisData[0];
3953*89c4ff92SAndroid Build Coastguard Worker 
3954*89c4ff92SAndroid Build Coastguard Worker     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
3955*89c4ff92SAndroid Build Coastguard Worker     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
3956*89c4ff92SAndroid Build Coastguard Worker     {
3957*89c4ff92SAndroid Build Coastguard Worker         // Square bracket denotes inclusive n while parenthesis denotes exclusive n
3958*89c4ff92SAndroid Build Coastguard Worker         // E.g. Rank 4 tensor can have axis in range [-4, 3)
3959*89c4ff92SAndroid Build Coastguard Worker         // -1 == 3, -2 == 2, -3 == 1, -4 == 0
3960*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
3961*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid axis: {}. Axis must be in range [-n, n) {}",
3962*89c4ff92SAndroid Build Coastguard Worker                             axis,
3963*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
3964*89c4ff92SAndroid Build Coastguard Worker     }
3965*89c4ff92SAndroid Build Coastguard Worker 
3966*89c4ff92SAndroid Build Coastguard Worker     const unsigned int splitDim = armnnUtils::GetUnsignedAxis(inputTensorInfo.GetNumDimensions(), axis);
3967*89c4ff92SAndroid Build Coastguard Worker 
3968*89c4ff92SAndroid Build Coastguard Worker     auto inputDimSize = inputTensorInfo.GetNumDimensions();
3969*89c4ff92SAndroid Build Coastguard Worker     if (inputDimSize > MaxNumOfTensorDimensions)
3970*89c4ff92SAndroid Build Coastguard Worker     {
3971*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
3972*89c4ff92SAndroid Build Coastguard Worker             fmt::format("The number of dimensions: {} for input tensors of the split op cannot be greater than {} {}",
3973*89c4ff92SAndroid Build Coastguard Worker                         inputTensorInfo.GetNumDimensions(),
3974*89c4ff92SAndroid Build Coastguard Worker                         MaxNumOfTensorDimensions,
3975*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
3976*89c4ff92SAndroid Build Coastguard Worker     }
3977*89c4ff92SAndroid Build Coastguard Worker 
3978*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> splitterDimSizes(inputDimSize);
3979*89c4ff92SAndroid Build Coastguard Worker 
3980*89c4ff92SAndroid Build Coastguard Worker     // Add current input shape to splitterDimSizes
3981*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputDimSize; ++i)
3982*89c4ff92SAndroid Build Coastguard Worker     {
3983*89c4ff92SAndroid Build Coastguard Worker         splitterDimSizes[i] = inputTensorInfo.GetShape()[i];
3984*89c4ff92SAndroid Build Coastguard Worker     }
3985*89c4ff92SAndroid Build Coastguard Worker 
3986*89c4ff92SAndroid Build Coastguard Worker     if (splitterDimSizes[splitDim] % numSplits != 0)
3987*89c4ff92SAndroid Build Coastguard Worker     {
3988*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("Number of splits must evenly divide the dimension");
3989*89c4ff92SAndroid Build Coastguard Worker     }
3990*89c4ff92SAndroid Build Coastguard Worker     splitterDimSizes[splitDim] /= numSplits;
3991*89c4ff92SAndroid Build Coastguard Worker 
3992*89c4ff92SAndroid Build Coastguard Worker     SplitterDescriptor splitDesc(numSplits, inputDimSize);
3993*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < numSplits; ++j)
3994*89c4ff92SAndroid Build Coastguard Worker     {
3995*89c4ff92SAndroid Build Coastguard Worker         // Set the size of the views.
3996*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < splitterDimSizes.size(); ++dimIdx)
3997*89c4ff92SAndroid Build Coastguard Worker         {
3998*89c4ff92SAndroid Build Coastguard Worker             splitDesc.SetViewSize(j, dimIdx, splitterDimSizes[dimIdx]);
3999*89c4ff92SAndroid Build Coastguard Worker         }
4000*89c4ff92SAndroid Build Coastguard Worker         splitDesc.SetViewOriginCoord(j, splitDim, splitterDimSizes[splitDim] * j);
4001*89c4ff92SAndroid Build Coastguard Worker     }
4002*89c4ff92SAndroid Build Coastguard Worker 
4003*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Split:{}:{}", subgraphIndex, operatorIndex);
4004*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
4005*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4006*89c4ff92SAndroid Build Coastguard Worker 
4007*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4008*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[1]});
4009*89c4ff92SAndroid Build Coastguard Worker 
4010*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
4011*89c4ff92SAndroid Build Coastguard Worker     {
4012*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo tensorInfo = ToTensorInfo(outputs[k], true);
4013*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(k).SetTensorInfo(tensorInfo);
4014*89c4ff92SAndroid Build Coastguard Worker     }
4015*89c4ff92SAndroid Build Coastguard Worker 
4016*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4017*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
4018*89c4ff92SAndroid Build Coastguard Worker }
4019*89c4ff92SAndroid Build Coastguard Worker 
ComputeWrappedIndex(int idx,unsigned int numDimsIn)4020*89c4ff92SAndroid Build Coastguard Worker unsigned int ComputeWrappedIndex(int idx, unsigned int numDimsIn)
4021*89c4ff92SAndroid Build Coastguard Worker {
4022*89c4ff92SAndroid Build Coastguard Worker     int numDims = armnn::numeric_cast<int>(numDimsIn);
4023*89c4ff92SAndroid Build Coastguard Worker     int v = idx < 0 ? numDims + idx : idx;
4024*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(v >= 0);
4025*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(v < numDims);
4026*89c4ff92SAndroid Build Coastguard Worker 
4027*89c4ff92SAndroid Build Coastguard Worker     return static_cast<unsigned int>(v);
4028*89c4ff92SAndroid Build Coastguard Worker }
4029*89c4ff92SAndroid Build Coastguard Worker 
ParseSplitV(size_t subgraphIndex,size_t operatorIndex)4030*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSplitV(size_t subgraphIndex, size_t operatorIndex)
4031*89c4ff92SAndroid Build Coastguard Worker {
4032*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4033*89c4ff92SAndroid Build Coastguard Worker 
4034*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
4035*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsSplitVOptions();
4036*89c4ff92SAndroid Build Coastguard Worker 
4037*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4038*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 3);
4039*89c4ff92SAndroid Build Coastguard Worker 
4040*89c4ff92SAndroid Build Coastguard Worker     auto& inputTensor = inputs[0];
4041*89c4ff92SAndroid Build Coastguard Worker     auto& splitsTensor = inputs[1];
4042*89c4ff92SAndroid Build Coastguard Worker     auto& axisTensor = inputs[2];
4043*89c4ff92SAndroid Build Coastguard Worker 
4044*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = ToTensorInfo(inputTensor);
4045*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo splitsInfo      = ToTensorInfo(splitsTensor);
4046*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo axisTensorInfo  = ToTensorInfo(axisTensor);
4047*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(axisTensorInfo.GetNumElements() == 1);
4048*89c4ff92SAndroid Build Coastguard Worker 
4049*89c4ff92SAndroid Build Coastguard Worker     // Inputs
4050*89c4ff92SAndroid Build Coastguard Worker     auto inputDimSize = inputTensorInfo.GetNumDimensions();
4051*89c4ff92SAndroid Build Coastguard Worker     if (inputDimSize > MaxNumOfTensorDimensions)
4052*89c4ff92SAndroid Build Coastguard Worker     {
4053*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4054*89c4ff92SAndroid Build Coastguard Worker             fmt::format("The number of dimensions: {} for input tensors of the "
4055*89c4ff92SAndroid Build Coastguard Worker                         "SplitV op cannot be greater than {} {}",
4056*89c4ff92SAndroid Build Coastguard Worker                         inputTensorInfo.GetNumDimensions(),
4057*89c4ff92SAndroid Build Coastguard Worker                         MaxNumOfTensorDimensions,
4058*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4059*89c4ff92SAndroid Build Coastguard Worker     }
4060*89c4ff92SAndroid Build Coastguard Worker 
4061*89c4ff92SAndroid Build Coastguard Worker     // Get split axis
4062*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr axisBufferPtr = GetBuffer(m_Model, axisTensor->buffer);
4063*89c4ff92SAndroid Build Coastguard Worker     if (axisBufferPtr == nullptr)
4064*89c4ff92SAndroid Build Coastguard Worker     {
4065*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4066*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid inputs. Failed to read axis. {}",
4067*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
4068*89c4ff92SAndroid Build Coastguard Worker     }
4069*89c4ff92SAndroid Build Coastguard Worker 
4070*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> axisData(axisTensorInfo.GetNumElements());
4071*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
4072*89c4ff92SAndroid Build Coastguard Worker     int32_t axis = axisData[0];
4073*89c4ff92SAndroid Build Coastguard Worker 
4074*89c4ff92SAndroid Build Coastguard Worker     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
4075*89c4ff92SAndroid Build Coastguard Worker     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
4076*89c4ff92SAndroid Build Coastguard Worker     {
4077*89c4ff92SAndroid Build Coastguard Worker         // Square bracket denotes inclusive n while parenthesis denotes exclusive n
4078*89c4ff92SAndroid Build Coastguard Worker         // E.g. Rank 4 tensor can have axis in range [-4, 3)
4079*89c4ff92SAndroid Build Coastguard Worker         // -1 == 3, -2 == 2, -3 == 1, -4 == 0
4080*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4081*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid axis: {}. Axis must be in range [-n, n) {}",
4082*89c4ff92SAndroid Build Coastguard Worker                             axis,
4083*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
4084*89c4ff92SAndroid Build Coastguard Worker     }
4085*89c4ff92SAndroid Build Coastguard Worker     const unsigned int splitDim = ComputeWrappedIndex(axis, inputTensorInfo.GetNumDimensions());
4086*89c4ff92SAndroid Build Coastguard Worker 
4087*89c4ff92SAndroid Build Coastguard Worker     // Set split sizes
4088*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(splitsInfo.GetNumDimensions(), 1);
4089*89c4ff92SAndroid Build Coastguard Worker     unsigned int numSplits{0};
4090*89c4ff92SAndroid Build Coastguard Worker 
4091*89c4ff92SAndroid Build Coastguard Worker     if(options)
4092*89c4ff92SAndroid Build Coastguard Worker     {
4093*89c4ff92SAndroid Build Coastguard Worker         numSplits = CHECKED_NON_NEGATIVE(options->num_splits);
4094*89c4ff92SAndroid Build Coastguard Worker     }
4095*89c4ff92SAndroid Build Coastguard Worker     else
4096*89c4ff92SAndroid Build Coastguard Worker     {
4097*89c4ff92SAndroid Build Coastguard Worker         numSplits = splitsInfo.GetNumElements();
4098*89c4ff92SAndroid Build Coastguard Worker     }
4099*89c4ff92SAndroid Build Coastguard Worker 
4100*89c4ff92SAndroid Build Coastguard Worker     if (numSplits <=0)
4101*89c4ff92SAndroid Build Coastguard Worker     {
4102*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("SplitV has invalid number of splits");
4103*89c4ff92SAndroid Build Coastguard Worker     }
4104*89c4ff92SAndroid Build Coastguard Worker 
4105*89c4ff92SAndroid Build Coastguard Worker     std::vector<int> splitsData(numSplits);
4106*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr splitsBufferPtr = GetBuffer(m_Model, splitsTensor->buffer);
4107*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(splitsData.data(), splitsBufferPtr->data.data(), splitsInfo.GetNumBytes());
4108*89c4ff92SAndroid Build Coastguard Worker 
4109*89c4ff92SAndroid Build Coastguard Worker     unsigned int idx = 0;
4110*89c4ff92SAndroid Build Coastguard Worker     int numInferred{0};
4111*89c4ff92SAndroid Build Coastguard Worker     unsigned int inferIdx{0};
4112*89c4ff92SAndroid Build Coastguard Worker     int splitSum{0};
4113*89c4ff92SAndroid Build Coastguard Worker     for (auto split : splitsData)
4114*89c4ff92SAndroid Build Coastguard Worker     {
4115*89c4ff92SAndroid Build Coastguard Worker         if (split < 0)
4116*89c4ff92SAndroid Build Coastguard Worker         {
4117*89c4ff92SAndroid Build Coastguard Worker             numInferred++;
4118*89c4ff92SAndroid Build Coastguard Worker             inferIdx = idx;
4119*89c4ff92SAndroid Build Coastguard Worker         }
4120*89c4ff92SAndroid Build Coastguard Worker         else
4121*89c4ff92SAndroid Build Coastguard Worker         {
4122*89c4ff92SAndroid Build Coastguard Worker             splitSum += split;
4123*89c4ff92SAndroid Build Coastguard Worker         }
4124*89c4ff92SAndroid Build Coastguard Worker         idx++;
4125*89c4ff92SAndroid Build Coastguard Worker     }
4126*89c4ff92SAndroid Build Coastguard Worker     // Check for inferred Axis
4127*89c4ff92SAndroid Build Coastguard Worker     if (numInferred == 0)
4128*89c4ff92SAndroid Build Coastguard Worker     {
4129*89c4ff92SAndroid Build Coastguard Worker         if (splitSum != armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]))
4130*89c4ff92SAndroid Build Coastguard Worker         {
4131*89c4ff92SAndroid Build Coastguard Worker             throw ParseException("SplitV split_sizes does not sum to the dimension of value along split_dim.");
4132*89c4ff92SAndroid Build Coastguard Worker         }
4133*89c4ff92SAndroid Build Coastguard Worker     }
4134*89c4ff92SAndroid Build Coastguard Worker     else if (numInferred == 1)
4135*89c4ff92SAndroid Build Coastguard Worker     {
4136*89c4ff92SAndroid Build Coastguard Worker         splitsData[inferIdx] = armnn::numeric_cast<int>(inputTensorInfo.GetShape()[splitDim]) - splitSum;
4137*89c4ff92SAndroid Build Coastguard Worker     }
4138*89c4ff92SAndroid Build Coastguard Worker     else
4139*89c4ff92SAndroid Build Coastguard Worker     {
4140*89c4ff92SAndroid Build Coastguard Worker         throw ParseException("Cannot infer split size for more than one split");
4141*89c4ff92SAndroid Build Coastguard Worker     }
4142*89c4ff92SAndroid Build Coastguard Worker 
4143*89c4ff92SAndroid Build Coastguard Worker     //Ouput size validation
4144*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4145*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), numSplits);
4146*89c4ff92SAndroid Build Coastguard Worker 
4147*89c4ff92SAndroid Build Coastguard Worker     // Setup Armnn descriptor
4148*89c4ff92SAndroid Build Coastguard Worker     SplitterDescriptor splitDesc(numSplits, inputDimSize);
4149*89c4ff92SAndroid Build Coastguard Worker     unsigned int accumSplit = 0;
4150*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int j = 0; j < numSplits; ++j)
4151*89c4ff92SAndroid Build Coastguard Worker     {
4152*89c4ff92SAndroid Build Coastguard Worker         unsigned int splitSize = armnn::numeric_cast<unsigned int>(splitsData[j]);
4153*89c4ff92SAndroid Build Coastguard Worker 
4154*89c4ff92SAndroid Build Coastguard Worker         // Set the size of the views.
4155*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int dimIdx = 0; dimIdx < inputTensorInfo.GetNumDimensions(); ++dimIdx)
4156*89c4ff92SAndroid Build Coastguard Worker         {
4157*89c4ff92SAndroid Build Coastguard Worker             unsigned int dimSize = inputTensorInfo.GetShape()[dimIdx];
4158*89c4ff92SAndroid Build Coastguard Worker             if (dimIdx == splitDim)
4159*89c4ff92SAndroid Build Coastguard Worker             {
4160*89c4ff92SAndroid Build Coastguard Worker                 dimSize = splitSize;
4161*89c4ff92SAndroid Build Coastguard Worker             }
4162*89c4ff92SAndroid Build Coastguard Worker             splitDesc.SetViewSize(j, dimIdx, dimSize);
4163*89c4ff92SAndroid Build Coastguard Worker         }
4164*89c4ff92SAndroid Build Coastguard Worker 
4165*89c4ff92SAndroid Build Coastguard Worker         splitDesc.SetViewOriginCoord(j, splitDim, accumSplit);
4166*89c4ff92SAndroid Build Coastguard Worker         accumSplit += splitSize;
4167*89c4ff92SAndroid Build Coastguard Worker     }
4168*89c4ff92SAndroid Build Coastguard Worker 
4169*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("SplitV:{}:{}", subgraphIndex, operatorIndex);
4170*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddSplitterLayer(splitDesc, layerName.c_str());
4171*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4172*89c4ff92SAndroid Build Coastguard Worker 
4173*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4174*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
4175*89c4ff92SAndroid Build Coastguard Worker 
4176*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int k = 0; k < layer->GetNumOutputSlots(); ++k)
4177*89c4ff92SAndroid Build Coastguard Worker     {
4178*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo tensorInfo = ToTensorInfo(outputs[k], true);
4179*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(k).SetTensorInfo(tensorInfo);
4180*89c4ff92SAndroid Build Coastguard Worker     }
4181*89c4ff92SAndroid Build Coastguard Worker 
4182*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4183*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
4184*89c4ff92SAndroid Build Coastguard Worker }
4185*89c4ff92SAndroid Build Coastguard Worker 
ParseArgMin(size_t subgraphIndex,size_t operatorIndex)4186*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseArgMin(size_t subgraphIndex, size_t operatorIndex)
4187*89c4ff92SAndroid Build Coastguard Worker {
4188*89c4ff92SAndroid Build Coastguard Worker     ParseArgMinMax(subgraphIndex, operatorIndex, armnn::ArgMinMaxFunction::Min);
4189*89c4ff92SAndroid Build Coastguard Worker }
4190*89c4ff92SAndroid Build Coastguard Worker 
ParseArgMax(size_t subgraphIndex,size_t operatorIndex)4191*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseArgMax(size_t subgraphIndex, size_t operatorIndex)
4192*89c4ff92SAndroid Build Coastguard Worker {
4193*89c4ff92SAndroid Build Coastguard Worker     ParseArgMinMax(subgraphIndex, operatorIndex, armnn::ArgMinMaxFunction::Max);
4194*89c4ff92SAndroid Build Coastguard Worker }
4195*89c4ff92SAndroid Build Coastguard Worker 
ParseArgMinMax(size_t subgraphIndex,size_t operatorIndex,ArgMinMaxFunction argMinMaxFunction)4196*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseArgMinMax(size_t subgraphIndex, size_t operatorIndex, ArgMinMaxFunction argMinMaxFunction)
4197*89c4ff92SAndroid Build Coastguard Worker {
4198*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4199*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4200*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
4201*89c4ff92SAndroid Build Coastguard Worker 
4202*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4203*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4204*89c4ff92SAndroid Build Coastguard Worker 
4205*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
4206*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo axisTensorInfo   = InputTensorInfo(subgraphIndex, operatorIndex, 1);
4207*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
4208*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(axisTensorInfo.GetNumElements() == 1);
4209*89c4ff92SAndroid Build Coastguard Worker 
4210*89c4ff92SAndroid Build Coastguard Worker     // Check if output tensor type is Signed32 or Signed64
4211*89c4ff92SAndroid Build Coastguard Worker     if (outputTensorInfo.GetDataType() != armnn::DataType::Signed32 &&
4212*89c4ff92SAndroid Build Coastguard Worker         outputTensorInfo.GetDataType() != armnn::DataType::Signed64)
4213*89c4ff92SAndroid Build Coastguard Worker     {
4214*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4215*89c4ff92SAndroid Build Coastguard Worker                 fmt::format(
4216*89c4ff92SAndroid Build Coastguard Worker                         "Output tensor data type is not supported. (Supported types: Signed32 & Signed64) {}",
4217*89c4ff92SAndroid Build Coastguard Worker                                 CHECK_LOCATION().AsString()));
4218*89c4ff92SAndroid Build Coastguard Worker     }
4219*89c4ff92SAndroid Build Coastguard Worker 
4220*89c4ff92SAndroid Build Coastguard Worker     // Get const axis value from model and set it to descriptor.
4221*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
4222*89c4ff92SAndroid Build Coastguard Worker     if (axisBufferPtr == nullptr)
4223*89c4ff92SAndroid Build Coastguard Worker     {
4224*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4225*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid inputs. Failed to read axis. {}",
4226*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
4227*89c4ff92SAndroid Build Coastguard Worker     }
4228*89c4ff92SAndroid Build Coastguard Worker 
4229*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> axisData(axisTensorInfo.GetNumElements());
4230*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(axisData.data(), axisBufferPtr->data.data(), axisTensorInfo.GetNumBytes());
4231*89c4ff92SAndroid Build Coastguard Worker     int32_t axis = axisData.front();
4232*89c4ff92SAndroid Build Coastguard Worker 
4233*89c4ff92SAndroid Build Coastguard Worker     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
4234*89c4ff92SAndroid Build Coastguard Worker     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
4235*89c4ff92SAndroid Build Coastguard Worker     {
4236*89c4ff92SAndroid Build Coastguard Worker         // Square bracket denotes inclusive n while parenthesis denotes exclusive n
4237*89c4ff92SAndroid Build Coastguard Worker         // E.g. Rank 4 tensor can have axis in range [-4, 3)
4238*89c4ff92SAndroid Build Coastguard Worker         // -1 == 3, -2 == 2, -3 == 1, -4 == 0
4239*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4240*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("Operation has invalid axis: {}. Axis must be in range [-n, n) {}",
4241*89c4ff92SAndroid Build Coastguard Worker                                     axis,
4242*89c4ff92SAndroid Build Coastguard Worker                                     CHECK_LOCATION().AsString()));
4243*89c4ff92SAndroid Build Coastguard Worker     }
4244*89c4ff92SAndroid Build Coastguard Worker 
4245*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxDescriptor desc;
4246*89c4ff92SAndroid Build Coastguard Worker     desc.m_Axis = axis;
4247*89c4ff92SAndroid Build Coastguard Worker     desc.m_Function = argMinMaxFunction;
4248*89c4ff92SAndroid Build Coastguard Worker 
4249*89c4ff92SAndroid Build Coastguard Worker     // Register a ArgMin/ArgMax layer.
4250*89c4ff92SAndroid Build Coastguard Worker     auto layerName = argMinMaxFunction == ArgMinMaxFunction::Max ? "ArgMax:{}:{}" : "ArgMin:{}:{}";
4251*89c4ff92SAndroid Build Coastguard Worker     auto layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
4252*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer *layer = m_Network->AddArgMinMaxLayer(desc, layerNameFormatted.c_str());
4253*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4254*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
4255*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4256*89c4ff92SAndroid Build Coastguard Worker 
4257*89c4ff92SAndroid Build Coastguard Worker     // Register input tensor to the layer.
4258*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4259*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
4260*89c4ff92SAndroid Build Coastguard Worker 
4261*89c4ff92SAndroid Build Coastguard Worker     // Register output tensor to the layer.
4262*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4263*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
4264*89c4ff92SAndroid Build Coastguard Worker }
4265*89c4ff92SAndroid Build Coastguard Worker 
ParseGather(size_t subgraphIndex,size_t operatorIndex)4266*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseGather(size_t subgraphIndex, size_t operatorIndex)
4267*89c4ff92SAndroid Build Coastguard Worker {
4268*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4269*89c4ff92SAndroid Build Coastguard Worker 
4270*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4271*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
4272*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4273*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4274*89c4ff92SAndroid Build Coastguard Worker 
4275*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
4276*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
4277*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
4278*89c4ff92SAndroid Build Coastguard Worker 
4279*89c4ff92SAndroid Build Coastguard Worker     armnn::GatherDescriptor gatherDescriptor;
4280*89c4ff92SAndroid Build Coastguard Worker 
4281*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
4282*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsGatherOptions();
4283*89c4ff92SAndroid Build Coastguard Worker     auto axis = options->axis;
4284*89c4ff92SAndroid Build Coastguard Worker 
4285*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Gather:{}:{}", subgraphIndex, operatorIndex);
4286*89c4ff92SAndroid Build Coastguard Worker 
4287*89c4ff92SAndroid Build Coastguard Worker     auto inputDimensions = static_cast<int32_t>(inputTensorInfo.GetNumDimensions());
4288*89c4ff92SAndroid Build Coastguard Worker     auto indicesDimensions = indicesTensorInfo.GetNumDimensions();
4289*89c4ff92SAndroid Build Coastguard Worker     auto outputDimensions = outputTensorInfo.GetNumDimensions();
4290*89c4ff92SAndroid Build Coastguard Worker     if (((axis < -inputDimensions) && (axis < 0)) || ((axis >= inputDimensions) && (axis > 0)))
4291*89c4ff92SAndroid Build Coastguard Worker     {
4292*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4293*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Operation has invalid axis: {} It is out of bounds [ -{}, {} ) {}",
4294*89c4ff92SAndroid Build Coastguard Worker                         axis,
4295*89c4ff92SAndroid Build Coastguard Worker                         inputDimensions, inputDimensions,
4296*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4297*89c4ff92SAndroid Build Coastguard Worker     }
4298*89c4ff92SAndroid Build Coastguard Worker     if (outputDimensions != static_cast<unsigned int>(inputDimensions) + indicesDimensions - 1)
4299*89c4ff92SAndroid Build Coastguard Worker     {
4300*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4301*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Operation has invalid output dimensions: {} Output must be an ({} + {} - 1) -D tensor {}",
4302*89c4ff92SAndroid Build Coastguard Worker                         outputDimensions,
4303*89c4ff92SAndroid Build Coastguard Worker                         inputDimensions, indicesDimensions,
4304*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4305*89c4ff92SAndroid Build Coastguard Worker     }
4306*89c4ff92SAndroid Build Coastguard Worker 
4307*89c4ff92SAndroid Build Coastguard Worker     gatherDescriptor.m_Axis = axis;
4308*89c4ff92SAndroid Build Coastguard Worker 
4309*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddGatherLayer(gatherDescriptor, layerName.c_str());
4310*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4311*89c4ff92SAndroid Build Coastguard Worker     outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
4312*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4313*89c4ff92SAndroid Build Coastguard Worker 
4314*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4315*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
4316*89c4ff92SAndroid Build Coastguard Worker 
4317*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4318*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
4319*89c4ff92SAndroid Build Coastguard Worker }
4320*89c4ff92SAndroid Build Coastguard Worker 
ParseGatherNd(size_t subgraphIndex,size_t operatorIndex)4321*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseGatherNd(size_t subgraphIndex, size_t operatorIndex)
4322*89c4ff92SAndroid Build Coastguard Worker {
4323*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4324*89c4ff92SAndroid Build Coastguard Worker 
4325*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4326*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
4327*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4328*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4329*89c4ff92SAndroid Build Coastguard Worker 
4330*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
4331*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
4332*89c4ff92SAndroid Build Coastguard Worker 
4333*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("GatherNd:{}:{}", subgraphIndex, operatorIndex);
4334*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddGatherNdLayer(layerName.c_str());
4335*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4336*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
4337*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4338*89c4ff92SAndroid Build Coastguard Worker 
4339*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4340*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
4341*89c4ff92SAndroid Build Coastguard Worker 
4342*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4343*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
4344*89c4ff92SAndroid Build Coastguard Worker }
4345*89c4ff92SAndroid Build Coastguard Worker 
ParseDepthToSpace(size_t subgraphIndex,size_t operatorIndex)4346*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseDepthToSpace(size_t subgraphIndex, size_t operatorIndex)
4347*89c4ff92SAndroid Build Coastguard Worker {
4348*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4349*89c4ff92SAndroid Build Coastguard Worker 
4350*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4351*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
4352*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::TensorRawPtrVector outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4353*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4354*89c4ff92SAndroid Build Coastguard Worker 
4355*89c4ff92SAndroid Build Coastguard Worker     armnn::DepthToSpaceDescriptor descriptor;
4356*89c4ff92SAndroid Build Coastguard Worker 
4357*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
4358*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsDepthToSpaceOptions();
4359*89c4ff92SAndroid Build Coastguard Worker     auto blockSize = options->block_size;
4360*89c4ff92SAndroid Build Coastguard Worker     if (blockSize < 2)
4361*89c4ff92SAndroid Build Coastguard Worker     {
4362*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4363*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Operation has invalid block size: {} Block size should be >= 2 {}",
4364*89c4ff92SAndroid Build Coastguard Worker                         blockSize,
4365*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4366*89c4ff92SAndroid Build Coastguard Worker     }
4367*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_BlockSize = armnn::numeric_cast<uint32_t>(blockSize);
4368*89c4ff92SAndroid Build Coastguard Worker 
4369*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("DepthToSpace:{}:{}", subgraphIndex, operatorIndex);
4370*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddDepthToSpaceLayer(descriptor, layerName.c_str());
4371*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4372*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
4373*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4374*89c4ff92SAndroid Build Coastguard Worker 
4375*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4376*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
4377*89c4ff92SAndroid Build Coastguard Worker 
4378*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4379*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
4380*89c4ff92SAndroid Build Coastguard Worker }
4381*89c4ff92SAndroid Build Coastguard Worker 
ParseSum(size_t subgraphIndex,size_t operatorIndex)4382*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSum(size_t subgraphIndex, size_t operatorIndex)
4383*89c4ff92SAndroid Build Coastguard Worker {
4384*89c4ff92SAndroid Build Coastguard Worker     ParseReduce(subgraphIndex, operatorIndex, armnn::ReduceOperation::Sum);
4385*89c4ff92SAndroid Build Coastguard Worker }
4386*89c4ff92SAndroid Build Coastguard Worker 
ParseReduceProd(size_t subgraphIndex,size_t operatorIndex)4387*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseReduceProd(size_t subgraphIndex, size_t operatorIndex)
4388*89c4ff92SAndroid Build Coastguard Worker {
4389*89c4ff92SAndroid Build Coastguard Worker     ParseReduce(subgraphIndex, operatorIndex, armnn::ReduceOperation::Prod);
4390*89c4ff92SAndroid Build Coastguard Worker }
4391*89c4ff92SAndroid Build Coastguard Worker 
ParseReduceMax(size_t subgraphIndex,size_t operatorIndex)4392*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseReduceMax(size_t subgraphIndex, size_t operatorIndex)
4393*89c4ff92SAndroid Build Coastguard Worker {
4394*89c4ff92SAndroid Build Coastguard Worker     ParseReduce(subgraphIndex, operatorIndex, armnn::ReduceOperation::Max);
4395*89c4ff92SAndroid Build Coastguard Worker }
4396*89c4ff92SAndroid Build Coastguard Worker 
ParseReduceMin(size_t subgraphIndex,size_t operatorIndex)4397*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseReduceMin(size_t subgraphIndex, size_t operatorIndex)
4398*89c4ff92SAndroid Build Coastguard Worker {
4399*89c4ff92SAndroid Build Coastguard Worker     ParseReduce(subgraphIndex, operatorIndex, armnn::ReduceOperation::Min);
4400*89c4ff92SAndroid Build Coastguard Worker }
4401*89c4ff92SAndroid Build Coastguard Worker 
ParseReduce(size_t subgraphIndex,size_t operatorIndex,ReduceOperation reduceOperation)4402*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseReduce(size_t subgraphIndex, size_t operatorIndex, ReduceOperation reduceOperation)
4403*89c4ff92SAndroid Build Coastguard Worker {
4404*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4405*89c4ff92SAndroid Build Coastguard Worker 
4406*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
4407*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsReducerOptions();
4408*89c4ff92SAndroid Build Coastguard Worker 
4409*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4410*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
4411*89c4ff92SAndroid Build Coastguard Worker 
4412*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4413*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4414*89c4ff92SAndroid Build Coastguard Worker 
4415*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("Reduce:{}:{}", subgraphIndex, operatorIndex);
4416*89c4ff92SAndroid Build Coastguard Worker 
4417*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo0 = InputTensorInfo(subgraphIndex, operatorIndex, 0);
4418*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo1 = InputTensorInfo(subgraphIndex, operatorIndex, 1);
4419*89c4ff92SAndroid Build Coastguard Worker 
4420*89c4ff92SAndroid Build Coastguard Worker     ReduceDescriptor desc;
4421*89c4ff92SAndroid Build Coastguard Worker     BufferRawPtr axisBufferPtr = GetBuffer(m_Model, inputs[1]->buffer);
4422*89c4ff92SAndroid Build Coastguard Worker     // Get const axis value from model and set it to descriptor.
4423*89c4ff92SAndroid Build Coastguard Worker     if (axisBufferPtr != nullptr)
4424*89c4ff92SAndroid Build Coastguard Worker     {
4425*89c4ff92SAndroid Build Coastguard Worker         std::vector<int32_t> axisData(inputTensorInfo1.GetNumElements());
4426*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(axisData.data(), axisBufferPtr->data.data(), inputTensorInfo1.GetNumBytes());
4427*89c4ff92SAndroid Build Coastguard Worker 
4428*89c4ff92SAndroid Build Coastguard Worker         // Convert the axis to unsigned int and remove duplicates.
4429*89c4ff92SAndroid Build Coastguard Worker         auto rank = static_cast<int32_t>(inputTensorInfo0.GetNumDimensions());
4430*89c4ff92SAndroid Build Coastguard Worker         std::set<unsigned int> uniqueAxis;
4431*89c4ff92SAndroid Build Coastguard Worker         std::transform(axisData.begin(),
4432*89c4ff92SAndroid Build Coastguard Worker                        axisData.end(),
4433*89c4ff92SAndroid Build Coastguard Worker                        std::inserter(uniqueAxis, uniqueAxis.begin()),
4434*89c4ff92SAndroid Build Coastguard Worker                        [rank](int i)->unsigned int{
4435*89c4ff92SAndroid Build Coastguard Worker                                return static_cast<uint32_t>(((i + rank) % rank)); });
4436*89c4ff92SAndroid Build Coastguard Worker         desc.m_vAxis.assign(uniqueAxis.begin(), uniqueAxis.end());
4437*89c4ff92SAndroid Build Coastguard Worker     }
4438*89c4ff92SAndroid Build Coastguard Worker     else
4439*89c4ff92SAndroid Build Coastguard Worker     {
4440*89c4ff92SAndroid Build Coastguard Worker         for (uint32_t i = 0; i < inputTensorInfo0.GetNumDimensions(); ++i)
4441*89c4ff92SAndroid Build Coastguard Worker         {
4442*89c4ff92SAndroid Build Coastguard Worker             desc.m_vAxis.push_back(i);
4443*89c4ff92SAndroid Build Coastguard Worker         }
4444*89c4ff92SAndroid Build Coastguard Worker     }
4445*89c4ff92SAndroid Build Coastguard Worker 
4446*89c4ff92SAndroid Build Coastguard Worker     desc.m_KeepDims        = options->keep_dims;
4447*89c4ff92SAndroid Build Coastguard Worker     desc.m_ReduceOperation = reduceOperation;
4448*89c4ff92SAndroid Build Coastguard Worker 
4449*89c4ff92SAndroid Build Coastguard Worker     // Register a new layer object, Sum.
4450*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddReduceLayer(desc, layerName.c_str());
4451*89c4ff92SAndroid Build Coastguard Worker 
4452*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
4453*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4454*89c4ff92SAndroid Build Coastguard Worker 
4455*89c4ff92SAndroid Build Coastguard Worker     // Register input tensor to the layer.
4456*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4457*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
4458*89c4ff92SAndroid Build Coastguard Worker 
4459*89c4ff92SAndroid Build Coastguard Worker     // Register output tensor to the layer.
4460*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4461*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
4462*89c4ff92SAndroid Build Coastguard Worker }
4463*89c4ff92SAndroid Build Coastguard Worker 
ParseLocalResponseNormalization(size_t subgraphIndex,size_t operatorIndex)4464*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLocalResponseNormalization(size_t subgraphIndex, size_t operatorIndex)
4465*89c4ff92SAndroid Build Coastguard Worker {
4466*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4467*89c4ff92SAndroid Build Coastguard Worker 
4468*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4469*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
4470*89c4ff92SAndroid Build Coastguard Worker 
4471*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4472*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4473*89c4ff92SAndroid Build Coastguard Worker 
4474*89c4ff92SAndroid Build Coastguard Worker     auto layerName = fmt::format("LRN:{}:{}", subgraphIndex, operatorIndex);
4475*89c4ff92SAndroid Build Coastguard Worker     std::string layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
4476*89c4ff92SAndroid Build Coastguard Worker 
4477*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 0);
4478*89c4ff92SAndroid Build Coastguard Worker 
4479*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = m_Model->subgraphs[subgraphIndex]->operators[operatorIndex];
4480*89c4ff92SAndroid Build Coastguard Worker     const auto* options = operatorPtr->builtin_options.AsLocalResponseNormalizationOptions();
4481*89c4ff92SAndroid Build Coastguard Worker 
4482*89c4ff92SAndroid Build Coastguard Worker     armnn::NormalizationDescriptor descriptor;
4483*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_DataLayout      = armnn::DataLayout::NHWC;
4484*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NormChannelType = armnn::NormalizationAlgorithmChannel::Across;
4485*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NormMethodType  = armnn::NormalizationAlgorithmMethod::LocalBrightness;
4486*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NormSize = static_cast<uint32_t>(options->radius);
4487*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_K = options->bias;
4488*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Alpha = options->alpha;
4489*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Beta = options->beta;
4490*89c4ff92SAndroid Build Coastguard Worker 
4491*89c4ff92SAndroid Build Coastguard Worker     // ArmNN expects normSize to be the full size of the normalization
4492*89c4ff92SAndroid Build Coastguard Worker     // window rather than the radius as in TfLite.
4493*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_NormSize = 1 + (2 * descriptor.m_NormSize);
4494*89c4ff92SAndroid Build Coastguard Worker 
4495*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddNormalizationLayer(descriptor, layerNameFormatted.c_str());
4496*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4497*89c4ff92SAndroid Build Coastguard Worker 
4498*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
4499*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4500*89c4ff92SAndroid Build Coastguard Worker 
4501*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4502*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
4503*89c4ff92SAndroid Build Coastguard Worker 
4504*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4505*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
4506*89c4ff92SAndroid Build Coastguard Worker }
4507*89c4ff92SAndroid Build Coastguard Worker 
ParseAbs(size_t subgraphIndex,size_t operatorIndex)4508*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseAbs(size_t subgraphIndex, size_t operatorIndex)
4509*89c4ff92SAndroid Build Coastguard Worker {
4510*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Abs);
4511*89c4ff92SAndroid Build Coastguard Worker }
4512*89c4ff92SAndroid Build Coastguard Worker 
ParseCeil(size_t subgraphIndex,size_t operatorIndex)4513*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseCeil(size_t subgraphIndex, size_t operatorIndex)
4514*89c4ff92SAndroid Build Coastguard Worker {
4515*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Ceil);
4516*89c4ff92SAndroid Build Coastguard Worker }
4517*89c4ff92SAndroid Build Coastguard Worker 
ParseExp(size_t subgraphIndex,size_t operatorIndex)4518*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseExp(size_t subgraphIndex, size_t operatorIndex)
4519*89c4ff92SAndroid Build Coastguard Worker {
4520*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Exp);
4521*89c4ff92SAndroid Build Coastguard Worker }
4522*89c4ff92SAndroid Build Coastguard Worker 
ParseLog(size_t subgraphIndex,size_t operatorIndex)4523*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLog(size_t subgraphIndex, size_t operatorIndex)
4524*89c4ff92SAndroid Build Coastguard Worker {
4525*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Log);
4526*89c4ff92SAndroid Build Coastguard Worker }
4527*89c4ff92SAndroid Build Coastguard Worker 
ParseLogicalNot(size_t subgraphIndex,size_t operatorIndex)4528*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLogicalNot(size_t subgraphIndex, size_t operatorIndex)
4529*89c4ff92SAndroid Build Coastguard Worker {
4530*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::LogicalNot);
4531*89c4ff92SAndroid Build Coastguard Worker }
4532*89c4ff92SAndroid Build Coastguard Worker 
ParseNeg(size_t subgraphIndex,size_t operatorIndex)4533*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseNeg(size_t subgraphIndex, size_t operatorIndex)
4534*89c4ff92SAndroid Build Coastguard Worker {
4535*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Neg);
4536*89c4ff92SAndroid Build Coastguard Worker }
4537*89c4ff92SAndroid Build Coastguard Worker 
ParseRsqrt(size_t subgraphIndex,size_t operatorIndex)4538*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseRsqrt(size_t subgraphIndex, size_t operatorIndex)
4539*89c4ff92SAndroid Build Coastguard Worker {
4540*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Rsqrt);
4541*89c4ff92SAndroid Build Coastguard Worker }
4542*89c4ff92SAndroid Build Coastguard Worker 
ParseSin(size_t subgraphIndex,size_t operatorIndex)4543*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSin(size_t subgraphIndex, size_t operatorIndex)
4544*89c4ff92SAndroid Build Coastguard Worker {
4545*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Sin);
4546*89c4ff92SAndroid Build Coastguard Worker }
4547*89c4ff92SAndroid Build Coastguard Worker 
ParseSqrt(size_t subgraphIndex,size_t operatorIndex)4548*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseSqrt(size_t subgraphIndex, size_t operatorIndex)
4549*89c4ff92SAndroid Build Coastguard Worker {
4550*89c4ff92SAndroid Build Coastguard Worker     ParseElementwiseUnary(subgraphIndex, operatorIndex, armnn::UnaryOperation::Sqrt);
4551*89c4ff92SAndroid Build Coastguard Worker }
4552*89c4ff92SAndroid Build Coastguard Worker 
ParseElementwiseUnary(size_t subgraphIndex,size_t operatorIndex,UnaryOperation unaryOperation)4553*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseElementwiseUnary(size_t subgraphIndex, size_t operatorIndex, UnaryOperation unaryOperation)
4554*89c4ff92SAndroid Build Coastguard Worker {
4555*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4556*89c4ff92SAndroid Build Coastguard Worker 
4557*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4558*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 1);
4559*89c4ff92SAndroid Build Coastguard Worker 
4560*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4561*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4562*89c4ff92SAndroid Build Coastguard Worker 
4563*89c4ff92SAndroid Build Coastguard Worker     std::string layerName = std::string(GetUnaryOperationAsCString(unaryOperation)) + ":{}:{}";
4564*89c4ff92SAndroid Build Coastguard Worker     std::string layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
4565*89c4ff92SAndroid Build Coastguard Worker 
4566*89c4ff92SAndroid Build Coastguard Worker     ElementwiseUnaryDescriptor desc;
4567*89c4ff92SAndroid Build Coastguard Worker     desc.m_Operation = unaryOperation;
4568*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddElementwiseUnaryLayer(desc, layerNameFormatted.c_str());
4569*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4570*89c4ff92SAndroid Build Coastguard Worker 
4571*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0});
4572*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4573*89c4ff92SAndroid Build Coastguard Worker 
4574*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4575*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0]});
4576*89c4ff92SAndroid Build Coastguard Worker 
4577*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4578*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, outputTensorIndexes);
4579*89c4ff92SAndroid Build Coastguard Worker }
4580*89c4ff92SAndroid Build Coastguard Worker 
ParseEqual(size_t subgraphIndex,size_t operatorIndex)4581*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseEqual(size_t subgraphIndex, size_t operatorIndex)
4582*89c4ff92SAndroid Build Coastguard Worker {
4583*89c4ff92SAndroid Build Coastguard Worker     ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Equal);
4584*89c4ff92SAndroid Build Coastguard Worker }
4585*89c4ff92SAndroid Build Coastguard Worker 
ParseNotEqual(size_t subgraphIndex,size_t operatorIndex)4586*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseNotEqual(size_t subgraphIndex, size_t operatorIndex)
4587*89c4ff92SAndroid Build Coastguard Worker {
4588*89c4ff92SAndroid Build Coastguard Worker     ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::NotEqual);
4589*89c4ff92SAndroid Build Coastguard Worker }
4590*89c4ff92SAndroid Build Coastguard Worker 
ParseGreater(size_t subgraphIndex,size_t operatorIndex)4591*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseGreater(size_t subgraphIndex, size_t operatorIndex)
4592*89c4ff92SAndroid Build Coastguard Worker {
4593*89c4ff92SAndroid Build Coastguard Worker     ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Greater);
4594*89c4ff92SAndroid Build Coastguard Worker }
4595*89c4ff92SAndroid Build Coastguard Worker 
ParseGreaterOrEqual(size_t subgraphIndex,size_t operatorIndex)4596*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseGreaterOrEqual(size_t subgraphIndex, size_t operatorIndex)
4597*89c4ff92SAndroid Build Coastguard Worker {
4598*89c4ff92SAndroid Build Coastguard Worker     ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::GreaterOrEqual);
4599*89c4ff92SAndroid Build Coastguard Worker }
4600*89c4ff92SAndroid Build Coastguard Worker 
ParseLess(size_t subgraphIndex,size_t operatorIndex)4601*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLess(size_t subgraphIndex, size_t operatorIndex)
4602*89c4ff92SAndroid Build Coastguard Worker {
4603*89c4ff92SAndroid Build Coastguard Worker     ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::Less);
4604*89c4ff92SAndroid Build Coastguard Worker }
4605*89c4ff92SAndroid Build Coastguard Worker 
ParseLessOrEqual(size_t subgraphIndex,size_t operatorIndex)4606*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseLessOrEqual(size_t subgraphIndex, size_t operatorIndex)
4607*89c4ff92SAndroid Build Coastguard Worker {
4608*89c4ff92SAndroid Build Coastguard Worker     ParseComparison(subgraphIndex, operatorIndex, armnn::ComparisonOperation::LessOrEqual);
4609*89c4ff92SAndroid Build Coastguard Worker }
4610*89c4ff92SAndroid Build Coastguard Worker 
ParseComparison(size_t subgraphIndex,size_t operatorIndex,ComparisonOperation comparisonOperation)4611*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::ParseComparison(size_t subgraphIndex, size_t operatorIndex,
4612*89c4ff92SAndroid Build Coastguard Worker                                        ComparisonOperation comparisonOperation)
4613*89c4ff92SAndroid Build Coastguard Worker {
4614*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4615*89c4ff92SAndroid Build Coastguard Worker 
4616*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetInputs(m_Model, subgraphIndex, operatorIndex);
4617*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(inputs.size(), 2);
4618*89c4ff92SAndroid Build Coastguard Worker 
4619*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetOutputs(m_Model, subgraphIndex, operatorIndex);
4620*89c4ff92SAndroid Build Coastguard Worker     CHECK_VALID_SIZE(outputs.size(), 1);
4621*89c4ff92SAndroid Build Coastguard Worker 
4622*89c4ff92SAndroid Build Coastguard Worker     auto layerName = std::string(GetComparisonOperationAsCString(comparisonOperation)) + ":{}:{}";
4623*89c4ff92SAndroid Build Coastguard Worker     std::string layerNameFormatted = fmt::format(layerName, subgraphIndex, operatorIndex);
4624*89c4ff92SAndroid Build Coastguard Worker 
4625*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo  = InputTensorInfo(subgraphIndex, operatorIndex, 0);
4626*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo input1TensorInfo = InputTensorInfo(subgraphIndex, operatorIndex, 1);
4627*89c4ff92SAndroid Build Coastguard Worker     CheckMatchingQuantization(inputTensorInfo, input1TensorInfo, layerNameFormatted, "Input 0", "Input 1");
4628*89c4ff92SAndroid Build Coastguard Worker 
4629*89c4ff92SAndroid Build Coastguard Worker     ComparisonDescriptor desc;
4630*89c4ff92SAndroid Build Coastguard Worker     desc.m_Operation = comparisonOperation;
4631*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* layer = m_Network->AddComparisonLayer(desc, layerNameFormatted.c_str());
4632*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4633*89c4ff92SAndroid Build Coastguard Worker 
4634*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputTensorInfo = OutputTensorInfoFromInputs(subgraphIndex, operatorIndex, layer, 0, {0, 1});
4635*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
4636*89c4ff92SAndroid Build Coastguard Worker 
4637*89c4ff92SAndroid Build Coastguard Worker     auto inputTensorIndexes = AsUnsignedVector(GetInputTensorIds(m_Model, subgraphIndex, operatorIndex));
4638*89c4ff92SAndroid Build Coastguard Worker     RegisterInputSlots(subgraphIndex, operatorIndex, layer, {inputTensorIndexes[0], inputTensorIndexes[1]});
4639*89c4ff92SAndroid Build Coastguard Worker 
4640*89c4ff92SAndroid Build Coastguard Worker     auto outputTensorIndexes = AsUnsignedVector(GetOutputTensorIds(m_Model, subgraphIndex, operatorIndex));
4641*89c4ff92SAndroid Build Coastguard Worker     RegisterOutputSlots(subgraphIndex, operatorIndex, layer, {outputTensorIndexes[0]});
4642*89c4ff92SAndroid Build Coastguard Worker }
4643*89c4ff92SAndroid Build Coastguard Worker 
AddReshapeLayer(armnn::IConnectableLayer * layer,unsigned int outputSlot,std::string reshapeLayerName,armnn::TensorInfo outputShape)4644*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* TfLiteParserImpl::AddReshapeLayer(armnn::IConnectableLayer* layer,
4645*89c4ff92SAndroid Build Coastguard Worker                                                             unsigned int outputSlot,
4646*89c4ff92SAndroid Build Coastguard Worker                                                             std::string reshapeLayerName,
4647*89c4ff92SAndroid Build Coastguard Worker                                                             armnn::TensorInfo outputShape)
4648*89c4ff92SAndroid Build Coastguard Worker {
4649*89c4ff92SAndroid Build Coastguard Worker     ReshapeDescriptor desc;
4650*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetShape = outputShape.GetShape();
4651*89c4ff92SAndroid Build Coastguard Worker 
4652*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* reshapeLayer =
4653*89c4ff92SAndroid Build Coastguard Worker             m_Network->AddReshapeLayer(desc, reshapeLayerName.c_str());
4654*89c4ff92SAndroid Build Coastguard Worker 
4655*89c4ff92SAndroid Build Coastguard Worker     auto & prevOutputSlot = layer->GetOutputSlot(outputSlot);
4656*89c4ff92SAndroid Build Coastguard Worker     prevOutputSlot.Connect(reshapeLayer->GetInputSlot(0));
4657*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer->GetOutputSlot(0).SetTensorInfo(outputShape);
4658*89c4ff92SAndroid Build Coastguard Worker     return reshapeLayer;
4659*89c4ff92SAndroid Build Coastguard Worker }
4660*89c4ff92SAndroid Build Coastguard Worker 
AddFusedActivationLayer(armnn::IConnectableLayer * prevLayer,unsigned int outputSlot,tflite::ActivationFunctionType activationType)4661*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* TfLiteParserImpl::AddFusedActivationLayer(armnn::IConnectableLayer* prevLayer,
4662*89c4ff92SAndroid Build Coastguard Worker                                                                     unsigned int outputSlot,
4663*89c4ff92SAndroid Build Coastguard Worker                                                                     tflite::ActivationFunctionType activationType)
4664*89c4ff92SAndroid Build Coastguard Worker {
4665*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor activationDesc;
4666*89c4ff92SAndroid Build Coastguard Worker     std::string layerName = prevLayer->GetName();
4667*89c4ff92SAndroid Build Coastguard Worker 
4668*89c4ff92SAndroid Build Coastguard Worker     switch(activationType)
4669*89c4ff92SAndroid Build Coastguard Worker     {
4670*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_NONE:
4671*89c4ff92SAndroid Build Coastguard Worker         {
4672*89c4ff92SAndroid Build Coastguard Worker             // this is a no-op: return previous layer
4673*89c4ff92SAndroid Build Coastguard Worker             return prevLayer;
4674*89c4ff92SAndroid Build Coastguard Worker         }
4675*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_RELU:
4676*89c4ff92SAndroid Build Coastguard Worker         {
4677*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = ActivationFunction::ReLu;
4678*89c4ff92SAndroid Build Coastguard Worker             layerName += ":RELU";
4679*89c4ff92SAndroid Build Coastguard Worker             break;
4680*89c4ff92SAndroid Build Coastguard Worker         }
4681*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_RELU6:
4682*89c4ff92SAndroid Build Coastguard Worker         {
4683*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = ActivationFunction::BoundedReLu;
4684*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 6.0f;
4685*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = 0.0f;
4686*89c4ff92SAndroid Build Coastguard Worker             layerName += ":RELU6";
4687*89c4ff92SAndroid Build Coastguard Worker             break;
4688*89c4ff92SAndroid Build Coastguard Worker         }
4689*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_TANH:
4690*89c4ff92SAndroid Build Coastguard Worker         {
4691*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_Function = ActivationFunction::TanH;
4692*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_A = 1.0f;
4693*89c4ff92SAndroid Build Coastguard Worker             activationDesc.m_B = 1.0f;
4694*89c4ff92SAndroid Build Coastguard Worker             layerName += ":TANH";
4695*89c4ff92SAndroid Build Coastguard Worker             break;
4696*89c4ff92SAndroid Build Coastguard Worker         }
4697*89c4ff92SAndroid Build Coastguard Worker 
4698*89c4ff92SAndroid Build Coastguard Worker         // I only put these here as a reminder what others we could support
4699*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_RELU_N1_TO_1:
4700*89c4ff92SAndroid Build Coastguard Worker         case tflite::ActivationFunctionType_SIGN_BIT:
4701*89c4ff92SAndroid Build Coastguard Worker         default:
4702*89c4ff92SAndroid Build Coastguard Worker         {
4703*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
4704*89c4ff92SAndroid Build Coastguard Worker                 fmt::format("TfLite parser doesn't support fused activation: "
4705*89c4ff92SAndroid Build Coastguard Worker                             "{}/{} {} ",
4706*89c4ff92SAndroid Build Coastguard Worker                             activationType,
4707*89c4ff92SAndroid Build Coastguard Worker                             tflite::EnumNameActivationFunctionType(activationType),
4708*89c4ff92SAndroid Build Coastguard Worker                             CHECK_LOCATION().AsString()));
4709*89c4ff92SAndroid Build Coastguard Worker 
4710*89c4ff92SAndroid Build Coastguard Worker         }
4711*89c4ff92SAndroid Build Coastguard Worker     }
4712*89c4ff92SAndroid Build Coastguard Worker 
4713*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* activationLayer =
4714*89c4ff92SAndroid Build Coastguard Worker         m_Network->AddActivationLayer(activationDesc, layerName.c_str());
4715*89c4ff92SAndroid Build Coastguard Worker 
4716*89c4ff92SAndroid Build Coastguard Worker     auto & prevOutputSlot = prevLayer->GetOutputSlot(outputSlot);
4717*89c4ff92SAndroid Build Coastguard Worker     prevOutputSlot.Connect(activationLayer->GetInputSlot(0));
4718*89c4ff92SAndroid Build Coastguard Worker     activationLayer->GetOutputSlot(0).SetTensorInfo(prevOutputSlot.GetTensorInfo());
4719*89c4ff92SAndroid Build Coastguard Worker     return activationLayer;
4720*89c4ff92SAndroid Build Coastguard Worker }
4721*89c4ff92SAndroid Build Coastguard Worker 
AddFusedFloorLayer(armnn::IConnectableLayer * prevLayer,unsigned int outputSlot)4722*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* TfLiteParserImpl::AddFusedFloorLayer(armnn::IConnectableLayer* prevLayer,
4723*89c4ff92SAndroid Build Coastguard Worker                                                                unsigned int outputSlot)
4724*89c4ff92SAndroid Build Coastguard Worker {
4725*89c4ff92SAndroid Build Coastguard Worker 
4726*89c4ff92SAndroid Build Coastguard Worker     auto& prevOutputSlot = prevLayer->GetOutputSlot(outputSlot);
4727*89c4ff92SAndroid Build Coastguard Worker     DataType dataType = prevOutputSlot.GetTensorInfo().GetDataType();
4728*89c4ff92SAndroid Build Coastguard Worker 
4729*89c4ff92SAndroid Build Coastguard Worker     if (dataType == DataType::Signed32)
4730*89c4ff92SAndroid Build Coastguard Worker     {
4731*89c4ff92SAndroid Build Coastguard Worker         return prevLayer;
4732*89c4ff92SAndroid Build Coastguard Worker     }
4733*89c4ff92SAndroid Build Coastguard Worker 
4734*89c4ff92SAndroid Build Coastguard Worker     std::string layerName = prevLayer->GetName();
4735*89c4ff92SAndroid Build Coastguard Worker     IConnectableLayer* floorLayer = m_Network->AddFloorLayer(layerName.c_str());
4736*89c4ff92SAndroid Build Coastguard Worker 
4737*89c4ff92SAndroid Build Coastguard Worker     prevOutputSlot.Connect(floorLayer->GetInputSlot(0));
4738*89c4ff92SAndroid Build Coastguard Worker     floorLayer->GetOutputSlot(0).SetTensorInfo(prevOutputSlot.GetTensorInfo());
4739*89c4ff92SAndroid Build Coastguard Worker 
4740*89c4ff92SAndroid Build Coastguard Worker     return floorLayer;
4741*89c4ff92SAndroid Build Coastguard Worker }
4742*89c4ff92SAndroid Build Coastguard Worker 
LoadModelFromFile(const char * fileName)4743*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr TfLiteParserImpl::LoadModelFromFile(const char* fileName)
4744*89c4ff92SAndroid Build Coastguard Worker {
4745*89c4ff92SAndroid Build Coastguard Worker     if (fileName == nullptr)
4746*89c4ff92SAndroid Build Coastguard Worker     {
4747*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(fmt::format("Invalid (null) file name {}",
4748*89c4ff92SAndroid Build Coastguard Worker                                        CHECK_LOCATION().AsString()));
4749*89c4ff92SAndroid Build Coastguard Worker     }
4750*89c4ff92SAndroid Build Coastguard Worker     std::error_code errorCode;
4751*89c4ff92SAndroid Build Coastguard Worker     fs::path pathToFile(fileName);
4752*89c4ff92SAndroid Build Coastguard Worker     if (!fs::exists(pathToFile, errorCode))
4753*89c4ff92SAndroid Build Coastguard Worker     {
4754*89c4ff92SAndroid Build Coastguard Worker         //fmt::format() could not be used here (format error)
4755*89c4ff92SAndroid Build Coastguard Worker         std::stringstream msg;
4756*89c4ff92SAndroid Build Coastguard Worker         msg << "Cannot find the file (" << fileName << ") errorCode: " << errorCode
4757*89c4ff92SAndroid Build Coastguard Worker             << " " << CHECK_LOCATION().AsString();
4758*89c4ff92SAndroid Build Coastguard Worker 
4759*89c4ff92SAndroid Build Coastguard Worker         throw FileNotFoundException(msg.str());
4760*89c4ff92SAndroid Build Coastguard Worker     }
4761*89c4ff92SAndroid Build Coastguard Worker     std::ifstream file(fileName, std::ios::binary);
4762*89c4ff92SAndroid Build Coastguard Worker     std::string fileContent((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
4763*89c4ff92SAndroid Build Coastguard Worker     return LoadModelFromBinary(reinterpret_cast<const uint8_t *>(fileContent.c_str()),
4764*89c4ff92SAndroid Build Coastguard Worker                                fileContent.size());
4765*89c4ff92SAndroid Build Coastguard Worker }
4766*89c4ff92SAndroid Build Coastguard Worker 
LoadModelFromBinary(const uint8_t * binaryContent,size_t len)4767*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::ModelPtr TfLiteParserImpl::LoadModelFromBinary(const uint8_t* binaryContent, size_t len)
4768*89c4ff92SAndroid Build Coastguard Worker {
4769*89c4ff92SAndroid Build Coastguard Worker     if (binaryContent == nullptr)
4770*89c4ff92SAndroid Build Coastguard Worker      {
4771*89c4ff92SAndroid Build Coastguard Worker         throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
4772*89c4ff92SAndroid Build Coastguard Worker                                        CHECK_LOCATION().AsString()));
4773*89c4ff92SAndroid Build Coastguard Worker      }
4774*89c4ff92SAndroid Build Coastguard Worker     flatbuffers::Verifier verifier(binaryContent, len);
4775*89c4ff92SAndroid Build Coastguard Worker     if (verifier.VerifyBuffer<tflite::Model>() == false)
4776*89c4ff92SAndroid Build Coastguard Worker     {
4777*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4778*89c4ff92SAndroid Build Coastguard Worker             fmt::format("Buffer doesn't conform to the expected Tensorflow Lite "
4779*89c4ff92SAndroid Build Coastguard Worker                         "flatbuffers format. size:{} {}",
4780*89c4ff92SAndroid Build Coastguard Worker                         len,
4781*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4782*89c4ff92SAndroid Build Coastguard Worker     }
4783*89c4ff92SAndroid Build Coastguard Worker     return tflite::UnPackModel(binaryContent);
4784*89c4ff92SAndroid Build Coastguard Worker }
4785*89c4ff92SAndroid Build Coastguard Worker 
GetInputs(const ModelPtr & model,size_t subgraphIndex,size_t operatorIndex)4786*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector TfLiteParserImpl::GetInputs(const ModelPtr& model,
4787*89c4ff92SAndroid Build Coastguard Worker                                                                  size_t subgraphIndex,
4788*89c4ff92SAndroid Build Coastguard Worker                                                                  size_t operatorIndex)
4789*89c4ff92SAndroid Build Coastguard Worker {
4790*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(model, subgraphIndex, operatorIndex);
4791*89c4ff92SAndroid Build Coastguard Worker 
4792*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = model->subgraphs[subgraphIndex];
4793*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = subgraphPtr->operators[operatorIndex];
4794*89c4ff92SAndroid Build Coastguard Worker 
4795*89c4ff92SAndroid Build Coastguard Worker     size_t inputCount = operatorPtr->inputs.size();
4796*89c4ff92SAndroid Build Coastguard Worker     TensorRawPtrVector result;
4797*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < inputCount; ++i)
4798*89c4ff92SAndroid Build Coastguard Worker     {
4799*89c4ff92SAndroid Build Coastguard Worker         // If the input location is -1 then assume input is turned off.
4800*89c4ff92SAndroid Build Coastguard Worker         if (operatorPtr->inputs[i] == -1)
4801*89c4ff92SAndroid Build Coastguard Worker         {
4802*89c4ff92SAndroid Build Coastguard Worker             continue;
4803*89c4ff92SAndroid Build Coastguard Worker         }
4804*89c4ff92SAndroid Build Coastguard Worker         else
4805*89c4ff92SAndroid Build Coastguard Worker         {
4806*89c4ff92SAndroid Build Coastguard Worker             uint32_t inputId = CHECKED_NON_NEGATIVE(operatorPtr->inputs[i]);
4807*89c4ff92SAndroid Build Coastguard Worker             result.push_back(subgraphPtr->tensors[inputId].get());
4808*89c4ff92SAndroid Build Coastguard Worker         }
4809*89c4ff92SAndroid Build Coastguard Worker     }
4810*89c4ff92SAndroid Build Coastguard Worker     return result;
4811*89c4ff92SAndroid Build Coastguard Worker }
4812*89c4ff92SAndroid Build Coastguard Worker 
GetOutputs(const ModelPtr & model,size_t subgraphIndex,size_t operatorIndex)4813*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorRawPtrVector TfLiteParserImpl::GetOutputs(const ModelPtr& model,
4814*89c4ff92SAndroid Build Coastguard Worker                                                                   size_t subgraphIndex,
4815*89c4ff92SAndroid Build Coastguard Worker                                                                   size_t operatorIndex)
4816*89c4ff92SAndroid Build Coastguard Worker {
4817*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(model, subgraphIndex, operatorIndex);
4818*89c4ff92SAndroid Build Coastguard Worker 
4819*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = model->subgraphs[subgraphIndex];
4820*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = subgraphPtr->operators[operatorIndex];
4821*89c4ff92SAndroid Build Coastguard Worker 
4822*89c4ff92SAndroid Build Coastguard Worker     size_t outputCount = operatorPtr->outputs.size();
4823*89c4ff92SAndroid Build Coastguard Worker     TensorRawPtrVector result(outputCount);
4824*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < outputCount; ++i)
4825*89c4ff92SAndroid Build Coastguard Worker     {
4826*89c4ff92SAndroid Build Coastguard Worker         uint32_t outputId = CHECKED_NON_NEGATIVE(operatorPtr->outputs[i]);
4827*89c4ff92SAndroid Build Coastguard Worker         CHECK_TENSOR(model, subgraphIndex, outputId);
4828*89c4ff92SAndroid Build Coastguard Worker         result[i] = subgraphPtr->tensors[outputId].get();
4829*89c4ff92SAndroid Build Coastguard Worker     }
4830*89c4ff92SAndroid Build Coastguard Worker     return result;
4831*89c4ff92SAndroid Build Coastguard Worker }
4832*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphInputs(const ModelPtr & model,size_t subgraphIndex)4833*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector TfLiteParserImpl::GetSubgraphInputs(const ModelPtr& model,
4834*89c4ff92SAndroid Build Coastguard Worker                                                                            size_t subgraphIndex)
4835*89c4ff92SAndroid Build Coastguard Worker {
4836*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(model, subgraphIndex);
4837*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = model->subgraphs[subgraphIndex];
4838*89c4ff92SAndroid Build Coastguard Worker 
4839*89c4ff92SAndroid Build Coastguard Worker     size_t inputCount = subgraphPtr->inputs.size();
4840*89c4ff92SAndroid Build Coastguard Worker     TensorIdRawPtrVector result(inputCount);
4841*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < inputCount; ++i)
4842*89c4ff92SAndroid Build Coastguard Worker     {
4843*89c4ff92SAndroid Build Coastguard Worker         uint32_t inputId = CHECKED_NON_NEGATIVE(subgraphPtr->inputs[i]);
4844*89c4ff92SAndroid Build Coastguard Worker         CHECK_TENSOR(model, subgraphIndex, inputId);
4845*89c4ff92SAndroid Build Coastguard Worker         result[i] = std::make_pair(inputId, subgraphPtr->tensors[inputId].get());
4846*89c4ff92SAndroid Build Coastguard Worker     }
4847*89c4ff92SAndroid Build Coastguard Worker     return result;
4848*89c4ff92SAndroid Build Coastguard Worker }
4849*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphOutputs(const ModelPtr & model,size_t subgraphIndex)4850*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::TensorIdRawPtrVector TfLiteParserImpl::GetSubgraphOutputs(const ModelPtr& model,
4851*89c4ff92SAndroid Build Coastguard Worker                                                                             size_t subgraphIndex)
4852*89c4ff92SAndroid Build Coastguard Worker {
4853*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(model, subgraphIndex);
4854*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = model->subgraphs[subgraphIndex];
4855*89c4ff92SAndroid Build Coastguard Worker 
4856*89c4ff92SAndroid Build Coastguard Worker     size_t outputCount = subgraphPtr->outputs.size();
4857*89c4ff92SAndroid Build Coastguard Worker     TensorIdRawPtrVector result(outputCount);
4858*89c4ff92SAndroid Build Coastguard Worker     for (size_t i = 0; i < outputCount; ++i)
4859*89c4ff92SAndroid Build Coastguard Worker     {
4860*89c4ff92SAndroid Build Coastguard Worker         uint32_t outputId = CHECKED_NON_NEGATIVE(subgraphPtr->outputs[i]);
4861*89c4ff92SAndroid Build Coastguard Worker         result[i] = std::make_pair(outputId, subgraphPtr->tensors[outputId].get());
4862*89c4ff92SAndroid Build Coastguard Worker     }
4863*89c4ff92SAndroid Build Coastguard Worker     return result;
4864*89c4ff92SAndroid Build Coastguard Worker }
4865*89c4ff92SAndroid Build Coastguard Worker 
GetInputTensorIds(const ModelPtr & model,size_t subgraphIndex,size_t operatorIndex)4866*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t>& TfLiteParserImpl::GetInputTensorIds(const ModelPtr& model,
4867*89c4ff92SAndroid Build Coastguard Worker                                                           size_t subgraphIndex,
4868*89c4ff92SAndroid Build Coastguard Worker                                                           size_t operatorIndex)
4869*89c4ff92SAndroid Build Coastguard Worker {
4870*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(model, subgraphIndex, operatorIndex);
4871*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = model->subgraphs[subgraphIndex];
4872*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = subgraphPtr->operators[operatorIndex];
4873*89c4ff92SAndroid Build Coastguard Worker     return operatorPtr->inputs;
4874*89c4ff92SAndroid Build Coastguard Worker }
4875*89c4ff92SAndroid Build Coastguard Worker 
GetOutputTensorIds(const ModelPtr & model,size_t subgraphIndex,size_t operatorIndex)4876*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t>& TfLiteParserImpl::GetOutputTensorIds(const ModelPtr& model,
4877*89c4ff92SAndroid Build Coastguard Worker                                                            size_t subgraphIndex,
4878*89c4ff92SAndroid Build Coastguard Worker                                                            size_t operatorIndex)
4879*89c4ff92SAndroid Build Coastguard Worker {
4880*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(model, subgraphIndex, operatorIndex);
4881*89c4ff92SAndroid Build Coastguard Worker     const auto& subgraphPtr = model->subgraphs[subgraphIndex];
4882*89c4ff92SAndroid Build Coastguard Worker     const auto& operatorPtr = subgraphPtr->operators[operatorIndex];
4883*89c4ff92SAndroid Build Coastguard Worker     return operatorPtr->outputs;
4884*89c4ff92SAndroid Build Coastguard Worker }
4885*89c4ff92SAndroid Build Coastguard Worker 
RegisterInputSlots(size_t subgraphIndex,size_t operatorIndex,IConnectableLayer * layer,const std::vector<unsigned int> & tensorIndexes,unsigned int startingSlotIndex)4886*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::RegisterInputSlots(size_t subgraphIndex,
4887*89c4ff92SAndroid Build Coastguard Worker                                           size_t operatorIndex,
4888*89c4ff92SAndroid Build Coastguard Worker                                           IConnectableLayer* layer,
4889*89c4ff92SAndroid Build Coastguard Worker                                           const std::vector<unsigned int>& tensorIndexes,
4890*89c4ff92SAndroid Build Coastguard Worker                                           unsigned int startingSlotIndex)
4891*89c4ff92SAndroid Build Coastguard Worker {
4892*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4893*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4894*89c4ff92SAndroid Build Coastguard Worker 
4895*89c4ff92SAndroid Build Coastguard Worker     if (tensorIndexes.size() + startingSlotIndex != layer->GetNumInputSlots())
4896*89c4ff92SAndroid Build Coastguard Worker     {
4897*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4898*89c4ff92SAndroid Build Coastguard Worker             fmt::format("The number of tensor inputs ({}) does not match the number expected ({})"
4899*89c4ff92SAndroid Build Coastguard Worker                         " for subgraph:{} operator index:{} {}",
4900*89c4ff92SAndroid Build Coastguard Worker                         tensorIndexes.size(),
4901*89c4ff92SAndroid Build Coastguard Worker                         layer->GetNumInputSlots(),
4902*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
4903*89c4ff92SAndroid Build Coastguard Worker                         operatorIndex,
4904*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4905*89c4ff92SAndroid Build Coastguard Worker     }
4906*89c4ff92SAndroid Build Coastguard Worker 
4907*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int index = 0; index < tensorIndexes.size() ; ++index)
4908*89c4ff92SAndroid Build Coastguard Worker     {
4909*89c4ff92SAndroid Build Coastguard Worker         unsigned int tensorIndex = tensorIndexes[index];
4910*89c4ff92SAndroid Build Coastguard Worker         armnn::IInputSlot* slot = &(layer->GetInputSlot(startingSlotIndex + index));
4911*89c4ff92SAndroid Build Coastguard Worker         RegisterConsumerOfTensor(subgraphIndex, tensorIndex, slot);
4912*89c4ff92SAndroid Build Coastguard Worker     }
4913*89c4ff92SAndroid Build Coastguard Worker }
4914*89c4ff92SAndroid Build Coastguard Worker 
RegisterOutputSlots(size_t subgraphIndex,size_t operatorIndex,IConnectableLayer * layer,const std::vector<unsigned int> & tensorIndexes)4915*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::RegisterOutputSlots(size_t subgraphIndex,
4916*89c4ff92SAndroid Build Coastguard Worker                                            size_t operatorIndex,
4917*89c4ff92SAndroid Build Coastguard Worker                                            IConnectableLayer* layer,
4918*89c4ff92SAndroid Build Coastguard Worker                                            const std::vector<unsigned int>& tensorIndexes)
4919*89c4ff92SAndroid Build Coastguard Worker {
4920*89c4ff92SAndroid Build Coastguard Worker     CHECK_MODEL(m_Model, subgraphIndex, operatorIndex);
4921*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
4922*89c4ff92SAndroid Build Coastguard Worker     if (tensorIndexes.size() != layer->GetNumOutputSlots())
4923*89c4ff92SAndroid Build Coastguard Worker     {
4924*89c4ff92SAndroid Build Coastguard Worker         throw ParseException(
4925*89c4ff92SAndroid Build Coastguard Worker             fmt::format("The number of tensor outputs ({}) does not match the number expected ({})"
4926*89c4ff92SAndroid Build Coastguard Worker                         " for subgraph:{} operator index:{} {}",
4927*89c4ff92SAndroid Build Coastguard Worker                         tensorIndexes.size(),
4928*89c4ff92SAndroid Build Coastguard Worker                         layer->GetNumOutputSlots(),
4929*89c4ff92SAndroid Build Coastguard Worker                         subgraphIndex,
4930*89c4ff92SAndroid Build Coastguard Worker                         operatorIndex,
4931*89c4ff92SAndroid Build Coastguard Worker                         CHECK_LOCATION().AsString()));
4932*89c4ff92SAndroid Build Coastguard Worker     }
4933*89c4ff92SAndroid Build Coastguard Worker 
4934*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
4935*89c4ff92SAndroid Build Coastguard Worker     {
4936*89c4ff92SAndroid Build Coastguard Worker         unsigned int tensorIndex = tensorIndexes[slotIndex];
4937*89c4ff92SAndroid Build Coastguard Worker         armnn::IOutputSlot* slot = &(layer->GetOutputSlot(slotIndex));
4938*89c4ff92SAndroid Build Coastguard Worker         RegisterProducerOfTensor(subgraphIndex, tensorIndex, slot);
4939*89c4ff92SAndroid Build Coastguard Worker     }
4940*89c4ff92SAndroid Build Coastguard Worker }
4941*89c4ff92SAndroid Build Coastguard Worker 
SetupInputLayerTensorInfos(size_t subgraphIndex)4942*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::SetupInputLayerTensorInfos(size_t subgraphIndex)
4943*89c4ff92SAndroid Build Coastguard Worker {
4944*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphIndex);
4945*89c4ff92SAndroid Build Coastguard Worker 
4946*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetSubgraphInputs(m_Model, subgraphIndex);
4947*89c4ff92SAndroid Build Coastguard Worker     for (auto const& tensorIdAndPtr : inputs)
4948*89c4ff92SAndroid Build Coastguard Worker     {
4949*89c4ff92SAndroid Build Coastguard Worker         auto tensorInfo = ToTensorInfo(tensorIdAndPtr.second);
4950*89c4ff92SAndroid Build Coastguard Worker         m_TensorInfos.insert({tensorIdAndPtr.first, tensorInfo});
4951*89c4ff92SAndroid Build Coastguard Worker     }
4952*89c4ff92SAndroid Build Coastguard Worker }
4953*89c4ff92SAndroid Build Coastguard Worker 
SetupInputLayers(size_t subgraphIndex)4954*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::SetupInputLayers(size_t subgraphIndex)
4955*89c4ff92SAndroid Build Coastguard Worker {
4956*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphIndex);
4957*89c4ff92SAndroid Build Coastguard Worker 
4958*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetSubgraphInputs(m_Model, subgraphIndex);
4959*89c4ff92SAndroid Build Coastguard Worker     for (auto const& tensorIdAndPtr : inputs)
4960*89c4ff92SAndroid Build Coastguard Worker     {
4961*89c4ff92SAndroid Build Coastguard Worker         auto bindingId = GenerateLayerBindingId(subgraphIndex, tensorIdAndPtr.first);
4962*89c4ff92SAndroid Build Coastguard Worker         IConnectableLayer* layer =
4963*89c4ff92SAndroid Build Coastguard Worker             m_Network->AddInputLayer(bindingId, tensorIdAndPtr.second->name.c_str());
4964*89c4ff92SAndroid Build Coastguard Worker 
4965*89c4ff92SAndroid Build Coastguard Worker         auto tensorInfo = ToTensorInfo(tensorIdAndPtr.second);
4966*89c4ff92SAndroid Build Coastguard Worker         layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
4967*89c4ff92SAndroid Build Coastguard Worker 
4968*89c4ff92SAndroid Build Coastguard Worker         RegisterOutputSlots(subgraphIndex,
4969*89c4ff92SAndroid Build Coastguard Worker                             VIRTUAL_OPERATOR_ID,
4970*89c4ff92SAndroid Build Coastguard Worker                             layer,
4971*89c4ff92SAndroid Build Coastguard Worker                             { static_cast<uint32_t>(tensorIdAndPtr.first) });
4972*89c4ff92SAndroid Build Coastguard Worker     }
4973*89c4ff92SAndroid Build Coastguard Worker }
4974*89c4ff92SAndroid Build Coastguard Worker 
SetupOutputLayers(size_t subgraphIndex)4975*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::SetupOutputLayers(size_t subgraphIndex)
4976*89c4ff92SAndroid Build Coastguard Worker {
4977*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphIndex);
4978*89c4ff92SAndroid Build Coastguard Worker 
4979*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetSubgraphOutputs(m_Model, subgraphIndex);
4980*89c4ff92SAndroid Build Coastguard Worker     for (auto const& tensorIdAndPtr : outputs)
4981*89c4ff92SAndroid Build Coastguard Worker     {
4982*89c4ff92SAndroid Build Coastguard Worker         auto bindingId = GenerateLayerBindingId(subgraphIndex, tensorIdAndPtr.first);
4983*89c4ff92SAndroid Build Coastguard Worker         IConnectableLayer* layer =
4984*89c4ff92SAndroid Build Coastguard Worker             m_Network->AddOutputLayer(bindingId, tensorIdAndPtr.second->name.c_str());
4985*89c4ff92SAndroid Build Coastguard Worker 
4986*89c4ff92SAndroid Build Coastguard Worker         RegisterInputSlots(subgraphIndex,
4987*89c4ff92SAndroid Build Coastguard Worker                            VIRTUAL_OPERATOR_ID,
4988*89c4ff92SAndroid Build Coastguard Worker                            layer,
4989*89c4ff92SAndroid Build Coastguard Worker                            { static_cast<uint32_t>(tensorIdAndPtr.first) });
4990*89c4ff92SAndroid Build Coastguard Worker     }
4991*89c4ff92SAndroid Build Coastguard Worker }
4992*89c4ff92SAndroid Build Coastguard Worker 
SetupConstantLayerTensorInfos(size_t subgraph)4993*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::SetupConstantLayerTensorInfos(size_t subgraph)
4994*89c4ff92SAndroid Build Coastguard Worker {
4995*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraph);
4996*89c4ff92SAndroid Build Coastguard Worker 
4997*89c4ff92SAndroid Build Coastguard Worker     const auto & subgraphPtr = m_Model->subgraphs[subgraph];
4998*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int subgraphIndex = 0; subgraphIndex < m_SubgraphConnections.size(); ++subgraphIndex)
4999*89c4ff92SAndroid Build Coastguard Worker     {
5000*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int tensorIndex = 0; tensorIndex < m_SubgraphConnections[subgraphIndex].size(); ++tensorIndex)
5001*89c4ff92SAndroid Build Coastguard Worker         {
5002*89c4ff92SAndroid Build Coastguard Worker             if (m_SubgraphConnections[subgraphIndex][tensorIndex].outputSlot == nullptr &&
5003*89c4ff92SAndroid Build Coastguard Worker                 m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots.size() > 0)
5004*89c4ff92SAndroid Build Coastguard Worker             {
5005*89c4ff92SAndroid Build Coastguard Worker                 TensorRawPtr tensorPtr = subgraphPtr->tensors[tensorIndex].get();
5006*89c4ff92SAndroid Build Coastguard Worker 
5007*89c4ff92SAndroid Build Coastguard Worker                 armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
5008*89c4ff92SAndroid Build Coastguard Worker 
5009*89c4ff92SAndroid Build Coastguard Worker                 m_TensorInfos.insert({tensorIndex, tensorInfo});
5010*89c4ff92SAndroid Build Coastguard Worker             }
5011*89c4ff92SAndroid Build Coastguard Worker         }
5012*89c4ff92SAndroid Build Coastguard Worker     }
5013*89c4ff92SAndroid Build Coastguard Worker }
5014*89c4ff92SAndroid Build Coastguard Worker 
SetupConstantLayers(size_t subgraph)5015*89c4ff92SAndroid Build Coastguard Worker void TfLiteParserImpl::SetupConstantLayers(size_t subgraph)
5016*89c4ff92SAndroid Build Coastguard Worker {
5017*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraph);
5018*89c4ff92SAndroid Build Coastguard Worker 
5019*89c4ff92SAndroid Build Coastguard Worker     const auto & subgraphPtr = m_Model->subgraphs[subgraph];
5020*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int subgraphIndex = 0; subgraphIndex < m_SubgraphConnections.size(); ++subgraphIndex)
5021*89c4ff92SAndroid Build Coastguard Worker     {
5022*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int tensorIndex = 0; tensorIndex < m_SubgraphConnections[subgraphIndex].size(); ++tensorIndex)
5023*89c4ff92SAndroid Build Coastguard Worker         {
5024*89c4ff92SAndroid Build Coastguard Worker             if (m_SubgraphConnections[subgraphIndex][tensorIndex].outputSlot == nullptr &&
5025*89c4ff92SAndroid Build Coastguard Worker                 m_SubgraphConnections[subgraphIndex][tensorIndex].inputSlots.size() > 0)
5026*89c4ff92SAndroid Build Coastguard Worker             {
5027*89c4ff92SAndroid Build Coastguard Worker                 TensorRawPtr tensorPtr = subgraphPtr->tensors[tensorIndex].get();
5028*89c4ff92SAndroid Build Coastguard Worker 
5029*89c4ff92SAndroid Build Coastguard Worker                 if (IsConstTensor(tensorPtr))
5030*89c4ff92SAndroid Build Coastguard Worker                 {
5031*89c4ff92SAndroid Build Coastguard Worker                     armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
5032*89c4ff92SAndroid Build Coastguard Worker                     armnn::DataType dataType = tensorInfo.GetDataType();
5033*89c4ff92SAndroid Build Coastguard Worker 
5034*89c4ff92SAndroid Build Coastguard Worker                     if (std::find(m_ConstantsToDequantize.begin(), m_ConstantsToDequantize.end(), tensorPtr->buffer)
5035*89c4ff92SAndroid Build Coastguard Worker                         != m_ConstantsToDequantize.end())
5036*89c4ff92SAndroid Build Coastguard Worker                     {
5037*89c4ff92SAndroid Build Coastguard Worker                         dataType = DataType::Float32;
5038*89c4ff92SAndroid Build Coastguard Worker                     }
5039*89c4ff92SAndroid Build Coastguard Worker                     auto tensorAndData = CreateConstTensorNonPermuted(tensorPtr, tensorInfo, dataType);
5040*89c4ff92SAndroid Build Coastguard Worker 
5041*89c4ff92SAndroid Build Coastguard Worker                     std::string layerName = fmt::format("Constant:{}", tensorPtr->name);
5042*89c4ff92SAndroid Build Coastguard Worker                     IConnectableLayer *layer = m_Network->AddConstantLayer(tensorAndData.first, layerName.c_str());
5043*89c4ff92SAndroid Build Coastguard Worker 
5044*89c4ff92SAndroid Build Coastguard Worker                     layer->GetOutputSlot(0).SetTensorInfo(tensorAndData.first.GetInfo());
5045*89c4ff92SAndroid Build Coastguard Worker                     RegisterOutputSlots(subgraphIndex,
5046*89c4ff92SAndroid Build Coastguard Worker                                         VIRTUAL_OPERATOR_ID,
5047*89c4ff92SAndroid Build Coastguard Worker                                         layer,
5048*89c4ff92SAndroid Build Coastguard Worker                                         { tensorIndex });
5049*89c4ff92SAndroid Build Coastguard Worker                 }
5050*89c4ff92SAndroid Build Coastguard Worker                 else if (ShouldConstantTensorBeCreated(tensorIndex))
5051*89c4ff92SAndroid Build Coastguard Worker                 {
5052*89c4ff92SAndroid Build Coastguard Worker                     armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
5053*89c4ff92SAndroid Build Coastguard Worker                     armnn::DataType dataType = tensorInfo.GetDataType();
5054*89c4ff92SAndroid Build Coastguard Worker 
5055*89c4ff92SAndroid Build Coastguard Worker                     if (std::find(m_ConstantsToDequantize.begin(), m_ConstantsToDequantize.end(), tensorPtr->buffer)
5056*89c4ff92SAndroid Build Coastguard Worker                         != m_ConstantsToDequantize.end())
5057*89c4ff92SAndroid Build Coastguard Worker                     {
5058*89c4ff92SAndroid Build Coastguard Worker                         dataType = DataType::Float32;
5059*89c4ff92SAndroid Build Coastguard Worker                     }
5060*89c4ff92SAndroid Build Coastguard Worker                     // Make sure isConstant flag is set.
5061*89c4ff92SAndroid Build Coastguard Worker                     tensorInfo.SetConstant();
5062*89c4ff92SAndroid Build Coastguard Worker                     tensorInfo.SetDataType(dataType);
5063*89c4ff92SAndroid Build Coastguard Worker 
5064*89c4ff92SAndroid Build Coastguard Worker                     auto tensorAndData = ConstTensor(tensorInfo, std::vector<uint8_t>(tensorInfo.GetNumBytes()));
5065*89c4ff92SAndroid Build Coastguard Worker 
5066*89c4ff92SAndroid Build Coastguard Worker                     std::string layerName = fmt::format("Constant:{}", tensorPtr->name);
5067*89c4ff92SAndroid Build Coastguard Worker                     IConnectableLayer* layer = m_Network->AddConstantLayer(tensorAndData, layerName.c_str());
5068*89c4ff92SAndroid Build Coastguard Worker 
5069*89c4ff92SAndroid Build Coastguard Worker                     layer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
5070*89c4ff92SAndroid Build Coastguard Worker                     RegisterOutputSlots(subgraphIndex,
5071*89c4ff92SAndroid Build Coastguard Worker                                         VIRTUAL_OPERATOR_ID,
5072*89c4ff92SAndroid Build Coastguard Worker                                         layer,
5073*89c4ff92SAndroid Build Coastguard Worker                                         {tensorIndex});
5074*89c4ff92SAndroid Build Coastguard Worker                 }
5075*89c4ff92SAndroid Build Coastguard Worker                 else
5076*89c4ff92SAndroid Build Coastguard Worker                 {
5077*89c4ff92SAndroid Build Coastguard Worker                     throw ParseException(
5078*89c4ff92SAndroid Build Coastguard Worker                             fmt::format("Invalid Tensor: Tensor should be constant. {}",
5079*89c4ff92SAndroid Build Coastguard Worker                                         CHECK_LOCATION().AsString()));
5080*89c4ff92SAndroid Build Coastguard Worker                 }
5081*89c4ff92SAndroid Build Coastguard Worker             }
5082*89c4ff92SAndroid Build Coastguard Worker         }
5083*89c4ff92SAndroid Build Coastguard Worker     }
5084*89c4ff92SAndroid Build Coastguard Worker }
5085*89c4ff92SAndroid Build Coastguard Worker 
5086*89c4ff92SAndroid Build Coastguard Worker // example usage: BufferRawPtr bufferPtr = GetBuffer(m_Model, inputs[0]->buffer);
GetBuffer(const ModelPtr & model,size_t bufferIndex)5087*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::BufferRawPtr TfLiteParserImpl::GetBuffer(const ModelPtr& model, size_t bufferIndex)
5088*89c4ff92SAndroid Build Coastguard Worker {
5089*89c4ff92SAndroid Build Coastguard Worker     CHECK_BUFFER(model, bufferIndex);
5090*89c4ff92SAndroid Build Coastguard Worker     return model->buffers[bufferIndex].get();
5091*89c4ff92SAndroid Build Coastguard Worker }
5092*89c4ff92SAndroid Build Coastguard Worker 
5093*89c4ff92SAndroid Build Coastguard Worker template<typename T>
5094*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,TfLiteParserImpl::TensorRawPtr tensorPtr,armnn::TensorInfo & tensorInfo,armnn::Optional<armnn::PermutationVector &> permutationVector)5095*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::CreateConstTensorAndStoreData(TfLiteParserImpl::BufferRawPtr bufferPtr,
5096*89c4ff92SAndroid Build Coastguard Worker                                             TfLiteParserImpl::TensorRawPtr tensorPtr,
5097*89c4ff92SAndroid Build Coastguard Worker                                             armnn::TensorInfo& tensorInfo,
5098*89c4ff92SAndroid Build Coastguard Worker                                             armnn::Optional<armnn::PermutationVector&> permutationVector)
5099*89c4ff92SAndroid Build Coastguard Worker {
5100*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
5101*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
5102*89c4ff92SAndroid Build Coastguard Worker 
5103*89c4ff92SAndroid Build Coastguard Worker     auto constData = CreateConstTensorImpl<T>(bufferPtr,
5104*89c4ff92SAndroid Build Coastguard Worker                                               tensorPtr,
5105*89c4ff92SAndroid Build Coastguard Worker                                               tensorInfo,
5106*89c4ff92SAndroid Build Coastguard Worker                                               permutationVector);
5107*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserImpl::SupportedDataStorage storage(std::move(constData.second));
5108*89c4ff92SAndroid Build Coastguard Worker     return std::make_pair(constData.first, std::move(storage));
5109*89c4ff92SAndroid Build Coastguard Worker }
5110*89c4ff92SAndroid Build Coastguard Worker 
ShouldConstantTensorBeCreated(unsigned int tensorIndex)5111*89c4ff92SAndroid Build Coastguard Worker bool TfLiteParserImpl::ShouldConstantTensorBeCreated(unsigned int tensorIndex)
5112*89c4ff92SAndroid Build Coastguard Worker {
5113*89c4ff92SAndroid Build Coastguard Worker     // If the TensorIndex appears in the list of ConstantsToBeCreated then return true
5114*89c4ff92SAndroid Build Coastguard Worker     return (std::find(m_ConstantsToBeCreated.begin(), m_ConstantsToBeCreated.end(), tensorIndex)
5115*89c4ff92SAndroid Build Coastguard Worker             != m_ConstantsToBeCreated.end());
5116*89c4ff92SAndroid Build Coastguard Worker }
5117*89c4ff92SAndroid Build Coastguard Worker 
IsConstTensor(TensorRawPtr tensorPtr)5118*89c4ff92SAndroid Build Coastguard Worker bool TfLiteParserImpl::IsConstTensor(TensorRawPtr tensorPtr)
5119*89c4ff92SAndroid Build Coastguard Worker {
5120*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR_PTR(tensorPtr);
5121*89c4ff92SAndroid Build Coastguard Worker     bool isConst = true;
5122*89c4ff92SAndroid Build Coastguard Worker 
5123*89c4ff92SAndroid Build Coastguard Worker     auto buffer = GetBuffer(m_Model, tensorPtr->buffer);
5124*89c4ff92SAndroid Build Coastguard Worker     if (buffer->data.size() == 0)
5125*89c4ff92SAndroid Build Coastguard Worker     {
5126*89c4ff92SAndroid Build Coastguard Worker         isConst = false;
5127*89c4ff92SAndroid Build Coastguard Worker     }
5128*89c4ff92SAndroid Build Coastguard Worker 
5129*89c4ff92SAndroid Build Coastguard Worker     return isConst;
5130*89c4ff92SAndroid Build Coastguard Worker }
5131*89c4ff92SAndroid Build Coastguard Worker 
5132*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, TfLiteParserImpl::SupportedDataStorage>
CreateConstTensorPermuted(TensorRawPtr tensorPtr,armnn::TensorInfo & tensorInfo,armnn::Optional<armnn::PermutationVector &> permutationVector)5133*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::CreateConstTensorPermuted(TensorRawPtr tensorPtr,
5134*89c4ff92SAndroid Build Coastguard Worker                                             armnn::TensorInfo& tensorInfo,
5135*89c4ff92SAndroid Build Coastguard Worker                                             armnn::Optional<armnn::PermutationVector&> permutationVector)
5136*89c4ff92SAndroid Build Coastguard Worker {
5137*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR_PTR(tensorPtr);
5138*89c4ff92SAndroid Build Coastguard Worker     auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer);
5139*89c4ff92SAndroid Build Coastguard Worker     CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer);
5140*89c4ff92SAndroid Build Coastguard Worker 
5141*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
5142*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
5143*89c4ff92SAndroid Build Coastguard Worker 
5144*89c4ff92SAndroid Build Coastguard Worker     switch (tensorInfo.GetDataType())
5145*89c4ff92SAndroid Build Coastguard Worker     {
5146*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float32:
5147*89c4ff92SAndroid Build Coastguard Worker             return CreateConstTensorAndStoreData<float>(bufferPtr,
5148*89c4ff92SAndroid Build Coastguard Worker                                                         tensorPtr,
5149*89c4ff92SAndroid Build Coastguard Worker                                                         tensorInfo,
5150*89c4ff92SAndroid Build Coastguard Worker                                                         permutationVector);
5151*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmU8:
5152*89c4ff92SAndroid Build Coastguard Worker             return CreateConstTensorAndStoreData<uint8_t>(bufferPtr,
5153*89c4ff92SAndroid Build Coastguard Worker                                                           tensorPtr,
5154*89c4ff92SAndroid Build Coastguard Worker                                                           tensorInfo,
5155*89c4ff92SAndroid Build Coastguard Worker                                                           permutationVector);
5156*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS8:
5157*89c4ff92SAndroid Build Coastguard Worker             return CreateConstTensorAndStoreData<int8_t>(bufferPtr,
5158*89c4ff92SAndroid Build Coastguard Worker                                                          tensorPtr,
5159*89c4ff92SAndroid Build Coastguard Worker                                                          tensorInfo,
5160*89c4ff92SAndroid Build Coastguard Worker                                                          permutationVector);
5161*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmS8:
5162*89c4ff92SAndroid Build Coastguard Worker             return CreateConstTensorAndStoreData<int8_t>(bufferPtr,
5163*89c4ff92SAndroid Build Coastguard Worker                                                          tensorPtr,
5164*89c4ff92SAndroid Build Coastguard Worker                                                          tensorInfo,
5165*89c4ff92SAndroid Build Coastguard Worker                                                          permutationVector);
5166*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Signed32:
5167*89c4ff92SAndroid Build Coastguard Worker             return CreateConstTensorAndStoreData<int32_t>(bufferPtr,
5168*89c4ff92SAndroid Build Coastguard Worker                                                           tensorPtr,
5169*89c4ff92SAndroid Build Coastguard Worker                                                           tensorInfo,
5170*89c4ff92SAndroid Build Coastguard Worker                                                           permutationVector);
5171*89c4ff92SAndroid Build Coastguard Worker         default:
5172*89c4ff92SAndroid Build Coastguard Worker         {
5173*89c4ff92SAndroid Build Coastguard Worker             std::stringstream errString;
5174*89c4ff92SAndroid Build Coastguard Worker             errString << "Unexpected datatype when creating const tensor: "
5175*89c4ff92SAndroid Build Coastguard Worker                         << armnn::GetDataTypeName(tensorInfo.GetDataType())
5176*89c4ff92SAndroid Build Coastguard Worker                         << " shape:" << tensorInfo.GetShape()
5177*89c4ff92SAndroid Build Coastguard Worker                         << CHECK_LOCATION().AsString();
5178*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(errString.str());
5179*89c4ff92SAndroid Build Coastguard Worker         }
5180*89c4ff92SAndroid Build Coastguard Worker     }
5181*89c4ff92SAndroid Build Coastguard Worker }
5182*89c4ff92SAndroid Build Coastguard Worker 
CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,armnn::TensorInfo & tensorInfo)5183*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
5184*89c4ff92SAndroid Build Coastguard Worker                                                                   armnn::TensorInfo& tensorInfo)
5185*89c4ff92SAndroid Build Coastguard Worker {
5186*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR_PTR(tensorPtr);
5187*89c4ff92SAndroid Build Coastguard Worker     auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer);
5188*89c4ff92SAndroid Build Coastguard Worker     CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer);
5189*89c4ff92SAndroid Build Coastguard Worker 
5190*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
5191*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
5192*89c4ff92SAndroid Build Coastguard Worker 
5193*89c4ff92SAndroid Build Coastguard Worker     return ConstTensor(tensorInfo, bufferPtr->data.data());
5194*89c4ff92SAndroid Build Coastguard Worker }
5195*89c4ff92SAndroid Build Coastguard Worker 
5196*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor, std::unique_ptr<float[]>>
CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,armnn::TensorInfo & tensorInfo,armnn::DataType inputDataType)5197*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::CreateConstTensorNonPermuted(TensorRawPtr tensorPtr,
5198*89c4ff92SAndroid Build Coastguard Worker                                                armnn::TensorInfo& tensorInfo,
5199*89c4ff92SAndroid Build Coastguard Worker                                                armnn::DataType inputDataType)
5200*89c4ff92SAndroid Build Coastguard Worker {
5201*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR_PTR(tensorPtr);
5202*89c4ff92SAndroid Build Coastguard Worker     auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer);
5203*89c4ff92SAndroid Build Coastguard Worker     CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer);
5204*89c4ff92SAndroid Build Coastguard Worker 
5205*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
5206*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
5207*89c4ff92SAndroid Build Coastguard Worker 
5208*89c4ff92SAndroid Build Coastguard Worker     if (inputDataType == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
5209*89c4ff92SAndroid Build Coastguard Worker     {
5210*89c4ff92SAndroid Build Coastguard Worker         try
5211*89c4ff92SAndroid Build Coastguard Worker         {
5212*89c4ff92SAndroid Build Coastguard Worker             TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
5213*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<float[]> data = armnnUtils::ToFloatArray(bufferPtr->data, tensorInfo);
5214*89c4ff92SAndroid Build Coastguard Worker             return std::make_pair(ConstTensor(constTensorInfo, data.get()), std::move(data));
5215*89c4ff92SAndroid Build Coastguard Worker         }
5216*89c4ff92SAndroid Build Coastguard Worker         catch (InvalidArgumentException&)
5217*89c4ff92SAndroid Build Coastguard Worker         {
5218*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
5219*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("Unsupported input/weights combination:  Input {} not supported with Weights {}",
5220*89c4ff92SAndroid Build Coastguard Worker                                 GetDataTypeName(DataType::Float32),
5221*89c4ff92SAndroid Build Coastguard Worker                                 GetDataTypeName(tensorInfo.GetDataType()),
5222*89c4ff92SAndroid Build Coastguard Worker                                 CHECK_LOCATION().AsString()));
5223*89c4ff92SAndroid Build Coastguard Worker         }
5224*89c4ff92SAndroid Build Coastguard Worker     }
5225*89c4ff92SAndroid Build Coastguard Worker     else
5226*89c4ff92SAndroid Build Coastguard Worker     {
5227*89c4ff92SAndroid Build Coastguard Worker         return std::make_pair(ConstTensor(tensorInfo, bufferPtr->data.data()), std::unique_ptr<float[]>());
5228*89c4ff92SAndroid Build Coastguard Worker     }
5229*89c4ff92SAndroid Build Coastguard Worker }
5230*89c4ff92SAndroid Build Coastguard Worker 
5231*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::ConstTensor*, std::unique_ptr<float[]>>
CreateConstTensorPtr(TensorRawPtr tensorPtr,armnn::TensorInfo & inputTensorInfo)5232*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::CreateConstTensorPtr(TensorRawPtr tensorPtr, armnn::TensorInfo& inputTensorInfo)
5233*89c4ff92SAndroid Build Coastguard Worker {
5234*89c4ff92SAndroid Build Coastguard Worker     CHECK_TENSOR_PTR(tensorPtr);
5235*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo tensorInfo = ToTensorInfo(tensorPtr);
5236*89c4ff92SAndroid Build Coastguard Worker     auto bufferPtr = GetBuffer(m_Model, tensorPtr->buffer);
5237*89c4ff92SAndroid Build Coastguard Worker     CHECK_BUFFER_SIZE(bufferPtr, tensorInfo, tensorPtr->buffer);
5238*89c4ff92SAndroid Build Coastguard Worker 
5239*89c4ff92SAndroid Build Coastguard Worker     // Make sure isConstant flag is set.
5240*89c4ff92SAndroid Build Coastguard Worker     tensorInfo.SetConstant();
5241*89c4ff92SAndroid Build Coastguard Worker 
5242*89c4ff92SAndroid Build Coastguard Worker     if (inputTensorInfo.GetDataType() == DataType::Float32 && tensorInfo.GetDataType() != DataType::Float32)
5243*89c4ff92SAndroid Build Coastguard Worker     {
5244*89c4ff92SAndroid Build Coastguard Worker         try
5245*89c4ff92SAndroid Build Coastguard Worker         {
5246*89c4ff92SAndroid Build Coastguard Worker             TensorInfo constTensorInfo(tensorInfo.GetShape(), DataType::Float32, 0.0f, 0, true);
5247*89c4ff92SAndroid Build Coastguard Worker             std::unique_ptr<float[]> data = armnnUtils::ToFloatArray(bufferPtr->data, tensorInfo);
5248*89c4ff92SAndroid Build Coastguard Worker             return std::make_pair(new ConstTensor(constTensorInfo, data.get()), std::move(data));
5249*89c4ff92SAndroid Build Coastguard Worker         }
5250*89c4ff92SAndroid Build Coastguard Worker         catch (InvalidArgumentException&)
5251*89c4ff92SAndroid Build Coastguard Worker         {
5252*89c4ff92SAndroid Build Coastguard Worker             throw ParseException(
5253*89c4ff92SAndroid Build Coastguard Worker                     fmt::format("Unsupported input/weights combination:  Input {} not supported with Weights {}",
5254*89c4ff92SAndroid Build Coastguard Worker                                 GetDataTypeName(DataType::Float32),
5255*89c4ff92SAndroid Build Coastguard Worker                                 GetDataTypeName(tensorInfo.GetDataType()),
5256*89c4ff92SAndroid Build Coastguard Worker                                 CHECK_LOCATION().AsString()));
5257*89c4ff92SAndroid Build Coastguard Worker         }
5258*89c4ff92SAndroid Build Coastguard Worker     }
5259*89c4ff92SAndroid Build Coastguard Worker     else
5260*89c4ff92SAndroid Build Coastguard Worker     {
5261*89c4ff92SAndroid Build Coastguard Worker         return std::make_pair(new ConstTensor(tensorInfo, bufferPtr->data.data()), std::unique_ptr<float[]>());
5262*89c4ff92SAndroid Build Coastguard Worker     }
5263*89c4ff92SAndroid Build Coastguard Worker }
5264*89c4ff92SAndroid Build Coastguard Worker 
GetNetworkInputBindingInfo(size_t subgraphId,const std::string & name) const5265*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo TfLiteParserImpl::GetNetworkInputBindingInfo(size_t subgraphId,
5266*89c4ff92SAndroid Build Coastguard Worker                                                               const std::string& name) const
5267*89c4ff92SAndroid Build Coastguard Worker {
5268*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphId);
5269*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetSubgraphInputs(m_Model, subgraphId);
5270*89c4ff92SAndroid Build Coastguard Worker     for (auto const& input : inputs)
5271*89c4ff92SAndroid Build Coastguard Worker     {
5272*89c4ff92SAndroid Build Coastguard Worker         if (input.second->name == name)
5273*89c4ff92SAndroid Build Coastguard Worker         {
5274*89c4ff92SAndroid Build Coastguard Worker             auto bindingId = GenerateLayerBindingId(subgraphId, input.first);
5275*89c4ff92SAndroid Build Coastguard Worker             auto inputTensorInfo = ToTensorInfo(input.second);
5276*89c4ff92SAndroid Build Coastguard Worker             // Input tensors are always treated as constant tensors during network execution.
5277*89c4ff92SAndroid Build Coastguard Worker             inputTensorInfo.SetConstant(true);
5278*89c4ff92SAndroid Build Coastguard Worker             return std::make_pair(bindingId, inputTensorInfo);
5279*89c4ff92SAndroid Build Coastguard Worker         }
5280*89c4ff92SAndroid Build Coastguard Worker     }
5281*89c4ff92SAndroid Build Coastguard Worker 
5282*89c4ff92SAndroid Build Coastguard Worker     std::stringstream bindings;
5283*89c4ff92SAndroid Build Coastguard Worker     for (auto const& input : inputs)
5284*89c4ff92SAndroid Build Coastguard Worker     {
5285*89c4ff92SAndroid Build Coastguard Worker         bindings << "'" << input.second->name << "' ";
5286*89c4ff92SAndroid Build Coastguard Worker     }
5287*89c4ff92SAndroid Build Coastguard Worker 
5288*89c4ff92SAndroid Build Coastguard Worker     throw ParseException(
5289*89c4ff92SAndroid Build Coastguard Worker         fmt::format("No input binding found for subgraph:{} and name:{}. "
5290*89c4ff92SAndroid Build Coastguard Worker                     "Possible inputs are: [{}] {}",
5291*89c4ff92SAndroid Build Coastguard Worker                     subgraphId,
5292*89c4ff92SAndroid Build Coastguard Worker                     name,
5293*89c4ff92SAndroid Build Coastguard Worker                     bindings.str(),
5294*89c4ff92SAndroid Build Coastguard Worker                     CHECK_LOCATION().AsString()));
5295*89c4ff92SAndroid Build Coastguard Worker }
5296*89c4ff92SAndroid Build Coastguard Worker 
GetNetworkOutputBindingInfo(size_t subgraphId,const std::string & name) const5297*89c4ff92SAndroid Build Coastguard Worker BindingPointInfo TfLiteParserImpl::GetNetworkOutputBindingInfo(size_t subgraphId,
5298*89c4ff92SAndroid Build Coastguard Worker                                                                const std::string& name) const
5299*89c4ff92SAndroid Build Coastguard Worker {
5300*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphId);
5301*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetSubgraphOutputs(m_Model, subgraphId);
5302*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < outputs.size(); ++i)
5303*89c4ff92SAndroid Build Coastguard Worker     {
5304*89c4ff92SAndroid Build Coastguard Worker         auto const output = outputs[i];
5305*89c4ff92SAndroid Build Coastguard Worker         if (output.second->name == name)
5306*89c4ff92SAndroid Build Coastguard Worker         {
5307*89c4ff92SAndroid Build Coastguard Worker             auto bindingId = GenerateLayerBindingId(subgraphId, output.first);
5308*89c4ff92SAndroid Build Coastguard Worker             std::vector<unsigned int> shape = m_OverriddenOutputShapes.size() > 0 ?
5309*89c4ff92SAndroid Build Coastguard Worker                                                 m_OverriddenOutputShapes[i] : AsUnsignedVector(output.second->shape);
5310*89c4ff92SAndroid Build Coastguard Worker             return std::make_pair(bindingId, ToTensorInfo(output.second, shape));
5311*89c4ff92SAndroid Build Coastguard Worker         }
5312*89c4ff92SAndroid Build Coastguard Worker     }
5313*89c4ff92SAndroid Build Coastguard Worker 
5314*89c4ff92SAndroid Build Coastguard Worker     std::stringstream bindings;
5315*89c4ff92SAndroid Build Coastguard Worker     for (auto const& output : outputs)
5316*89c4ff92SAndroid Build Coastguard Worker     {
5317*89c4ff92SAndroid Build Coastguard Worker         bindings << "'" << output.second->name << "' ";
5318*89c4ff92SAndroid Build Coastguard Worker     }
5319*89c4ff92SAndroid Build Coastguard Worker 
5320*89c4ff92SAndroid Build Coastguard Worker     throw ParseException(
5321*89c4ff92SAndroid Build Coastguard Worker         fmt::format("No output binding found for subgraph:{} and name:{}. "
5322*89c4ff92SAndroid Build Coastguard Worker                     "Possible outputs are: [{}] {}",
5323*89c4ff92SAndroid Build Coastguard Worker                     subgraphId,
5324*89c4ff92SAndroid Build Coastguard Worker                     name,
5325*89c4ff92SAndroid Build Coastguard Worker                     bindings.str(),
5326*89c4ff92SAndroid Build Coastguard Worker                     CHECK_LOCATION().AsString()));
5327*89c4ff92SAndroid Build Coastguard Worker }
5328*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphCount() const5329*89c4ff92SAndroid Build Coastguard Worker size_t TfLiteParserImpl::GetSubgraphCount() const
5330*89c4ff92SAndroid Build Coastguard Worker {
5331*89c4ff92SAndroid Build Coastguard Worker     return m_Model->subgraphs.size();
5332*89c4ff92SAndroid Build Coastguard Worker }
5333*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphInputTensorNames(size_t subgraphId) const5334*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> TfLiteParserImpl::GetSubgraphInputTensorNames(size_t subgraphId) const
5335*89c4ff92SAndroid Build Coastguard Worker {
5336*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphId);
5337*89c4ff92SAndroid Build Coastguard Worker     auto inputs = GetSubgraphInputs(m_Model, subgraphId);
5338*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> result;
5339*89c4ff92SAndroid Build Coastguard Worker     result.reserve(inputs.size());
5340*89c4ff92SAndroid Build Coastguard Worker     for (auto const& input : inputs)
5341*89c4ff92SAndroid Build Coastguard Worker     {
5342*89c4ff92SAndroid Build Coastguard Worker         result.push_back(input.second->name);
5343*89c4ff92SAndroid Build Coastguard Worker     }
5344*89c4ff92SAndroid Build Coastguard Worker     return result;
5345*89c4ff92SAndroid Build Coastguard Worker }
5346*89c4ff92SAndroid Build Coastguard Worker 
GetSubgraphOutputTensorNames(size_t subgraphId) const5347*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> TfLiteParserImpl::GetSubgraphOutputTensorNames(size_t subgraphId) const
5348*89c4ff92SAndroid Build Coastguard Worker {
5349*89c4ff92SAndroid Build Coastguard Worker     CHECK_SUBGRAPH(m_Model, subgraphId);
5350*89c4ff92SAndroid Build Coastguard Worker     auto outputs = GetSubgraphOutputs(m_Model, subgraphId);
5351*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> result;
5352*89c4ff92SAndroid Build Coastguard Worker     result.reserve(outputs.size());
5353*89c4ff92SAndroid Build Coastguard Worker     for (auto const& output : outputs)
5354*89c4ff92SAndroid Build Coastguard Worker     {
5355*89c4ff92SAndroid Build Coastguard Worker         result.push_back(output.second->name);
5356*89c4ff92SAndroid Build Coastguard Worker     }
5357*89c4ff92SAndroid Build Coastguard Worker     return result;
5358*89c4ff92SAndroid Build Coastguard Worker }
5359*89c4ff92SAndroid Build Coastguard Worker 
GetVersion()5360*89c4ff92SAndroid Build Coastguard Worker const std::string TfLiteParserImpl::GetVersion()
5361*89c4ff92SAndroid Build Coastguard Worker {
5362*89c4ff92SAndroid Build Coastguard Worker     return TFLITE_PARSER_VERSION;
5363*89c4ff92SAndroid Build Coastguard Worker }
5364*89c4ff92SAndroid Build Coastguard Worker 
SupportedDataStorage(std::unique_ptr<float[]> && data)5365*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<float[]>&& data)
5366*89c4ff92SAndroid Build Coastguard Worker : m_FloatData(std::move(data))
5367*89c4ff92SAndroid Build Coastguard Worker , m_Uint8Data(nullptr)
5368*89c4ff92SAndroid Build Coastguard Worker , m_Int8Data(nullptr)
5369*89c4ff92SAndroid Build Coastguard Worker , m_Int32Data(nullptr)
5370*89c4ff92SAndroid Build Coastguard Worker {
5371*89c4ff92SAndroid Build Coastguard Worker }
5372*89c4ff92SAndroid Build Coastguard Worker 
SupportedDataStorage(std::unique_ptr<uint8_t[]> && data)5373*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<uint8_t[]>&& data)
5374*89c4ff92SAndroid Build Coastguard Worker : m_FloatData(nullptr)
5375*89c4ff92SAndroid Build Coastguard Worker , m_Uint8Data(std::move(data))
5376*89c4ff92SAndroid Build Coastguard Worker , m_Int8Data(nullptr)
5377*89c4ff92SAndroid Build Coastguard Worker , m_Int32Data(nullptr)
5378*89c4ff92SAndroid Build Coastguard Worker {
5379*89c4ff92SAndroid Build Coastguard Worker }
5380*89c4ff92SAndroid Build Coastguard Worker 
SupportedDataStorage(std::unique_ptr<int8_t[]> && data)5381*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int8_t[]>&& data)
5382*89c4ff92SAndroid Build Coastguard Worker : m_FloatData(nullptr)
5383*89c4ff92SAndroid Build Coastguard Worker , m_Uint8Data(nullptr)
5384*89c4ff92SAndroid Build Coastguard Worker , m_Int8Data(std::move(data))
5385*89c4ff92SAndroid Build Coastguard Worker , m_Int32Data(nullptr)
5386*89c4ff92SAndroid Build Coastguard Worker {
5387*89c4ff92SAndroid Build Coastguard Worker }
5388*89c4ff92SAndroid Build Coastguard Worker 
SupportedDataStorage(std::unique_ptr<int32_t[]> && data)5389*89c4ff92SAndroid Build Coastguard Worker TfLiteParserImpl::SupportedDataStorage::SupportedDataStorage(std::unique_ptr<int32_t[]>&& data)
5390*89c4ff92SAndroid Build Coastguard Worker : m_FloatData(nullptr)
5391*89c4ff92SAndroid Build Coastguard Worker , m_Uint8Data(nullptr)
5392*89c4ff92SAndroid Build Coastguard Worker , m_Int8Data(nullptr)
5393*89c4ff92SAndroid Build Coastguard Worker , m_Int32Data(std::move(data))
5394*89c4ff92SAndroid Build Coastguard Worker {
5395*89c4ff92SAndroid Build Coastguard Worker }
5396*89c4ff92SAndroid Build Coastguard Worker 
5397*89c4ff92SAndroid Build Coastguard Worker } // armnnTfLiteParser
5398