1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "CanonicalUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Transpose.hpp>
18*89c4ff92SAndroid Build Coastguard Worker
19*89c4ff92SAndroid Build Coastguard Worker #include <ActivationFunctor.h>
20*89c4ff92SAndroid Build Coastguard Worker #include <CpuExecutor.h>
21*89c4ff92SAndroid Build Coastguard Worker #include <OperationsUtils.h>
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp>
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker #include <log/log.h>
26*89c4ff92SAndroid Build Coastguard Worker #include <vector>
27*89c4ff92SAndroid Build Coastguard Worker
getMainModel(const android::nn::Model & model)28*89c4ff92SAndroid Build Coastguard Worker inline const android::nn::Model::Subgraph& getMainModel(const android::nn::Model& model) { return model.main; }
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker ///
34*89c4ff92SAndroid Build Coastguard Worker /// Helper classes
35*89c4ff92SAndroid Build Coastguard Worker ///
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/OperandTypes.h>
38*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h>
39*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/TypeUtils.h>
40*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h>
41*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Validation.h>
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker using Model = ::android::nn::Model;
44*89c4ff92SAndroid Build Coastguard Worker using Operand = ::android::nn::Operand;
45*89c4ff92SAndroid Build Coastguard Worker using OperandLifeTime = ::android::nn::Operand::LifeTime;
46*89c4ff92SAndroid Build Coastguard Worker using OperandType = ::android::nn::OperandType;
47*89c4ff92SAndroid Build Coastguard Worker using Operation = ::android::nn::Operation;
48*89c4ff92SAndroid Build Coastguard Worker using OperationType = ::android::nn::OperationType;
49*89c4ff92SAndroid Build Coastguard Worker using ErrorStatus = ::android::nn::ErrorStatus;
50*89c4ff92SAndroid Build Coastguard Worker
51*89c4ff92SAndroid Build Coastguard Worker struct ConversionData
52*89c4ff92SAndroid Build Coastguard Worker {
ConversionDataarmnn_driver::ConversionData53*89c4ff92SAndroid Build Coastguard Worker ConversionData(const std::vector<armnn::BackendId>& backends)
54*89c4ff92SAndroid Build Coastguard Worker : m_Backends(backends)
55*89c4ff92SAndroid Build Coastguard Worker , m_Network(nullptr, nullptr)
56*89c4ff92SAndroid Build Coastguard Worker , m_DynamicInputsEncountered(false)
57*89c4ff92SAndroid Build Coastguard Worker {}
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::BackendId> m_Backends;
60*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr m_Network;
61*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::IOutputSlot*> m_OutputSlotForOperand;
62*89c4ff92SAndroid Build Coastguard Worker std::vector<::android::nn::RunTimePoolInfo> m_MemPools;
63*89c4ff92SAndroid Build Coastguard Worker bool m_DynamicInputsEncountered;
64*89c4ff92SAndroid Build Coastguard Worker };
65*89c4ff92SAndroid Build Coastguard Worker
66*89c4ff92SAndroid Build Coastguard Worker class LayerInputHandle
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker public:
69*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle();
70*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo);
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker bool IsValid() const;
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker void Connect(armnn::IInputSlot& inputSlot);
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker void Disconnect(armnn::IInputSlot& inputSlot);
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& GetTensorInfo() const;
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker void SanitizeQuantizationScale(LayerInputHandle& weight, LayerInputHandle& input);
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* GetOutputSlot() const;
83*89c4ff92SAndroid Build Coastguard Worker
84*89c4ff92SAndroid Build Coastguard Worker private:
85*89c4ff92SAndroid Build Coastguard Worker armnn::IOutputSlot* m_OutputSlot;
86*89c4ff92SAndroid Build Coastguard Worker bool m_Valid;
87*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo m_TensorInfo;
88*89c4ff92SAndroid Build Coastguard Worker };
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker class ConstTensorPin
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker public:
93*89c4ff92SAndroid Build Coastguard Worker // Creates an invalid tensor pin (can be used to signal errors)
94*89c4ff92SAndroid Build Coastguard Worker // The optional flag can be set to indicate the tensor values were missing, but it was otherwise valid
95*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin(bool optional = false);
96*89c4ff92SAndroid Build Coastguard Worker
97*89c4ff92SAndroid Build Coastguard Worker // @param tensorInfo TensorInfo associated with the tensor.
98*89c4ff92SAndroid Build Coastguard Worker // @param valueStart Start address of tensor data. Belongs to one of the memory pools associated with
99*89c4ff92SAndroid Build Coastguard Worker // the model being converted.
100*89c4ff92SAndroid Build Coastguard Worker // @param numBytes Number of bytes for the tensor data.
101*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes,
102*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mappings);
103*89c4ff92SAndroid Build Coastguard Worker
104*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin(const ConstTensorPin& other) = delete;
105*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin(ConstTensorPin&& other) = default;
106*89c4ff92SAndroid Build Coastguard Worker
107*89c4ff92SAndroid Build Coastguard Worker bool IsValid() const;
108*89c4ff92SAndroid Build Coastguard Worker bool IsOptional() const;
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor& GetConstTensor() const;
111*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor* GetConstTensorPtr() const;
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker private:
114*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor m_ConstTensor;
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker // Owned memory for swizzled tensor data, only required if the tensor needed
117*89c4ff92SAndroid Build Coastguard Worker // swizzling. Otherwise, @ref m_ConstTensor will reference memory from one of
118*89c4ff92SAndroid Build Coastguard Worker // the pools associated with the model being converted.
119*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_SwizzledTensorData;
120*89c4ff92SAndroid Build Coastguard Worker
121*89c4ff92SAndroid Build Coastguard Worker // optional flag to indicate that an invalid tensor pin is not an error, but the optional values were not given
122*89c4ff92SAndroid Build Coastguard Worker bool m_Optional;
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker
125*89c4ff92SAndroid Build Coastguard Worker enum class ConversionResult
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker Success,
128*89c4ff92SAndroid Build Coastguard Worker ErrorMappingPools,
129*89c4ff92SAndroid Build Coastguard Worker UnsupportedFeature
130*89c4ff92SAndroid Build Coastguard Worker };
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
133*89c4ff92SAndroid Build Coastguard Worker
134*89c4ff92SAndroid Build Coastguard Worker ///
135*89c4ff92SAndroid Build Coastguard Worker /// Utility functions
136*89c4ff92SAndroid Build Coastguard Worker ///
137*89c4ff92SAndroid Build Coastguard Worker
138*89c4ff92SAndroid Build Coastguard Worker namespace
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker using namespace armnn_driver;
141*89c4ff92SAndroid Build Coastguard Worker
142*89c4ff92SAndroid Build Coastguard Worker // Convenience function to log the reason for failing to convert a model.
143*89c4ff92SAndroid Build Coastguard Worker // @return Always returns false (so that it can be used by callers as a quick way to signal an error and return)
144*89c4ff92SAndroid Build Coastguard Worker template<class... Args>
Fail(const char * formatStr,Args &&...args)145*89c4ff92SAndroid Build Coastguard Worker static bool Fail(const char* formatStr, Args&&... args)
146*89c4ff92SAndroid Build Coastguard Worker {
147*89c4ff92SAndroid Build Coastguard Worker ALOGD(formatStr, std::forward<Args>(args)...);
148*89c4ff92SAndroid Build Coastguard Worker return false;
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker // Convenience macro to call an Is*Supported function and log caller name together with reason for lack of support.
152*89c4ff92SAndroid Build Coastguard Worker // Called as: FORWARD_LAYER_SUPPORT_FUNC(__func__, Is*Supported, backends, a, b, c, d, e)
153*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_LAYER_SUPPORT_FUNC(funcName, func, backends, supported, setBackend, ...) \
154*89c4ff92SAndroid Build Coastguard Worker try \
155*89c4ff92SAndroid Build Coastguard Worker { \
156*89c4ff92SAndroid Build Coastguard Worker for (auto&& backendId : backends) \
157*89c4ff92SAndroid Build Coastguard Worker { \
158*89c4ff92SAndroid Build Coastguard Worker auto layerSupportObject = armnn::GetILayerSupportByBackendId(backendId); \
159*89c4ff92SAndroid Build Coastguard Worker if (layerSupportObject.IsBackendRegistered()) \
160*89c4ff92SAndroid Build Coastguard Worker { \
161*89c4ff92SAndroid Build Coastguard Worker std::string reasonIfUnsupported; \
162*89c4ff92SAndroid Build Coastguard Worker supported = \
163*89c4ff92SAndroid Build Coastguard Worker layerSupportObject.func(__VA_ARGS__, armnn::Optional<std::string&>(reasonIfUnsupported)); \
164*89c4ff92SAndroid Build Coastguard Worker if (supported) \
165*89c4ff92SAndroid Build Coastguard Worker { \
166*89c4ff92SAndroid Build Coastguard Worker setBackend = backendId; \
167*89c4ff92SAndroid Build Coastguard Worker break; \
168*89c4ff92SAndroid Build Coastguard Worker } \
169*89c4ff92SAndroid Build Coastguard Worker else \
170*89c4ff92SAndroid Build Coastguard Worker { \
171*89c4ff92SAndroid Build Coastguard Worker if (reasonIfUnsupported.size() > 0) \
172*89c4ff92SAndroid Build Coastguard Worker { \
173*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << funcName << ": not supported by armnn: " << reasonIfUnsupported.c_str(); \
174*89c4ff92SAndroid Build Coastguard Worker } \
175*89c4ff92SAndroid Build Coastguard Worker else \
176*89c4ff92SAndroid Build Coastguard Worker { \
177*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << funcName << ": not supported by armnn"; \
178*89c4ff92SAndroid Build Coastguard Worker } \
179*89c4ff92SAndroid Build Coastguard Worker } \
180*89c4ff92SAndroid Build Coastguard Worker } \
181*89c4ff92SAndroid Build Coastguard Worker else \
182*89c4ff92SAndroid Build Coastguard Worker { \
183*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << funcName << ": backend not registered: " << backendId.Get().c_str(); \
184*89c4ff92SAndroid Build Coastguard Worker } \
185*89c4ff92SAndroid Build Coastguard Worker } \
186*89c4ff92SAndroid Build Coastguard Worker if (!supported) \
187*89c4ff92SAndroid Build Coastguard Worker { \
188*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << funcName << ": not supported by any specified backend"; \
189*89c4ff92SAndroid Build Coastguard Worker } \
190*89c4ff92SAndroid Build Coastguard Worker } \
191*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::InvalidArgumentException &e) \
192*89c4ff92SAndroid Build Coastguard Worker { \
193*89c4ff92SAndroid Build Coastguard Worker throw armnn::InvalidArgumentException(e, "Failed to check layer support", CHECK_LOCATION()); \
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker
GetTensorShapeForOperand(const Operand & operand)196*89c4ff92SAndroid Build Coastguard Worker inline armnn::TensorShape GetTensorShapeForOperand(const Operand& operand)
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker return armnn::TensorShape(operand.dimensions.size(), operand.dimensions.data());
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker
201*89c4ff92SAndroid Build Coastguard Worker // Support within the 1.3 driver for specific tensor data types
IsOperandTypeSupportedForTensors(OperandType type)202*89c4ff92SAndroid Build Coastguard Worker inline bool IsOperandTypeSupportedForTensors(OperandType type)
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker return type == OperandType::BOOL ||
205*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_BOOL8 ||
206*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_FLOAT16 ||
207*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_FLOAT32 ||
208*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_QUANT8_ASYMM ||
209*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED ||
210*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_QUANT8_SYMM ||
211*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
212*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_QUANT16_SYMM ||
213*89c4ff92SAndroid Build Coastguard Worker type == OperandType::TENSOR_INT32;
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker
IsBool(Operand operand)216*89c4ff92SAndroid Build Coastguard Worker inline bool IsBool(Operand operand)
217*89c4ff92SAndroid Build Coastguard Worker {
218*89c4ff92SAndroid Build Coastguard Worker return operand.type == OperandType::BOOL;
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker
Is12OrLaterOperand(Operand)221*89c4ff92SAndroid Build Coastguard Worker inline bool Is12OrLaterOperand(Operand)
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker return true;
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker
226*89c4ff92SAndroid Build Coastguard Worker
227*89c4ff92SAndroid Build Coastguard Worker template<typename LayerHandleType>
AddReshapeLayer(armnn::INetwork & network,LayerHandleType & inputLayer,armnn::TensorInfo reshapeInfo)228*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& AddReshapeLayer(armnn::INetwork& network,
229*89c4ff92SAndroid Build Coastguard Worker LayerHandleType& inputLayer,
230*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo reshapeInfo)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker armnn::ReshapeDescriptor reshapeDescriptor;
233*89c4ff92SAndroid Build Coastguard Worker reshapeDescriptor.m_TargetShape = reshapeInfo.GetShape();
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* reshapeLayer = network.AddReshapeLayer(reshapeDescriptor);
236*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(reshapeLayer != nullptr);
237*89c4ff92SAndroid Build Coastguard Worker
238*89c4ff92SAndroid Build Coastguard Worker // Attach the input layer to the reshape layer
239*89c4ff92SAndroid Build Coastguard Worker inputLayer.Connect(reshapeLayer->GetInputSlot(0));
240*89c4ff92SAndroid Build Coastguard Worker reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapeInfo);
241*89c4ff92SAndroid Build Coastguard Worker
242*89c4ff92SAndroid Build Coastguard Worker return *reshapeLayer;
243*89c4ff92SAndroid Build Coastguard Worker }
244*89c4ff92SAndroid Build Coastguard Worker
245*89c4ff92SAndroid Build Coastguard Worker
FlattenFullyConnectedInput(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape)246*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape,
247*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& weightsShape)
248*89c4ff92SAndroid Build Coastguard Worker {
249*89c4ff92SAndroid Build Coastguard Worker if (inputShape.GetNumDimensions() > 2U)
250*89c4ff92SAndroid Build Coastguard Worker {
251*89c4ff92SAndroid Build Coastguard Worker unsigned int totalInputElements = inputShape.GetNumElements();
252*89c4ff92SAndroid Build Coastguard Worker unsigned int inputSize = weightsShape[1];
253*89c4ff92SAndroid Build Coastguard Worker
254*89c4ff92SAndroid Build Coastguard Worker unsigned int batchSize = totalInputElements / inputSize;
255*89c4ff92SAndroid Build Coastguard Worker
256*89c4ff92SAndroid Build Coastguard Worker if(totalInputElements % batchSize != 0)
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker throw std::runtime_error("Failed to deduce tensor shape");
259*89c4ff92SAndroid Build Coastguard Worker }
260*89c4ff92SAndroid Build Coastguard Worker
261*89c4ff92SAndroid Build Coastguard Worker return armnn::TensorShape({batchSize, inputSize});
262*89c4ff92SAndroid Build Coastguard Worker }
263*89c4ff92SAndroid Build Coastguard Worker else
264*89c4ff92SAndroid Build Coastguard Worker {
265*89c4ff92SAndroid Build Coastguard Worker return inputShape;
266*89c4ff92SAndroid Build Coastguard Worker }
267*89c4ff92SAndroid Build Coastguard Worker }
268*89c4ff92SAndroid Build Coastguard Worker
VerifyFullyConnectedShapes(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape,const armnn::TensorShape & outputShape,bool transposeWeightMatrix)269*89c4ff92SAndroid Build Coastguard Worker inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape,
270*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& weightsShape,
271*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& outputShape,
272*89c4ff92SAndroid Build Coastguard Worker bool transposeWeightMatrix)
273*89c4ff92SAndroid Build Coastguard Worker {
274*89c4ff92SAndroid Build Coastguard Worker unsigned int dimIdx = transposeWeightMatrix ? 0 : 1;
275*89c4ff92SAndroid Build Coastguard Worker return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]);
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker
BroadcastTensor(LayerInputHandle & input0,LayerInputHandle & input1,armnn::IConnectableLayer * startLayer,ConversionData & data)278*89c4ff92SAndroid Build Coastguard Worker bool BroadcastTensor(LayerInputHandle& input0,
279*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle& input1,
280*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* startLayer,
281*89c4ff92SAndroid Build Coastguard Worker ConversionData& data)
282*89c4ff92SAndroid Build Coastguard Worker {
283*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(startLayer != nullptr);
284*89c4ff92SAndroid Build Coastguard Worker
285*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo0 = input0.GetTensorInfo();
286*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo1 = input1.GetTensorInfo();
287*89c4ff92SAndroid Build Coastguard Worker
288*89c4ff92SAndroid Build Coastguard Worker unsigned int inputDimensions0 = inputInfo0.GetNumDimensions();
289*89c4ff92SAndroid Build Coastguard Worker unsigned int inputDimensions1 = inputInfo1.GetNumDimensions();
290*89c4ff92SAndroid Build Coastguard Worker
291*89c4ff92SAndroid Build Coastguard Worker if (inputDimensions0 == inputDimensions1)
292*89c4ff92SAndroid Build Coastguard Worker {
293*89c4ff92SAndroid Build Coastguard Worker // The inputs have the same number of dimensions, simply connect them to the given layer as they are
294*89c4ff92SAndroid Build Coastguard Worker input0.Connect(startLayer->GetInputSlot(0));
295*89c4ff92SAndroid Build Coastguard Worker input1.Connect(startLayer->GetInputSlot(1));
296*89c4ff92SAndroid Build Coastguard Worker
297*89c4ff92SAndroid Build Coastguard Worker return true;
298*89c4ff92SAndroid Build Coastguard Worker }
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker // Since the number of dimensions do not match then we need to add degenerate dimensions
301*89c4ff92SAndroid Build Coastguard Worker // to the "smaller" tensor using a reshape, while keeping the order of the inputs.
302*89c4ff92SAndroid Build Coastguard Worker
303*89c4ff92SAndroid Build Coastguard Worker unsigned int maxInputDimensions = std::max(inputDimensions0, inputDimensions1);
304*89c4ff92SAndroid Build Coastguard Worker unsigned int sizeDifference = std::abs(armnn::numeric_cast<int>(inputDimensions0) -
305*89c4ff92SAndroid Build Coastguard Worker armnn::numeric_cast<int>(inputDimensions1));
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker bool input0IsSmaller = inputDimensions0 < inputDimensions1;
308*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle& smallInputHandle = input0IsSmaller ? input0 : input1;
309*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& smallInfo = smallInputHandle.GetTensorInfo();
310*89c4ff92SAndroid Build Coastguard Worker
311*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape& smallShape = smallInfo.GetShape();
312*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> reshapedDimensions(maxInputDimensions, 1);
313*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = sizeDifference; i < maxInputDimensions; i++)
314*89c4ff92SAndroid Build Coastguard Worker {
315*89c4ff92SAndroid Build Coastguard Worker reshapedDimensions[i] = smallShape[i - sizeDifference];
316*89c4ff92SAndroid Build Coastguard Worker }
317*89c4ff92SAndroid Build Coastguard Worker
318*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo reshapedInfo = smallInfo;
319*89c4ff92SAndroid Build Coastguard Worker reshapedInfo.SetShape(armnn::TensorShape{ armnn::numeric_cast<unsigned int>(reshapedDimensions.size()),
320*89c4ff92SAndroid Build Coastguard Worker reshapedDimensions.data() });
321*89c4ff92SAndroid Build Coastguard Worker
322*89c4ff92SAndroid Build Coastguard Worker // RehsapeDescriptor that is ignored in the IsReshapeSupported function
323*89c4ff92SAndroid Build Coastguard Worker armnn::ReshapeDescriptor reshapeDescriptor;
324*89c4ff92SAndroid Build Coastguard Worker
325*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
326*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
327*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
328*89c4ff92SAndroid Build Coastguard Worker IsReshapeSupported,
329*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
330*89c4ff92SAndroid Build Coastguard Worker isSupported,
331*89c4ff92SAndroid Build Coastguard Worker setBackend,
332*89c4ff92SAndroid Build Coastguard Worker smallInfo,
333*89c4ff92SAndroid Build Coastguard Worker reshapedInfo,
334*89c4ff92SAndroid Build Coastguard Worker reshapeDescriptor);
335*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
336*89c4ff92SAndroid Build Coastguard Worker {
337*89c4ff92SAndroid Build Coastguard Worker return false;
338*89c4ff92SAndroid Build Coastguard Worker }
339*89c4ff92SAndroid Build Coastguard Worker
340*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(data.m_Network != nullptr);
341*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& reshapeLayer = AddReshapeLayer(*data.m_Network, smallInputHandle, reshapedInfo);
342*89c4ff92SAndroid Build Coastguard Worker reshapeLayer.SetBackendId(setBackend);
343*89c4ff92SAndroid Build Coastguard Worker
344*89c4ff92SAndroid Build Coastguard Worker if (input0IsSmaller)
345*89c4ff92SAndroid Build Coastguard Worker {
346*89c4ff92SAndroid Build Coastguard Worker // Input0 is the "smaller" tensor, connect the reshape layer as follows:
347*89c4ff92SAndroid Build Coastguard Worker //
348*89c4ff92SAndroid Build Coastguard Worker // Input0 Input1
349*89c4ff92SAndroid Build Coastguard Worker // | |
350*89c4ff92SAndroid Build Coastguard Worker // Reshape |
351*89c4ff92SAndroid Build Coastguard Worker // \ /
352*89c4ff92SAndroid Build Coastguard Worker // StartLayer
353*89c4ff92SAndroid Build Coastguard Worker
354*89c4ff92SAndroid Build Coastguard Worker reshapeLayer.GetOutputSlot(0).Connect(startLayer->GetInputSlot(0));
355*89c4ff92SAndroid Build Coastguard Worker input1.Connect(startLayer->GetInputSlot(1));
356*89c4ff92SAndroid Build Coastguard Worker }
357*89c4ff92SAndroid Build Coastguard Worker else
358*89c4ff92SAndroid Build Coastguard Worker {
359*89c4ff92SAndroid Build Coastguard Worker // Input1 is the "smaller" tensor, connect the reshape layer as follows:
360*89c4ff92SAndroid Build Coastguard Worker //
361*89c4ff92SAndroid Build Coastguard Worker // Input0 Input1
362*89c4ff92SAndroid Build Coastguard Worker // | |
363*89c4ff92SAndroid Build Coastguard Worker // | Reshape
364*89c4ff92SAndroid Build Coastguard Worker // \ /
365*89c4ff92SAndroid Build Coastguard Worker // StartLayer
366*89c4ff92SAndroid Build Coastguard Worker
367*89c4ff92SAndroid Build Coastguard Worker input0.Connect(startLayer->GetInputSlot(0));
368*89c4ff92SAndroid Build Coastguard Worker reshapeLayer.GetOutputSlot(0).Connect(startLayer->GetInputSlot(1));
369*89c4ff92SAndroid Build Coastguard Worker }
370*89c4ff92SAndroid Build Coastguard Worker
371*89c4ff92SAndroid Build Coastguard Worker return true;
372*89c4ff92SAndroid Build Coastguard Worker }
373*89c4ff92SAndroid Build Coastguard Worker
CalcPadding(uint32_t input,uint32_t kernel,uint32_t stride,uint32_t & outPadHead,uint32_t & outPadTail,PaddingScheme scheme)374*89c4ff92SAndroid Build Coastguard Worker void CalcPadding(uint32_t input,
375*89c4ff92SAndroid Build Coastguard Worker uint32_t kernel,
376*89c4ff92SAndroid Build Coastguard Worker uint32_t stride,
377*89c4ff92SAndroid Build Coastguard Worker uint32_t& outPadHead,
378*89c4ff92SAndroid Build Coastguard Worker uint32_t& outPadTail,
379*89c4ff92SAndroid Build Coastguard Worker PaddingScheme scheme)
380*89c4ff92SAndroid Build Coastguard Worker {
381*89c4ff92SAndroid Build Coastguard Worker int32_t padHead;
382*89c4ff92SAndroid Build Coastguard Worker int32_t padTail;
383*89c4ff92SAndroid Build Coastguard Worker calculateExplicitPadding(input, stride, kernel, scheme, &padHead, &padTail);
384*89c4ff92SAndroid Build Coastguard Worker outPadHead = armnn::numeric_cast<uint32_t>(padHead);
385*89c4ff92SAndroid Build Coastguard Worker outPadTail = armnn::numeric_cast<uint32_t>(padTail);
386*89c4ff92SAndroid Build Coastguard Worker }
387*89c4ff92SAndroid Build Coastguard Worker
CalcPadding(uint32_t input,uint32_t kernel,uint32_t stride,uint32_t dilation,uint32_t & outPadHead,uint32_t & outPadTail,::android::nn::PaddingScheme scheme)388*89c4ff92SAndroid Build Coastguard Worker void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t dilation, uint32_t& outPadHead,
389*89c4ff92SAndroid Build Coastguard Worker uint32_t& outPadTail, ::android::nn::PaddingScheme scheme)
390*89c4ff92SAndroid Build Coastguard Worker {
391*89c4ff92SAndroid Build Coastguard Worker int32_t padHead;
392*89c4ff92SAndroid Build Coastguard Worker int32_t padTail;
393*89c4ff92SAndroid Build Coastguard Worker calculateExplicitPadding(input, stride, dilation, kernel, scheme, &padHead, &padTail);
394*89c4ff92SAndroid Build Coastguard Worker outPadHead = armnn::numeric_cast<uint32_t>(padHead);
395*89c4ff92SAndroid Build Coastguard Worker outPadTail = armnn::numeric_cast<uint32_t>(padTail);
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker
CalcPaddingTransposeConv(uint32_t output,uint32_t kernel,int32_t stride,int32_t & outPadHead,int32_t & outPadTail,::android::nn::PaddingScheme scheme)398*89c4ff92SAndroid Build Coastguard Worker inline void CalcPaddingTransposeConv(uint32_t output, uint32_t kernel, int32_t stride, int32_t& outPadHead,
399*89c4ff92SAndroid Build Coastguard Worker int32_t& outPadTail, ::android::nn::PaddingScheme scheme)
400*89c4ff92SAndroid Build Coastguard Worker {
401*89c4ff92SAndroid Build Coastguard Worker calculateExplicitPaddingTransposeConv(output, stride, kernel, scheme, &outPadHead, &outPadTail);
402*89c4ff92SAndroid Build Coastguard Worker }
403*89c4ff92SAndroid Build Coastguard Worker
GetOperandShape(const Operand & operand)404*89c4ff92SAndroid Build Coastguard Worker Shape GetOperandShape(const Operand& operand)
405*89c4ff92SAndroid Build Coastguard Worker {
406*89c4ff92SAndroid Build Coastguard Worker Shape shape;
407*89c4ff92SAndroid Build Coastguard Worker shape.type = OperandType(operand.type);
408*89c4ff92SAndroid Build Coastguard Worker shape.dimensions = operand.dimensions;
409*89c4ff92SAndroid Build Coastguard Worker shape.scale = operand.scale;
410*89c4ff92SAndroid Build Coastguard Worker shape.offset = operand.zeroPoint;
411*89c4ff92SAndroid Build Coastguard Worker return shape;
412*89c4ff92SAndroid Build Coastguard Worker }
413*89c4ff92SAndroid Build Coastguard Worker
414*89c4ff92SAndroid Build Coastguard Worker
415*89c4ff92SAndroid Build Coastguard Worker // ArmNN requires the bias scale to be equal to the product of the weight and input scales, which is also
416*89c4ff92SAndroid Build Coastguard Worker // what AndroidNN requires. However for some of the AndroidNN tests the values don't exactly match so
417*89c4ff92SAndroid Build Coastguard Worker // we accept some tolerance. We don't want ArmNN itself to accept these inconsistencies as it is up to the
418*89c4ff92SAndroid Build Coastguard Worker // user (us, in this case) to ensure they match.
SanitizeBiasQuantizationScale(armnn::TensorInfo & biasInfo,const armnn::TensorInfo & weightInfo,const armnn::TensorInfo & inputInfo)419*89c4ff92SAndroid Build Coastguard Worker void SanitizeBiasQuantizationScale(armnn::TensorInfo& biasInfo,
420*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& weightInfo,
421*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& inputInfo)
422*89c4ff92SAndroid Build Coastguard Worker {
423*89c4ff92SAndroid Build Coastguard Worker if (weightInfo.HasPerAxisQuantization())
424*89c4ff92SAndroid Build Coastguard Worker {
425*89c4ff92SAndroid Build Coastguard Worker // NOTE: Bias scale is always set to 0 for per-axis quantization and
426*89c4ff92SAndroid Build Coastguard Worker // it needs to be calculated: scale[i] = input_scale * weight_scale[i]
427*89c4ff92SAndroid Build Coastguard Worker auto UpdateBiasScaleValue = [&inputInfo](float biasScale) -> float
428*89c4ff92SAndroid Build Coastguard Worker {
429*89c4ff92SAndroid Build Coastguard Worker return biasScale * inputInfo.GetQuantizationScale();
430*89c4ff92SAndroid Build Coastguard Worker };
431*89c4ff92SAndroid Build Coastguard Worker
432*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasScales(weightInfo.GetQuantizationScales());
433*89c4ff92SAndroid Build Coastguard Worker std::transform(biasScales.begin(), biasScales.end(), biasScales.begin(), UpdateBiasScaleValue);
434*89c4ff92SAndroid Build Coastguard Worker
435*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetQuantizationScales(biasScales);
436*89c4ff92SAndroid Build Coastguard Worker // bias is expected to be a 1d tensor, set qdim=0
437*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetQuantizationDim(0);
438*89c4ff92SAndroid Build Coastguard Worker
439*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Bias quantization params have been updated for per-axis quantization";
440*89c4ff92SAndroid Build Coastguard Worker }
441*89c4ff92SAndroid Build Coastguard Worker else
442*89c4ff92SAndroid Build Coastguard Worker {
443*89c4ff92SAndroid Build Coastguard Worker const float expectedBiasScale = weightInfo.GetQuantizationScale() * inputInfo.GetQuantizationScale();
444*89c4ff92SAndroid Build Coastguard Worker if (biasInfo.GetQuantizationScale() != expectedBiasScale)
445*89c4ff92SAndroid Build Coastguard Worker {
446*89c4ff92SAndroid Build Coastguard Worker if (armnnUtils::within_percentage_tolerance(biasInfo.GetQuantizationScale(), expectedBiasScale, 1.0f))
447*89c4ff92SAndroid Build Coastguard Worker {
448*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Bias quantization scale has been modified to match input * weights";
449*89c4ff92SAndroid Build Coastguard Worker biasInfo.SetQuantizationScale(expectedBiasScale);
450*89c4ff92SAndroid Build Coastguard Worker }
451*89c4ff92SAndroid Build Coastguard Worker }
452*89c4ff92SAndroid Build Coastguard Worker }
453*89c4ff92SAndroid Build Coastguard Worker }
454*89c4ff92SAndroid Build Coastguard Worker
455*89c4ff92SAndroid Build Coastguard Worker // 4D Tensor Permutations
456*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector IdentityPermutation4D({ 0U, 1U, 2U, 3U });
457*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector IdentityPermutation3D({ 0U, 1U, 2U });
458*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector SwapDim2And3({ 0U, 1U, 3U, 2U });
459*89c4ff92SAndroid Build Coastguard Worker
460*89c4ff92SAndroid Build Coastguard Worker // 3D Permutation Vectors
461*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector RotateTensorLeft({ 1U, 2U, 0U });
462*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector RotateTensorRight({ 2U, 0U, 1U });
463*89c4ff92SAndroid Build Coastguard Worker
464*89c4ff92SAndroid Build Coastguard Worker template<typename OSlot>
AddTransposeLayer(armnn::INetwork & network,OSlot & input,const armnn::PermutationVector & mappings)465*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& AddTransposeLayer(armnn::INetwork& network, OSlot& input,
466*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mappings)
467*89c4ff92SAndroid Build Coastguard Worker {
468*89c4ff92SAndroid Build Coastguard Worker // Add swizzle layer
469*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* const layer = network.AddTransposeLayer(mappings);
470*89c4ff92SAndroid Build Coastguard Worker
471*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer != nullptr);
472*89c4ff92SAndroid Build Coastguard Worker
473*89c4ff92SAndroid Build Coastguard Worker // Connect input to swizzle layer
474*89c4ff92SAndroid Build Coastguard Worker input.Connect(layer->GetInputSlot(0));
475*89c4ff92SAndroid Build Coastguard Worker
476*89c4ff92SAndroid Build Coastguard Worker // Setup swizzled output
477*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outInfo = armnnUtils::TransposeTensorShape(input.GetTensorInfo(), mappings);
478*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(outInfo);
479*89c4ff92SAndroid Build Coastguard Worker
480*89c4ff92SAndroid Build Coastguard Worker return *layer;
481*89c4ff92SAndroid Build Coastguard Worker }
482*89c4ff92SAndroid Build Coastguard Worker
ValidateConcatOutputShape(const std::vector<armnn::TensorShape> & inputShapes,const armnn::TensorShape & outputShape,uint32_t concatDim)483*89c4ff92SAndroid Build Coastguard Worker bool ValidateConcatOutputShape(const std::vector<armnn::TensorShape> & inputShapes,
484*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape & outputShape,
485*89c4ff92SAndroid Build Coastguard Worker uint32_t concatDim)
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker // Validate the output shape is correct given the input shapes (which have just been validated)
488*89c4ff92SAndroid Build Coastguard Worker unsigned int numDimensions = inputShapes[0].GetNumDimensions();
489*89c4ff92SAndroid Build Coastguard Worker if (outputShape.GetNumDimensions() != numDimensions)
490*89c4ff92SAndroid Build Coastguard Worker {
491*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Output shape has wrong number of dimensions", __func__);
492*89c4ff92SAndroid Build Coastguard Worker }
493*89c4ff92SAndroid Build Coastguard Worker
494*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSizeAlongConcatenatedDimension = 0;
495*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < inputShapes.size(); i++)
496*89c4ff92SAndroid Build Coastguard Worker {
497*89c4ff92SAndroid Build Coastguard Worker outputSizeAlongConcatenatedDimension += inputShapes[i][concatDim];
498*89c4ff92SAndroid Build Coastguard Worker }
499*89c4ff92SAndroid Build Coastguard Worker
500*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numDimensions; ++i)
501*89c4ff92SAndroid Build Coastguard Worker {
502*89c4ff92SAndroid Build Coastguard Worker if (i == concatDim)
503*89c4ff92SAndroid Build Coastguard Worker {
504*89c4ff92SAndroid Build Coastguard Worker if (outputShape[i] != outputSizeAlongConcatenatedDimension)
505*89c4ff92SAndroid Build Coastguard Worker {
506*89c4ff92SAndroid Build Coastguard Worker return Fail(
507*89c4ff92SAndroid Build Coastguard Worker "%s: Invalid output shape for dimension %d (%d != %d)",
508*89c4ff92SAndroid Build Coastguard Worker __func__,
509*89c4ff92SAndroid Build Coastguard Worker i,
510*89c4ff92SAndroid Build Coastguard Worker outputShape[i],
511*89c4ff92SAndroid Build Coastguard Worker outputSizeAlongConcatenatedDimension);
512*89c4ff92SAndroid Build Coastguard Worker }
513*89c4ff92SAndroid Build Coastguard Worker }
514*89c4ff92SAndroid Build Coastguard Worker else
515*89c4ff92SAndroid Build Coastguard Worker {
516*89c4ff92SAndroid Build Coastguard Worker if (outputShape[i] != inputShapes[0][i])
517*89c4ff92SAndroid Build Coastguard Worker {
518*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Invalid output shape", __func__);
519*89c4ff92SAndroid Build Coastguard Worker }
520*89c4ff92SAndroid Build Coastguard Worker }
521*89c4ff92SAndroid Build Coastguard Worker }
522*89c4ff92SAndroid Build Coastguard Worker
523*89c4ff92SAndroid Build Coastguard Worker return true;
524*89c4ff92SAndroid Build Coastguard Worker }
525*89c4ff92SAndroid Build Coastguard Worker
RequiresReshape(armnn::TensorShape & inputShape)526*89c4ff92SAndroid Build Coastguard Worker inline bool RequiresReshape(armnn::TensorShape & inputShape)
527*89c4ff92SAndroid Build Coastguard Worker {
528*89c4ff92SAndroid Build Coastguard Worker return inputShape.GetNumDimensions() < 3;
529*89c4ff92SAndroid Build Coastguard Worker }
530*89c4ff92SAndroid Build Coastguard Worker
SwizzleInputs(armnn::INetwork & network,std::vector<LayerInputHandle> & inputs,std::vector<armnn::TensorShape> & inputShapes,const armnn::PermutationVector & mapping,std::vector<armnn::BackendId> & setBackends)531*89c4ff92SAndroid Build Coastguard Worker inline void SwizzleInputs(armnn::INetwork& network,
532*89c4ff92SAndroid Build Coastguard Worker std::vector<LayerInputHandle>& inputs,
533*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape>& inputShapes,
534*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mapping,
535*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId>& setBackends)
536*89c4ff92SAndroid Build Coastguard Worker {
537*89c4ff92SAndroid Build Coastguard Worker if (!mapping.IsEqual(IdentityPermutation4D))
538*89c4ff92SAndroid Build Coastguard Worker {
539*89c4ff92SAndroid Build Coastguard Worker size_t nInputs = inputs.size();
540*89c4ff92SAndroid Build Coastguard Worker for (size_t i=0; i<nInputs; ++i)
541*89c4ff92SAndroid Build Coastguard Worker {
542*89c4ff92SAndroid Build Coastguard Worker // add swizzle layer
543*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& swizzleLayer = AddTransposeLayer(network, inputs[i], mapping);
544*89c4ff92SAndroid Build Coastguard Worker swizzleLayer.SetBackendId(setBackends[i]);
545*89c4ff92SAndroid Build Coastguard Worker auto& outputSlot = swizzleLayer.GetOutputSlot(0);
546*89c4ff92SAndroid Build Coastguard Worker auto& outputInfo = outputSlot.GetTensorInfo();
547*89c4ff92SAndroid Build Coastguard Worker // replace inputs with the swizzled ones
548*89c4ff92SAndroid Build Coastguard Worker inputs[i] = LayerInputHandle(true, &outputSlot, outputInfo);
549*89c4ff92SAndroid Build Coastguard Worker inputShapes[i] = inputs[i].GetTensorInfo().GetShape();
550*89c4ff92SAndroid Build Coastguard Worker }
551*89c4ff92SAndroid Build Coastguard Worker }
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker
TransposeInputTensors(ConversionData & data,std::vector<LayerInputHandle> & inputs,std::vector<armnn::TensorShape> & inputShapes,const armnn::PermutationVector & mapping)554*89c4ff92SAndroid Build Coastguard Worker bool TransposeInputTensors(ConversionData& data,
555*89c4ff92SAndroid Build Coastguard Worker std::vector<LayerInputHandle>& inputs,
556*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::TensorShape>& inputShapes,
557*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& mapping)
558*89c4ff92SAndroid Build Coastguard Worker {
559*89c4ff92SAndroid Build Coastguard Worker // If we have a IdentityPermutation4D or IdentityPermutation3D then we are not permuting
560*89c4ff92SAndroid Build Coastguard Worker if (!mapping.IsEqual(IdentityPermutation4D) && !mapping.IsEqual(IdentityPermutation3D))
561*89c4ff92SAndroid Build Coastguard Worker {
562*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::BackendId> setBackendsVec;
563*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTransposeInfo;
564*89c4ff92SAndroid Build Coastguard Worker size_t nInputs = inputs.size();
565*89c4ff92SAndroid Build Coastguard Worker for (size_t i=0; i<nInputs; ++i)
566*89c4ff92SAndroid Build Coastguard Worker {
567*89c4ff92SAndroid Build Coastguard Worker // check permute layer
568*89c4ff92SAndroid Build Coastguard Worker armnn::TransposeDescriptor transposeDesc;
569*89c4ff92SAndroid Build Coastguard Worker transposeDesc.m_DimMappings = mapping;
570*89c4ff92SAndroid Build Coastguard Worker outputTransposeInfo = armnnUtils::TransposeTensorShape(inputs[i].GetTensorInfo(), mapping);
571*89c4ff92SAndroid Build Coastguard Worker
572*89c4ff92SAndroid Build Coastguard Worker bool isSupported = false;
573*89c4ff92SAndroid Build Coastguard Worker armnn::BackendId setBackend;
574*89c4ff92SAndroid Build Coastguard Worker FORWARD_LAYER_SUPPORT_FUNC(__func__,
575*89c4ff92SAndroid Build Coastguard Worker IsTransposeSupported,
576*89c4ff92SAndroid Build Coastguard Worker data.m_Backends,
577*89c4ff92SAndroid Build Coastguard Worker isSupported,
578*89c4ff92SAndroid Build Coastguard Worker setBackend,
579*89c4ff92SAndroid Build Coastguard Worker inputs[i].GetTensorInfo(),
580*89c4ff92SAndroid Build Coastguard Worker outputTransposeInfo,
581*89c4ff92SAndroid Build Coastguard Worker transposeDesc);
582*89c4ff92SAndroid Build Coastguard Worker setBackendsVec.push_back(setBackend);
583*89c4ff92SAndroid Build Coastguard Worker if (!isSupported)
584*89c4ff92SAndroid Build Coastguard Worker {
585*89c4ff92SAndroid Build Coastguard Worker return false;
586*89c4ff92SAndroid Build Coastguard Worker }
587*89c4ff92SAndroid Build Coastguard Worker
588*89c4ff92SAndroid Build Coastguard Worker }
589*89c4ff92SAndroid Build Coastguard Worker SwizzleInputs(*data.m_Network, inputs, inputShapes, mapping, setBackendsVec);
590*89c4ff92SAndroid Build Coastguard Worker }
591*89c4ff92SAndroid Build Coastguard Worker return true;
592*89c4ff92SAndroid Build Coastguard Worker }
593*89c4ff92SAndroid Build Coastguard Worker
CreateConcatPermutationParameters(const unsigned int numberOfDimensions,int32_t & concatDimension,std::pair<armnn::PermutationVector,armnn::PermutationVector> & permutationPair)594*89c4ff92SAndroid Build Coastguard Worker bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions,
595*89c4ff92SAndroid Build Coastguard Worker int32_t & concatDimension,
596*89c4ff92SAndroid Build Coastguard Worker std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
597*89c4ff92SAndroid Build Coastguard Worker {
598*89c4ff92SAndroid Build Coastguard Worker bool needPermute = false;
599*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(numberOfDimensions >= 3);
600*89c4ff92SAndroid Build Coastguard Worker
601*89c4ff92SAndroid Build Coastguard Worker // ArmNN uses Compute Library subtensors to perform concatenation
602*89c4ff92SAndroid Build Coastguard Worker // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor,
603*89c4ff92SAndroid Build Coastguard Worker // or along dimension 0 or 2 for a 3-D tensor.
604*89c4ff92SAndroid Build Coastguard Worker if (numberOfDimensions == 4 && concatDimension == 2)
605*89c4ff92SAndroid Build Coastguard Worker {
606*89c4ff92SAndroid Build Coastguard Worker concatDimension = 3;
607*89c4ff92SAndroid Build Coastguard Worker permutationPair = std::make_pair(SwapDim2And3, SwapDim2And3);
608*89c4ff92SAndroid Build Coastguard Worker needPermute = true;
609*89c4ff92SAndroid Build Coastguard Worker }
610*89c4ff92SAndroid Build Coastguard Worker else if (numberOfDimensions == 3 && concatDimension == 1)
611*89c4ff92SAndroid Build Coastguard Worker {
612*89c4ff92SAndroid Build Coastguard Worker concatDimension = 0;
613*89c4ff92SAndroid Build Coastguard Worker permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
614*89c4ff92SAndroid Build Coastguard Worker needPermute = true;
615*89c4ff92SAndroid Build Coastguard Worker }
616*89c4ff92SAndroid Build Coastguard Worker // If the tensor is 3-D and the concat dimension is 2 then we don't need to permute but we do need to change the
617*89c4ff92SAndroid Build Coastguard Worker // permutation identity to only have 3 dimensions
618*89c4ff92SAndroid Build Coastguard Worker else if (numberOfDimensions == 3 && concatDimension == 2)
619*89c4ff92SAndroid Build Coastguard Worker {
620*89c4ff92SAndroid Build Coastguard Worker permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D);
621*89c4ff92SAndroid Build Coastguard Worker }
622*89c4ff92SAndroid Build Coastguard Worker return needPermute;
623*89c4ff92SAndroid Build Coastguard Worker }
624*89c4ff92SAndroid Build Coastguard Worker
625*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
626*89c4ff92SAndroid Build Coastguard Worker
627*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
628*89c4ff92SAndroid Build Coastguard Worker {
629*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
630*89c4ff92SAndroid Build Coastguard Worker
631*89c4ff92SAndroid Build Coastguard Worker //// Creates an ArmNN activation layer and connects it to the given layer, if the
632*89c4ff92SAndroid Build Coastguard Worker //// passed in AndroidNN activation function requires so.
633*89c4ff92SAndroid Build Coastguard Worker //// @return The end layer of the sequence of layers built for the given AndroidNN
634*89c4ff92SAndroid Build Coastguard Worker //// activation function or nullptr if an error occurred (e.g. unsupported activation).
635*89c4ff92SAndroid Build Coastguard Worker //// Note that the end layer matches the input layer if no activation is required
636*89c4ff92SAndroid Build Coastguard Worker //// (the sequence of layers has length 1).
637*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
638*89c4ff92SAndroid Build Coastguard Worker ActivationFn activation,
639*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* prevLayer,
640*89c4ff92SAndroid Build Coastguard Worker ConversionData& data);
641*89c4ff92SAndroid Build Coastguard Worker
642*89c4ff92SAndroid Build Coastguard Worker
GetInputOperand(const Operation & operation,uint32_t inputIndex,const Model & model,bool failOnIndexOutOfBounds=true)643*89c4ff92SAndroid Build Coastguard Worker inline const Operand* GetInputOperand(const Operation& operation,
644*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
645*89c4ff92SAndroid Build Coastguard Worker const Model& model,
646*89c4ff92SAndroid Build Coastguard Worker bool failOnIndexOutOfBounds = true)
647*89c4ff92SAndroid Build Coastguard Worker {
648*89c4ff92SAndroid Build Coastguard Worker if (inputIndex >= operation.inputs.size())
649*89c4ff92SAndroid Build Coastguard Worker {
650*89c4ff92SAndroid Build Coastguard Worker if (failOnIndexOutOfBounds)
651*89c4ff92SAndroid Build Coastguard Worker {
652*89c4ff92SAndroid Build Coastguard Worker Fail("%s: invalid input index: %i out of %i", __func__, inputIndex, operation.inputs.size());
653*89c4ff92SAndroid Build Coastguard Worker }
654*89c4ff92SAndroid Build Coastguard Worker return nullptr;
655*89c4ff92SAndroid Build Coastguard Worker }
656*89c4ff92SAndroid Build Coastguard Worker
657*89c4ff92SAndroid Build Coastguard Worker // Model should have been validated beforehand
658*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(operation.inputs[inputIndex] < getMainModel(model).operands.size());
659*89c4ff92SAndroid Build Coastguard Worker return &getMainModel(model).operands[operation.inputs[inputIndex]];
660*89c4ff92SAndroid Build Coastguard Worker }
661*89c4ff92SAndroid Build Coastguard Worker
GetOutputOperand(const Operation & operation,uint32_t outputIndex,const Model & model)662*89c4ff92SAndroid Build Coastguard Worker inline const Operand* GetOutputOperand(const Operation& operation,
663*89c4ff92SAndroid Build Coastguard Worker uint32_t outputIndex,
664*89c4ff92SAndroid Build Coastguard Worker const Model& model)
665*89c4ff92SAndroid Build Coastguard Worker {
666*89c4ff92SAndroid Build Coastguard Worker if (outputIndex >= operation.outputs.size())
667*89c4ff92SAndroid Build Coastguard Worker {
668*89c4ff92SAndroid Build Coastguard Worker Fail("%s: invalid output index: %i out of %i", __func__, outputIndex, operation.outputs.size());
669*89c4ff92SAndroid Build Coastguard Worker return nullptr;
670*89c4ff92SAndroid Build Coastguard Worker }
671*89c4ff92SAndroid Build Coastguard Worker
672*89c4ff92SAndroid Build Coastguard Worker // Model should have been validated beforehand
673*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(operation.outputs[outputIndex] < getMainModel(model).operands.size());
674*89c4ff92SAndroid Build Coastguard Worker
675*89c4ff92SAndroid Build Coastguard Worker return &getMainModel(model).operands[operation.outputs[outputIndex]];
676*89c4ff92SAndroid Build Coastguard Worker }
677*89c4ff92SAndroid Build Coastguard Worker
678*89c4ff92SAndroid Build Coastguard Worker const void* GetOperandValueReadOnlyAddress(const Operand& operand,
679*89c4ff92SAndroid Build Coastguard Worker const Model& model,
680*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
681*89c4ff92SAndroid Build Coastguard Worker bool optional = false);
682*89c4ff92SAndroid Build Coastguard Worker
GetOperandType(const Operation & operation,uint32_t inputIndex,const Model & model,OperandType & type)683*89c4ff92SAndroid Build Coastguard Worker inline bool GetOperandType(const Operation& operation,
684*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
685*89c4ff92SAndroid Build Coastguard Worker const Model& model,
686*89c4ff92SAndroid Build Coastguard Worker OperandType& type)
687*89c4ff92SAndroid Build Coastguard Worker {
688*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
689*89c4ff92SAndroid Build Coastguard Worker if (!operand)
690*89c4ff92SAndroid Build Coastguard Worker {
691*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
692*89c4ff92SAndroid Build Coastguard Worker }
693*89c4ff92SAndroid Build Coastguard Worker
694*89c4ff92SAndroid Build Coastguard Worker type = operand->type;
695*89c4ff92SAndroid Build Coastguard Worker return true;
696*89c4ff92SAndroid Build Coastguard Worker }
697*89c4ff92SAndroid Build Coastguard Worker
IsOperandConstant(const Operand & operand)698*89c4ff92SAndroid Build Coastguard Worker inline bool IsOperandConstant(const Operand& operand)
699*89c4ff92SAndroid Build Coastguard Worker {
700*89c4ff92SAndroid Build Coastguard Worker OperandLifeTime lifetime = operand.lifetime;
701*89c4ff92SAndroid Build Coastguard Worker
702*89c4ff92SAndroid Build Coastguard Worker return lifetime == OperandLifeTime::CONSTANT_COPY ||
703*89c4ff92SAndroid Build Coastguard Worker lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
704*89c4ff92SAndroid Build Coastguard Worker lifetime == OperandLifeTime::POINTER ||
705*89c4ff92SAndroid Build Coastguard Worker lifetime == OperandLifeTime::NO_VALUE;
706*89c4ff92SAndroid Build Coastguard Worker }
707*89c4ff92SAndroid Build Coastguard Worker
708*89c4ff92SAndroid Build Coastguard Worker bool IsWeightsValid(const Operation& operation, uint32_t inputIndex, const Model& model);
709*89c4ff92SAndroid Build Coastguard Worker
710*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
711*89c4ff92SAndroid Build Coastguard Worker const Model& model,
712*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
713*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& dimensionMappings = g_DontPermute,
714*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape* overrideTensorShape = nullptr,
715*89c4ff92SAndroid Build Coastguard Worker bool optional = false,
716*89c4ff92SAndroid Build Coastguard Worker const armnn::DataType* overrideDataType = nullptr);
717*89c4ff92SAndroid Build Coastguard Worker
ConvertOperationInputToConstTensorPin(const Operation & operation,uint32_t inputIndex,const Model & model,const ConversionData & data,const armnn::PermutationVector & dimensionMappings=g_DontPermute,const armnn::TensorShape * overrideTensorShape=nullptr,bool optional=false)718*89c4ff92SAndroid Build Coastguard Worker inline ConstTensorPin ConvertOperationInputToConstTensorPin(
719*89c4ff92SAndroid Build Coastguard Worker const Operation& operation,
720*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
721*89c4ff92SAndroid Build Coastguard Worker const Model& model,
722*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
723*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& dimensionMappings = g_DontPermute,
724*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorShape* overrideTensorShape = nullptr,
725*89c4ff92SAndroid Build Coastguard Worker bool optional = false)
726*89c4ff92SAndroid Build Coastguard Worker {
727*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
728*89c4ff92SAndroid Build Coastguard Worker if (!operand)
729*89c4ff92SAndroid Build Coastguard Worker {
730*89c4ff92SAndroid Build Coastguard Worker Fail("%s: failed to get input operand: index=%u", __func__, inputIndex);
731*89c4ff92SAndroid Build Coastguard Worker return ConstTensorPin();
732*89c4ff92SAndroid Build Coastguard Worker }
733*89c4ff92SAndroid Build Coastguard Worker return ConvertOperandToConstTensorPin(*operand,
734*89c4ff92SAndroid Build Coastguard Worker model,
735*89c4ff92SAndroid Build Coastguard Worker data,
736*89c4ff92SAndroid Build Coastguard Worker dimensionMappings,
737*89c4ff92SAndroid Build Coastguard Worker overrideTensorShape,
738*89c4ff92SAndroid Build Coastguard Worker optional);
739*89c4ff92SAndroid Build Coastguard Worker }
740*89c4ff92SAndroid Build Coastguard Worker
741*89c4ff92SAndroid Build Coastguard Worker template <typename OutputType>
GetInputScalar(const Operation & operation,uint32_t inputIndex,OperandType type,OutputType & outValue,const Model & model,const ConversionData & data,bool optional=false)742*89c4ff92SAndroid Build Coastguard Worker bool GetInputScalar(const Operation& operation,
743*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
744*89c4ff92SAndroid Build Coastguard Worker OperandType type,
745*89c4ff92SAndroid Build Coastguard Worker OutputType& outValue,
746*89c4ff92SAndroid Build Coastguard Worker const Model& model,
747*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
748*89c4ff92SAndroid Build Coastguard Worker bool optional = false)
749*89c4ff92SAndroid Build Coastguard Worker {
750*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
751*89c4ff92SAndroid Build Coastguard Worker if (!optional && !operand)
752*89c4ff92SAndroid Build Coastguard Worker {
753*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
754*89c4ff92SAndroid Build Coastguard Worker }
755*89c4ff92SAndroid Build Coastguard Worker
756*89c4ff92SAndroid Build Coastguard Worker if (!optional && operand->type != type)
757*89c4ff92SAndroid Build Coastguard Worker {
758*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": unexpected operand type: " << operand->type << " should be: " << type;
759*89c4ff92SAndroid Build Coastguard Worker return false;
760*89c4ff92SAndroid Build Coastguard Worker }
761*89c4ff92SAndroid Build Coastguard Worker
762*89c4ff92SAndroid Build Coastguard Worker if (!optional && operand->location.length != sizeof(OutputType))
763*89c4ff92SAndroid Build Coastguard Worker {
764*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: incorrect operand location length: %i (should be %i)",
765*89c4ff92SAndroid Build Coastguard Worker __func__, operand->location.length, sizeof(OutputType));
766*89c4ff92SAndroid Build Coastguard Worker }
767*89c4ff92SAndroid Build Coastguard Worker
768*89c4ff92SAndroid Build Coastguard Worker const void* valueAddress = GetOperandValueReadOnlyAddress(*operand, model, data);
769*89c4ff92SAndroid Build Coastguard Worker if (!optional && !valueAddress)
770*89c4ff92SAndroid Build Coastguard Worker {
771*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: failed to get address for operand", __func__);
772*89c4ff92SAndroid Build Coastguard Worker }
773*89c4ff92SAndroid Build Coastguard Worker
774*89c4ff92SAndroid Build Coastguard Worker if(!optional)
775*89c4ff92SAndroid Build Coastguard Worker {
776*89c4ff92SAndroid Build Coastguard Worker outValue = *(static_cast<const OutputType*>(valueAddress));
777*89c4ff92SAndroid Build Coastguard Worker }
778*89c4ff92SAndroid Build Coastguard Worker
779*89c4ff92SAndroid Build Coastguard Worker return true;
780*89c4ff92SAndroid Build Coastguard Worker }
781*89c4ff92SAndroid Build Coastguard Worker
GetInputInt32(const Operation & operation,uint32_t inputIndex,int32_t & outValue,const Model & model,const ConversionData & data)782*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputInt32(const Operation& operation,
783*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
784*89c4ff92SAndroid Build Coastguard Worker int32_t& outValue,
785*89c4ff92SAndroid Build Coastguard Worker const Model& model,
786*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
787*89c4ff92SAndroid Build Coastguard Worker {
788*89c4ff92SAndroid Build Coastguard Worker return GetInputScalar(operation, inputIndex, OperandType::INT32, outValue, model, data);
789*89c4ff92SAndroid Build Coastguard Worker }
790*89c4ff92SAndroid Build Coastguard Worker
GetInputFloat32(const Operation & operation,uint32_t inputIndex,float & outValue,const Model & model,const ConversionData & data)791*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputFloat32(const Operation& operation,
792*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
793*89c4ff92SAndroid Build Coastguard Worker float& outValue,
794*89c4ff92SAndroid Build Coastguard Worker const Model& model,
795*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
796*89c4ff92SAndroid Build Coastguard Worker {
797*89c4ff92SAndroid Build Coastguard Worker return GetInputScalar(operation, inputIndex, OperandType::FLOAT32, outValue, model, data);
798*89c4ff92SAndroid Build Coastguard Worker }
799*89c4ff92SAndroid Build Coastguard Worker
GetInputActivationFunctionImpl(const Operation & operation,uint32_t inputIndex,OperandType type,ActivationFn & outActivationFunction,const Model & model,const ConversionData & data)800*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputActivationFunctionImpl(const Operation& operation,
801*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
802*89c4ff92SAndroid Build Coastguard Worker OperandType type,
803*89c4ff92SAndroid Build Coastguard Worker ActivationFn& outActivationFunction,
804*89c4ff92SAndroid Build Coastguard Worker const Model& model,
805*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
806*89c4ff92SAndroid Build Coastguard Worker {
807*89c4ff92SAndroid Build Coastguard Worker if (type != OperandType::INT32 && type != OperandType::TENSOR_INT32)
808*89c4ff92SAndroid Build Coastguard Worker {
809*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << __func__ << ": unexpected operand type: " << type
810*89c4ff92SAndroid Build Coastguard Worker << " should be OperandType::INT32 or OperandType::TENSOR_INT32";
811*89c4ff92SAndroid Build Coastguard Worker return false;
812*89c4ff92SAndroid Build Coastguard Worker }
813*89c4ff92SAndroid Build Coastguard Worker
814*89c4ff92SAndroid Build Coastguard Worker int32_t activationFunctionAsInt;
815*89c4ff92SAndroid Build Coastguard Worker if (!GetInputScalar(operation, inputIndex, type, activationFunctionAsInt, model, data))
816*89c4ff92SAndroid Build Coastguard Worker {
817*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: failed to get activation input value", __func__);
818*89c4ff92SAndroid Build Coastguard Worker }
819*89c4ff92SAndroid Build Coastguard Worker outActivationFunction = static_cast<ActivationFn>(activationFunctionAsInt);
820*89c4ff92SAndroid Build Coastguard Worker return true;
821*89c4ff92SAndroid Build Coastguard Worker }
822*89c4ff92SAndroid Build Coastguard Worker
GetInputActivationFunction(const Operation & operation,uint32_t inputIndex,ActivationFn & outActivationFunction,const Model & model,const ConversionData & data)823*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputActivationFunction(const Operation& operation,
824*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
825*89c4ff92SAndroid Build Coastguard Worker ActivationFn& outActivationFunction,
826*89c4ff92SAndroid Build Coastguard Worker const Model& model,
827*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
828*89c4ff92SAndroid Build Coastguard Worker {
829*89c4ff92SAndroid Build Coastguard Worker return GetInputActivationFunctionImpl(operation,
830*89c4ff92SAndroid Build Coastguard Worker inputIndex,
831*89c4ff92SAndroid Build Coastguard Worker OperandType::INT32,
832*89c4ff92SAndroid Build Coastguard Worker outActivationFunction,
833*89c4ff92SAndroid Build Coastguard Worker model,
834*89c4ff92SAndroid Build Coastguard Worker data);
835*89c4ff92SAndroid Build Coastguard Worker }
836*89c4ff92SAndroid Build Coastguard Worker
GetInputActivationFunctionFromTensor(const Operation & operation,uint32_t inputIndex,ActivationFn & outActivationFunction,const Model & model,const ConversionData & data)837*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputActivationFunctionFromTensor(const Operation& operation,
838*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
839*89c4ff92SAndroid Build Coastguard Worker ActivationFn& outActivationFunction,
840*89c4ff92SAndroid Build Coastguard Worker const Model& model,
841*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
842*89c4ff92SAndroid Build Coastguard Worker {
843*89c4ff92SAndroid Build Coastguard Worker // This only accepts a 1-D tensor of size 1
844*89c4ff92SAndroid Build Coastguard Worker return GetInputActivationFunctionImpl(operation,
845*89c4ff92SAndroid Build Coastguard Worker inputIndex,
846*89c4ff92SAndroid Build Coastguard Worker OperandType::INT32,
847*89c4ff92SAndroid Build Coastguard Worker outActivationFunction,
848*89c4ff92SAndroid Build Coastguard Worker model,
849*89c4ff92SAndroid Build Coastguard Worker data);
850*89c4ff92SAndroid Build Coastguard Worker }
851*89c4ff92SAndroid Build Coastguard Worker
852*89c4ff92SAndroid Build Coastguard Worker
GetOptionalInputActivation(const Operation & operation,uint32_t inputIndex,ActivationFn & activationFunction,const Model & model,const ConversionData & data)853*89c4ff92SAndroid Build Coastguard Worker inline bool GetOptionalInputActivation(const Operation& operation,
854*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
855*89c4ff92SAndroid Build Coastguard Worker ActivationFn& activationFunction,
856*89c4ff92SAndroid Build Coastguard Worker const Model& model,
857*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
858*89c4ff92SAndroid Build Coastguard Worker {
859*89c4ff92SAndroid Build Coastguard Worker if (operation.inputs.size() <= inputIndex)
860*89c4ff92SAndroid Build Coastguard Worker {
861*89c4ff92SAndroid Build Coastguard Worker activationFunction = ActivationFn::kActivationNone;
862*89c4ff92SAndroid Build Coastguard Worker }
863*89c4ff92SAndroid Build Coastguard Worker else
864*89c4ff92SAndroid Build Coastguard Worker {
865*89c4ff92SAndroid Build Coastguard Worker if (!GetInputActivationFunction(operation, inputIndex, activationFunction, model, data))
866*89c4ff92SAndroid Build Coastguard Worker {
867*89c4ff92SAndroid Build Coastguard Worker return Fail("%s: Operation has invalid inputs", __func__);
868*89c4ff92SAndroid Build Coastguard Worker }
869*89c4ff92SAndroid Build Coastguard Worker }
870*89c4ff92SAndroid Build Coastguard Worker return true;
871*89c4ff92SAndroid Build Coastguard Worker }
872*89c4ff92SAndroid Build Coastguard Worker
873*89c4ff92SAndroid Build Coastguard Worker template<typename ConvolutionDescriptor>
GetOptionalConvolutionDilationParams(const Operation & operation,uint32_t dilationXIndex,ConvolutionDescriptor & descriptor,const Model & model,const ConversionData & data)874*89c4ff92SAndroid Build Coastguard Worker bool GetOptionalConvolutionDilationParams(const Operation& operation,
875*89c4ff92SAndroid Build Coastguard Worker uint32_t dilationXIndex,
876*89c4ff92SAndroid Build Coastguard Worker ConvolutionDescriptor& descriptor,
877*89c4ff92SAndroid Build Coastguard Worker const Model& model,
878*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
879*89c4ff92SAndroid Build Coastguard Worker {
880*89c4ff92SAndroid Build Coastguard Worker bool success = true;
881*89c4ff92SAndroid Build Coastguard Worker if (operation.inputs.size() >= dilationXIndex + 2)
882*89c4ff92SAndroid Build Coastguard Worker {
883*89c4ff92SAndroid Build Coastguard Worker success &= GetInputScalar(operation,
884*89c4ff92SAndroid Build Coastguard Worker dilationXIndex,
885*89c4ff92SAndroid Build Coastguard Worker OperandType::INT32,
886*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationX,
887*89c4ff92SAndroid Build Coastguard Worker model,
888*89c4ff92SAndroid Build Coastguard Worker data);
889*89c4ff92SAndroid Build Coastguard Worker success &= GetInputScalar(operation,
890*89c4ff92SAndroid Build Coastguard Worker dilationXIndex + 1,
891*89c4ff92SAndroid Build Coastguard Worker OperandType::INT32,
892*89c4ff92SAndroid Build Coastguard Worker descriptor.m_DilationY,
893*89c4ff92SAndroid Build Coastguard Worker model,
894*89c4ff92SAndroid Build Coastguard Worker data);
895*89c4ff92SAndroid Build Coastguard Worker }
896*89c4ff92SAndroid Build Coastguard Worker
897*89c4ff92SAndroid Build Coastguard Worker return success;
898*89c4ff92SAndroid Build Coastguard Worker }
899*89c4ff92SAndroid Build Coastguard Worker
GetOptionalBool(const Operation & operation,uint32_t inputIndex,const Model & model,const ConversionData & data)900*89c4ff92SAndroid Build Coastguard Worker inline bool GetOptionalBool(const Operation& operation,
901*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
902*89c4ff92SAndroid Build Coastguard Worker const Model& model,
903*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data)
904*89c4ff92SAndroid Build Coastguard Worker {
905*89c4ff92SAndroid Build Coastguard Worker const Operand* operand = GetInputOperand(operation, inputIndex, model);
906*89c4ff92SAndroid Build Coastguard Worker if (!operand)
907*89c4ff92SAndroid Build Coastguard Worker {
908*89c4ff92SAndroid Build Coastguard Worker return false;
909*89c4ff92SAndroid Build Coastguard Worker }
910*89c4ff92SAndroid Build Coastguard Worker
911*89c4ff92SAndroid Build Coastguard Worker if (!IsBool(*operand))
912*89c4ff92SAndroid Build Coastguard Worker {
913*89c4ff92SAndroid Build Coastguard Worker return false;
914*89c4ff92SAndroid Build Coastguard Worker }
915*89c4ff92SAndroid Build Coastguard Worker
916*89c4ff92SAndroid Build Coastguard Worker const void* valueAddress = GetOperandValueReadOnlyAddress(*operand, model, data);
917*89c4ff92SAndroid Build Coastguard Worker if (!valueAddress)
918*89c4ff92SAndroid Build Coastguard Worker {
919*89c4ff92SAndroid Build Coastguard Worker return false;
920*89c4ff92SAndroid Build Coastguard Worker }
921*89c4ff92SAndroid Build Coastguard Worker
922*89c4ff92SAndroid Build Coastguard Worker return *(static_cast<const bool*>(valueAddress));
923*89c4ff92SAndroid Build Coastguard Worker }
924*89c4ff92SAndroid Build Coastguard Worker
925*89c4ff92SAndroid Build Coastguard Worker bool GetTensorInt32Values(const Operand& operand,
926*89c4ff92SAndroid Build Coastguard Worker std::vector<int32_t>& outValues,
927*89c4ff92SAndroid Build Coastguard Worker const Model& model,
928*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data);
929*89c4ff92SAndroid Build Coastguard Worker
930*89c4ff92SAndroid Build Coastguard Worker bool GetInputPaddingScheme(const Operation& operation,
931*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
932*89c4ff92SAndroid Build Coastguard Worker PaddingScheme& outPaddingScheme,
933*89c4ff92SAndroid Build Coastguard Worker const Model& model,
934*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data);
935*89c4ff92SAndroid Build Coastguard Worker
936*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
937*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
938*89c4ff92SAndroid Build Coastguard Worker const Model& model,
939*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
940*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector& dimensionMappings = g_DontPermute,
941*89c4ff92SAndroid Build Coastguard Worker const LayerInputHandle* inputHandle = nullptr);
942*89c4ff92SAndroid Build Coastguard Worker
943*89c4ff92SAndroid Build Coastguard Worker bool SetupAndTrackLayerOutputSlot(const Operation& operation,
944*89c4ff92SAndroid Build Coastguard Worker uint32_t operationOutputIndex,
945*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& layer,
946*89c4ff92SAndroid Build Coastguard Worker uint32_t layerOutputIndex,
947*89c4ff92SAndroid Build Coastguard Worker const Model& model,
948*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
949*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo* overrideOutputInfo = nullptr,
950*89c4ff92SAndroid Build Coastguard Worker const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr,
951*89c4ff92SAndroid Build Coastguard Worker const ActivationFn& activationFunction = ActivationFn::kActivationNone,
952*89c4ff92SAndroid Build Coastguard Worker bool inferOutputShapes = false);
953*89c4ff92SAndroid Build Coastguard Worker
954*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout OptionalDataLayout(const Operation& operation,
955*89c4ff92SAndroid Build Coastguard Worker uint32_t inputIndex,
956*89c4ff92SAndroid Build Coastguard Worker const Model& model,
957*89c4ff92SAndroid Build Coastguard Worker ConversionData& data);
958*89c4ff92SAndroid Build Coastguard Worker
SetupAndTrackLayerOutputSlot(const Operation & operation,uint32_t outputIndex,armnn::IConnectableLayer & layer,const Model & model,ConversionData & data,const armnn::TensorInfo * overrideOutputInfo=nullptr,const std::function<void (const armnn::TensorInfo &,bool &)> & validateFunc=nullptr,const ActivationFn & activationFunction=ActivationFn::kActivationNone)959*89c4ff92SAndroid Build Coastguard Worker inline bool SetupAndTrackLayerOutputSlot(
960*89c4ff92SAndroid Build Coastguard Worker const Operation& operation,
961*89c4ff92SAndroid Build Coastguard Worker uint32_t outputIndex,
962*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& layer,
963*89c4ff92SAndroid Build Coastguard Worker const Model& model,
964*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
965*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo* overrideOutputInfo = nullptr,
966*89c4ff92SAndroid Build Coastguard Worker const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr,
967*89c4ff92SAndroid Build Coastguard Worker const ActivationFn& activationFunction = ActivationFn::kActivationNone)
968*89c4ff92SAndroid Build Coastguard Worker {
969*89c4ff92SAndroid Build Coastguard Worker return SetupAndTrackLayerOutputSlot(operation,
970*89c4ff92SAndroid Build Coastguard Worker outputIndex,
971*89c4ff92SAndroid Build Coastguard Worker layer,
972*89c4ff92SAndroid Build Coastguard Worker outputIndex,
973*89c4ff92SAndroid Build Coastguard Worker model,
974*89c4ff92SAndroid Build Coastguard Worker data,
975*89c4ff92SAndroid Build Coastguard Worker overrideOutputInfo,
976*89c4ff92SAndroid Build Coastguard Worker validateFunc,
977*89c4ff92SAndroid Build Coastguard Worker activationFunction);
978*89c4ff92SAndroid Build Coastguard Worker }
979*89c4ff92SAndroid Build Coastguard Worker
980*89c4ff92SAndroid Build Coastguard Worker bool ConvertToActivation(const Operation& operation,
981*89c4ff92SAndroid Build Coastguard Worker const char* operationName,
982*89c4ff92SAndroid Build Coastguard Worker const armnn::ActivationDescriptor& activationDesc,
983*89c4ff92SAndroid Build Coastguard Worker const Model& model,
984*89c4ff92SAndroid Build Coastguard Worker ConversionData& data);
985*89c4ff92SAndroid Build Coastguard Worker
986*89c4ff92SAndroid Build Coastguard Worker bool ConvertPaddings(const Operation& operation,
987*89c4ff92SAndroid Build Coastguard Worker const Model& model,
988*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
989*89c4ff92SAndroid Build Coastguard Worker unsigned int rank,
990*89c4ff92SAndroid Build Coastguard Worker armnn::PadDescriptor& padDescriptor);
991*89c4ff92SAndroid Build Coastguard Worker bool ConvertReduce(const Operation& operation,
992*89c4ff92SAndroid Build Coastguard Worker const Model& model,
993*89c4ff92SAndroid Build Coastguard Worker ConversionData& data,
994*89c4ff92SAndroid Build Coastguard Worker armnn::ReduceOperation reduceOperation);
995*89c4ff92SAndroid Build Coastguard Worker
996*89c4ff92SAndroid Build Coastguard Worker bool ConvertPooling2d(const Operation& operation,
997*89c4ff92SAndroid Build Coastguard Worker const char* operationName,
998*89c4ff92SAndroid Build Coastguard Worker armnn::PoolingAlgorithm poolType,
999*89c4ff92SAndroid Build Coastguard Worker const Model& model,
1000*89c4ff92SAndroid Build Coastguard Worker ConversionData& data);
1001*89c4ff92SAndroid Build Coastguard Worker
IsQSymm8(const Operand & operand)1002*89c4ff92SAndroid Build Coastguard Worker inline bool IsQSymm8(const Operand& operand)
1003*89c4ff92SAndroid Build Coastguard Worker {
1004*89c4ff92SAndroid Build Coastguard Worker return operand.type == OperandType::TENSOR_QUANT8_SYMM;
1005*89c4ff92SAndroid Build Coastguard Worker }
1006*89c4ff92SAndroid Build Coastguard Worker
1007*89c4ff92SAndroid Build Coastguard Worker enum class DequantizeStatus
1008*89c4ff92SAndroid Build Coastguard Worker {
1009*89c4ff92SAndroid Build Coastguard Worker SUCCESS,
1010*89c4ff92SAndroid Build Coastguard Worker NOT_REQUIRED,
1011*89c4ff92SAndroid Build Coastguard Worker INVALID_OPERAND
1012*89c4ff92SAndroid Build Coastguard Worker };
1013*89c4ff92SAndroid Build Coastguard Worker
1014*89c4ff92SAndroid Build Coastguard Worker using DequantizeResult = std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo, DequantizeStatus>;
1015*89c4ff92SAndroid Build Coastguard Worker
1016*89c4ff92SAndroid Build Coastguard Worker DequantizeResult DequantizeIfRequired(size_t operand_index,
1017*89c4ff92SAndroid Build Coastguard Worker const Operation& operation,
1018*89c4ff92SAndroid Build Coastguard Worker const Model& model,
1019*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data);
1020*89c4ff92SAndroid Build Coastguard Worker
1021*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation,
1022*89c4ff92SAndroid Build Coastguard Worker const Model& model,
1023*89c4ff92SAndroid Build Coastguard Worker const ConversionData& data,
1024*89c4ff92SAndroid Build Coastguard Worker size_t operandIndex,
1025*89c4ff92SAndroid Build Coastguard Worker bool optional = false);
1026*89c4ff92SAndroid Build Coastguard Worker
1027*89c4ff92SAndroid Build Coastguard Worker bool IsConnectedToDequantize(armnn::IOutputSlot* ioutputSlot);
1028*89c4ff92SAndroid Build Coastguard Worker
1029*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
1030