xref: /aosp_15_r20/external/armnn/shim/sl/canonical/ConversionUtils.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "CanonicalUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/ArmNN.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/NumericCast.hpp>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/DataLayoutIndexed.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Transpose.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <ActivationFunctor.h>
20*89c4ff92SAndroid Build Coastguard Worker #include <CpuExecutor.h>
21*89c4ff92SAndroid Build Coastguard Worker #include <OperationsUtils.h>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/FloatingPointComparison.hpp>
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <log/log.h>
26*89c4ff92SAndroid Build Coastguard Worker #include <vector>
27*89c4ff92SAndroid Build Coastguard Worker 
getMainModel(const android::nn::Model & model)28*89c4ff92SAndroid Build Coastguard Worker inline const android::nn::Model::Subgraph& getMainModel(const android::nn::Model& model) { return model.main; }
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker ///
34*89c4ff92SAndroid Build Coastguard Worker /// Helper classes
35*89c4ff92SAndroid Build Coastguard Worker ///
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/OperandTypes.h>
38*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h>
39*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/TypeUtils.h>
40*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h>
41*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Validation.h>
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker using Model                     = ::android::nn::Model;
44*89c4ff92SAndroid Build Coastguard Worker using Operand                   = ::android::nn::Operand;
45*89c4ff92SAndroid Build Coastguard Worker using OperandLifeTime           = ::android::nn::Operand::LifeTime;
46*89c4ff92SAndroid Build Coastguard Worker using OperandType               = ::android::nn::OperandType;
47*89c4ff92SAndroid Build Coastguard Worker using Operation                 = ::android::nn::Operation;
48*89c4ff92SAndroid Build Coastguard Worker using OperationType             = ::android::nn::OperationType;
49*89c4ff92SAndroid Build Coastguard Worker using ErrorStatus               = ::android::nn::ErrorStatus;
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker struct ConversionData
52*89c4ff92SAndroid Build Coastguard Worker {
ConversionDataarmnn_driver::ConversionData53*89c4ff92SAndroid Build Coastguard Worker     ConversionData(const std::vector<armnn::BackendId>& backends)
54*89c4ff92SAndroid Build Coastguard Worker     : m_Backends(backends)
55*89c4ff92SAndroid Build Coastguard Worker     , m_Network(nullptr, nullptr)
56*89c4ff92SAndroid Build Coastguard Worker     , m_DynamicInputsEncountered(false)
57*89c4ff92SAndroid Build Coastguard Worker     {}
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     const std::vector<armnn::BackendId>       m_Backends;
60*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr                        m_Network;
61*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::IOutputSlot*>          m_OutputSlotForOperand;
62*89c4ff92SAndroid Build Coastguard Worker     std::vector<::android::nn::RunTimePoolInfo> m_MemPools;
63*89c4ff92SAndroid Build Coastguard Worker     bool m_DynamicInputsEncountered;
64*89c4ff92SAndroid Build Coastguard Worker };
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker class LayerInputHandle
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker public:
69*89c4ff92SAndroid Build Coastguard Worker     LayerInputHandle();
70*89c4ff92SAndroid Build Coastguard Worker     LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo);
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker     bool IsValid() const;
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     void Connect(armnn::IInputSlot& inputSlot);
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     void Disconnect(armnn::IInputSlot& inputSlot);
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& GetTensorInfo() const;
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker     void SanitizeQuantizationScale(LayerInputHandle& weight, LayerInputHandle& input);
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot* GetOutputSlot() const;
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker private:
85*89c4ff92SAndroid Build Coastguard Worker     armnn::IOutputSlot* m_OutputSlot;
86*89c4ff92SAndroid Build Coastguard Worker     bool                m_Valid;
87*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo   m_TensorInfo;
88*89c4ff92SAndroid Build Coastguard Worker };
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker class ConstTensorPin
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker public:
93*89c4ff92SAndroid Build Coastguard Worker     // Creates an invalid tensor pin (can be used to signal errors)
94*89c4ff92SAndroid Build Coastguard Worker     // The optional flag can be set to indicate the tensor values were missing, but it was otherwise valid
95*89c4ff92SAndroid Build Coastguard Worker     ConstTensorPin(bool optional = false);
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker     // @param tensorInfo TensorInfo associated with the tensor.
98*89c4ff92SAndroid Build Coastguard Worker     // @param valueStart Start address of tensor data. Belongs to one of the memory pools associated with
99*89c4ff92SAndroid Build Coastguard Worker     // the model being converted.
100*89c4ff92SAndroid Build Coastguard Worker     // @param numBytes Number of bytes for the tensor data.
101*89c4ff92SAndroid Build Coastguard Worker     ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes,
102*89c4ff92SAndroid Build Coastguard Worker                    const armnn::PermutationVector& mappings);
103*89c4ff92SAndroid Build Coastguard Worker 
104*89c4ff92SAndroid Build Coastguard Worker     ConstTensorPin(const ConstTensorPin& other) = delete;
105*89c4ff92SAndroid Build Coastguard Worker     ConstTensorPin(ConstTensorPin&& other)      = default;
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     bool IsValid() const;
108*89c4ff92SAndroid Build Coastguard Worker     bool IsOptional() const;
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     const armnn::ConstTensor& GetConstTensor() const;
111*89c4ff92SAndroid Build Coastguard Worker     const armnn::ConstTensor* GetConstTensorPtr() const;
112*89c4ff92SAndroid Build Coastguard Worker 
113*89c4ff92SAndroid Build Coastguard Worker private:
114*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor m_ConstTensor;
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     // Owned memory for swizzled tensor data, only required if the tensor needed
117*89c4ff92SAndroid Build Coastguard Worker     // swizzling. Otherwise, @ref m_ConstTensor will reference memory from one of
118*89c4ff92SAndroid Build Coastguard Worker     // the pools associated with the model being converted.
119*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_SwizzledTensorData;
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     // optional flag to indicate that an invalid tensor pin is not an error, but the optional values were not given
122*89c4ff92SAndroid Build Coastguard Worker     bool m_Optional;
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker enum class ConversionResult
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     Success,
128*89c4ff92SAndroid Build Coastguard Worker     ErrorMappingPools,
129*89c4ff92SAndroid Build Coastguard Worker     UnsupportedFeature
130*89c4ff92SAndroid Build Coastguard Worker };
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker ///
135*89c4ff92SAndroid Build Coastguard Worker /// Utility functions
136*89c4ff92SAndroid Build Coastguard Worker ///
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker namespace
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker using namespace armnn_driver;
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker // Convenience function to log the reason for failing to convert a model.
143*89c4ff92SAndroid Build Coastguard Worker // @return Always returns false (so that it can be used by callers as a quick way to signal an error and return)
144*89c4ff92SAndroid Build Coastguard Worker template<class... Args>
Fail(const char * formatStr,Args &&...args)145*89c4ff92SAndroid Build Coastguard Worker static bool Fail(const char* formatStr, Args&&... args)
146*89c4ff92SAndroid Build Coastguard Worker {
147*89c4ff92SAndroid Build Coastguard Worker     ALOGD(formatStr, std::forward<Args>(args)...);
148*89c4ff92SAndroid Build Coastguard Worker     return false;
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker // Convenience macro to call an Is*Supported function and log caller name together with reason for lack of support.
152*89c4ff92SAndroid Build Coastguard Worker // Called as: FORWARD_LAYER_SUPPORT_FUNC(__func__, Is*Supported, backends, a, b, c, d, e)
153*89c4ff92SAndroid Build Coastguard Worker #define FORWARD_LAYER_SUPPORT_FUNC(funcName, func, backends, supported, setBackend, ...) \
154*89c4ff92SAndroid Build Coastguard Worker try \
155*89c4ff92SAndroid Build Coastguard Worker { \
156*89c4ff92SAndroid Build Coastguard Worker     for (auto&& backendId : backends) \
157*89c4ff92SAndroid Build Coastguard Worker     { \
158*89c4ff92SAndroid Build Coastguard Worker         auto layerSupportObject = armnn::GetILayerSupportByBackendId(backendId); \
159*89c4ff92SAndroid Build Coastguard Worker         if (layerSupportObject.IsBackendRegistered()) \
160*89c4ff92SAndroid Build Coastguard Worker         { \
161*89c4ff92SAndroid Build Coastguard Worker             std::string reasonIfUnsupported; \
162*89c4ff92SAndroid Build Coastguard Worker             supported = \
163*89c4ff92SAndroid Build Coastguard Worker                 layerSupportObject.func(__VA_ARGS__, armnn::Optional<std::string&>(reasonIfUnsupported)); \
164*89c4ff92SAndroid Build Coastguard Worker             if (supported) \
165*89c4ff92SAndroid Build Coastguard Worker             { \
166*89c4ff92SAndroid Build Coastguard Worker                 setBackend = backendId; \
167*89c4ff92SAndroid Build Coastguard Worker                 break; \
168*89c4ff92SAndroid Build Coastguard Worker             } \
169*89c4ff92SAndroid Build Coastguard Worker             else \
170*89c4ff92SAndroid Build Coastguard Worker             { \
171*89c4ff92SAndroid Build Coastguard Worker                 if (reasonIfUnsupported.size() > 0) \
172*89c4ff92SAndroid Build Coastguard Worker                 { \
173*89c4ff92SAndroid Build Coastguard Worker                     VLOG(DRIVER) << funcName << ": not supported by armnn: " <<  reasonIfUnsupported.c_str(); \
174*89c4ff92SAndroid Build Coastguard Worker                 } \
175*89c4ff92SAndroid Build Coastguard Worker                 else \
176*89c4ff92SAndroid Build Coastguard Worker                 { \
177*89c4ff92SAndroid Build Coastguard Worker                     VLOG(DRIVER) << funcName << ": not supported by armnn"; \
178*89c4ff92SAndroid Build Coastguard Worker                 } \
179*89c4ff92SAndroid Build Coastguard Worker             } \
180*89c4ff92SAndroid Build Coastguard Worker         } \
181*89c4ff92SAndroid Build Coastguard Worker         else \
182*89c4ff92SAndroid Build Coastguard Worker         { \
183*89c4ff92SAndroid Build Coastguard Worker             VLOG(DRIVER) << funcName << ": backend not registered: " << backendId.Get().c_str(); \
184*89c4ff92SAndroid Build Coastguard Worker         } \
185*89c4ff92SAndroid Build Coastguard Worker     } \
186*89c4ff92SAndroid Build Coastguard Worker     if (!supported) \
187*89c4ff92SAndroid Build Coastguard Worker     { \
188*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << funcName << ": not supported by any specified backend"; \
189*89c4ff92SAndroid Build Coastguard Worker     } \
190*89c4ff92SAndroid Build Coastguard Worker } \
191*89c4ff92SAndroid Build Coastguard Worker catch (const armnn::InvalidArgumentException &e) \
192*89c4ff92SAndroid Build Coastguard Worker { \
193*89c4ff92SAndroid Build Coastguard Worker     throw armnn::InvalidArgumentException(e, "Failed to check layer support", CHECK_LOCATION()); \
194*89c4ff92SAndroid Build Coastguard Worker }
195*89c4ff92SAndroid Build Coastguard Worker 
GetTensorShapeForOperand(const Operand & operand)196*89c4ff92SAndroid Build Coastguard Worker inline armnn::TensorShape GetTensorShapeForOperand(const Operand& operand)
197*89c4ff92SAndroid Build Coastguard Worker {
198*89c4ff92SAndroid Build Coastguard Worker     return armnn::TensorShape(operand.dimensions.size(), operand.dimensions.data());
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker // Support within the 1.3 driver for specific tensor data types
IsOperandTypeSupportedForTensors(OperandType type)202*89c4ff92SAndroid Build Coastguard Worker inline bool IsOperandTypeSupportedForTensors(OperandType type)
203*89c4ff92SAndroid Build Coastguard Worker {
204*89c4ff92SAndroid Build Coastguard Worker     return type == OperandType::BOOL                           ||
205*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_BOOL8                   ||
206*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_FLOAT16                 ||
207*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_FLOAT32                 ||
208*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_QUANT8_ASYMM            ||
209*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_QUANT8_ASYMM_SIGNED     ||
210*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_QUANT8_SYMM             ||
211*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL ||
212*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_QUANT16_SYMM            ||
213*89c4ff92SAndroid Build Coastguard Worker            type == OperandType::TENSOR_INT32;
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker 
IsBool(Operand operand)216*89c4ff92SAndroid Build Coastguard Worker inline bool IsBool(Operand operand)
217*89c4ff92SAndroid Build Coastguard Worker {
218*89c4ff92SAndroid Build Coastguard Worker     return operand.type == OperandType::BOOL;
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker 
Is12OrLaterOperand(Operand)221*89c4ff92SAndroid Build Coastguard Worker inline bool Is12OrLaterOperand(Operand)
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker     return true;
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker 
226*89c4ff92SAndroid Build Coastguard Worker 
227*89c4ff92SAndroid Build Coastguard Worker template<typename LayerHandleType>
AddReshapeLayer(armnn::INetwork & network,LayerHandleType & inputLayer,armnn::TensorInfo reshapeInfo)228*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& AddReshapeLayer(armnn::INetwork& network,
229*89c4ff92SAndroid Build Coastguard Worker                                           LayerHandleType& inputLayer,
230*89c4ff92SAndroid Build Coastguard Worker                                           armnn::TensorInfo reshapeInfo)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker     armnn::ReshapeDescriptor reshapeDescriptor;
233*89c4ff92SAndroid Build Coastguard Worker     reshapeDescriptor.m_TargetShape = reshapeInfo.GetShape();
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* reshapeLayer = network.AddReshapeLayer(reshapeDescriptor);
236*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(reshapeLayer != nullptr);
237*89c4ff92SAndroid Build Coastguard Worker 
238*89c4ff92SAndroid Build Coastguard Worker     // Attach the input layer to the reshape layer
239*89c4ff92SAndroid Build Coastguard Worker     inputLayer.Connect(reshapeLayer->GetInputSlot(0));
240*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer->GetOutputSlot(0).SetTensorInfo(reshapeInfo);
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker     return *reshapeLayer;
243*89c4ff92SAndroid Build Coastguard Worker }
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker 
FlattenFullyConnectedInput(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape)246*89c4ff92SAndroid Build Coastguard Worker  armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape,
247*89c4ff92SAndroid Build Coastguard Worker                                                const armnn::TensorShape& weightsShape)
248*89c4ff92SAndroid Build Coastguard Worker {
249*89c4ff92SAndroid Build Coastguard Worker     if (inputShape.GetNumDimensions() > 2U)
250*89c4ff92SAndroid Build Coastguard Worker     {
251*89c4ff92SAndroid Build Coastguard Worker         unsigned int totalInputElements = inputShape.GetNumElements();
252*89c4ff92SAndroid Build Coastguard Worker         unsigned int inputSize = weightsShape[1];
253*89c4ff92SAndroid Build Coastguard Worker 
254*89c4ff92SAndroid Build Coastguard Worker         unsigned int batchSize = totalInputElements / inputSize;
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker         if(totalInputElements % batchSize != 0)
257*89c4ff92SAndroid Build Coastguard Worker         {
258*89c4ff92SAndroid Build Coastguard Worker             throw std::runtime_error("Failed to deduce tensor shape");
259*89c4ff92SAndroid Build Coastguard Worker         }
260*89c4ff92SAndroid Build Coastguard Worker 
261*89c4ff92SAndroid Build Coastguard Worker         return armnn::TensorShape({batchSize, inputSize});
262*89c4ff92SAndroid Build Coastguard Worker     }
263*89c4ff92SAndroid Build Coastguard Worker     else
264*89c4ff92SAndroid Build Coastguard Worker     {
265*89c4ff92SAndroid Build Coastguard Worker         return inputShape;
266*89c4ff92SAndroid Build Coastguard Worker     }
267*89c4ff92SAndroid Build Coastguard Worker }
268*89c4ff92SAndroid Build Coastguard Worker 
VerifyFullyConnectedShapes(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape,const armnn::TensorShape & outputShape,bool transposeWeightMatrix)269*89c4ff92SAndroid Build Coastguard Worker inline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape,
270*89c4ff92SAndroid Build Coastguard Worker                                        const armnn::TensorShape& weightsShape,
271*89c4ff92SAndroid Build Coastguard Worker                                        const armnn::TensorShape& outputShape,
272*89c4ff92SAndroid Build Coastguard Worker                                        bool  transposeWeightMatrix)
273*89c4ff92SAndroid Build Coastguard Worker {
274*89c4ff92SAndroid Build Coastguard Worker     unsigned int dimIdx = transposeWeightMatrix ? 0 : 1;
275*89c4ff92SAndroid Build Coastguard Worker     return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]);
276*89c4ff92SAndroid Build Coastguard Worker }
277*89c4ff92SAndroid Build Coastguard Worker 
BroadcastTensor(LayerInputHandle & input0,LayerInputHandle & input1,armnn::IConnectableLayer * startLayer,ConversionData & data)278*89c4ff92SAndroid Build Coastguard Worker bool BroadcastTensor(LayerInputHandle& input0,
279*89c4ff92SAndroid Build Coastguard Worker                      LayerInputHandle& input1,
280*89c4ff92SAndroid Build Coastguard Worker                      armnn::IConnectableLayer* startLayer,
281*89c4ff92SAndroid Build Coastguard Worker                      ConversionData& data)
282*89c4ff92SAndroid Build Coastguard Worker {
283*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(startLayer != nullptr);
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputInfo0 = input0.GetTensorInfo();
286*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& inputInfo1 = input1.GetTensorInfo();
287*89c4ff92SAndroid Build Coastguard Worker 
288*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputDimensions0 = inputInfo0.GetNumDimensions();
289*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputDimensions1 = inputInfo1.GetNumDimensions();
290*89c4ff92SAndroid Build Coastguard Worker 
291*89c4ff92SAndroid Build Coastguard Worker     if (inputDimensions0 == inputDimensions1)
292*89c4ff92SAndroid Build Coastguard Worker     {
293*89c4ff92SAndroid Build Coastguard Worker         // The inputs have the same number of dimensions, simply connect them to the given layer as they are
294*89c4ff92SAndroid Build Coastguard Worker         input0.Connect(startLayer->GetInputSlot(0));
295*89c4ff92SAndroid Build Coastguard Worker         input1.Connect(startLayer->GetInputSlot(1));
296*89c4ff92SAndroid Build Coastguard Worker 
297*89c4ff92SAndroid Build Coastguard Worker         return true;
298*89c4ff92SAndroid Build Coastguard Worker     }
299*89c4ff92SAndroid Build Coastguard Worker 
300*89c4ff92SAndroid Build Coastguard Worker     // Since the number of dimensions do not match then we need to add degenerate dimensions
301*89c4ff92SAndroid Build Coastguard Worker     // to the "smaller" tensor using a reshape, while keeping the order of the inputs.
302*89c4ff92SAndroid Build Coastguard Worker 
303*89c4ff92SAndroid Build Coastguard Worker     unsigned int maxInputDimensions = std::max(inputDimensions0, inputDimensions1);
304*89c4ff92SAndroid Build Coastguard Worker     unsigned int sizeDifference = std::abs(armnn::numeric_cast<int>(inputDimensions0) -
305*89c4ff92SAndroid Build Coastguard Worker                                            armnn::numeric_cast<int>(inputDimensions1));
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker     bool input0IsSmaller = inputDimensions0 < inputDimensions1;
308*89c4ff92SAndroid Build Coastguard Worker     LayerInputHandle& smallInputHandle = input0IsSmaller ? input0 : input1;
309*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo& smallInfo = smallInputHandle.GetTensorInfo();
310*89c4ff92SAndroid Build Coastguard Worker 
311*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape& smallShape = smallInfo.GetShape();
312*89c4ff92SAndroid Build Coastguard Worker     std::vector<unsigned int> reshapedDimensions(maxInputDimensions, 1);
313*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = sizeDifference; i < maxInputDimensions; i++)
314*89c4ff92SAndroid Build Coastguard Worker     {
315*89c4ff92SAndroid Build Coastguard Worker         reshapedDimensions[i] = smallShape[i - sizeDifference];
316*89c4ff92SAndroid Build Coastguard Worker     }
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo reshapedInfo = smallInfo;
319*89c4ff92SAndroid Build Coastguard Worker     reshapedInfo.SetShape(armnn::TensorShape{ armnn::numeric_cast<unsigned int>(reshapedDimensions.size()),
320*89c4ff92SAndroid Build Coastguard Worker                                               reshapedDimensions.data() });
321*89c4ff92SAndroid Build Coastguard Worker 
322*89c4ff92SAndroid Build Coastguard Worker     // RehsapeDescriptor that is ignored in the IsReshapeSupported function
323*89c4ff92SAndroid Build Coastguard Worker     armnn::ReshapeDescriptor reshapeDescriptor;
324*89c4ff92SAndroid Build Coastguard Worker 
325*89c4ff92SAndroid Build Coastguard Worker     bool isSupported = false;
326*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendId setBackend;
327*89c4ff92SAndroid Build Coastguard Worker     FORWARD_LAYER_SUPPORT_FUNC(__func__,
328*89c4ff92SAndroid Build Coastguard Worker                                IsReshapeSupported,
329*89c4ff92SAndroid Build Coastguard Worker                                data.m_Backends,
330*89c4ff92SAndroid Build Coastguard Worker                                isSupported,
331*89c4ff92SAndroid Build Coastguard Worker                                setBackend,
332*89c4ff92SAndroid Build Coastguard Worker                                smallInfo,
333*89c4ff92SAndroid Build Coastguard Worker                                reshapedInfo,
334*89c4ff92SAndroid Build Coastguard Worker                                reshapeDescriptor);
335*89c4ff92SAndroid Build Coastguard Worker     if (!isSupported)
336*89c4ff92SAndroid Build Coastguard Worker     {
337*89c4ff92SAndroid Build Coastguard Worker         return false;
338*89c4ff92SAndroid Build Coastguard Worker     }
339*89c4ff92SAndroid Build Coastguard Worker 
340*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(data.m_Network != nullptr);
341*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer& reshapeLayer = AddReshapeLayer(*data.m_Network, smallInputHandle, reshapedInfo);
342*89c4ff92SAndroid Build Coastguard Worker     reshapeLayer.SetBackendId(setBackend);
343*89c4ff92SAndroid Build Coastguard Worker 
344*89c4ff92SAndroid Build Coastguard Worker     if (input0IsSmaller)
345*89c4ff92SAndroid Build Coastguard Worker     {
346*89c4ff92SAndroid Build Coastguard Worker         // Input0 is the "smaller" tensor, connect the reshape layer as follows:
347*89c4ff92SAndroid Build Coastguard Worker         //
348*89c4ff92SAndroid Build Coastguard Worker         //  Input0 Input1
349*89c4ff92SAndroid Build Coastguard Worker         //     |     |
350*89c4ff92SAndroid Build Coastguard Worker         //  Reshape  |
351*89c4ff92SAndroid Build Coastguard Worker         //      \   /
352*89c4ff92SAndroid Build Coastguard Worker         //    StartLayer
353*89c4ff92SAndroid Build Coastguard Worker 
354*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer.GetOutputSlot(0).Connect(startLayer->GetInputSlot(0));
355*89c4ff92SAndroid Build Coastguard Worker         input1.Connect(startLayer->GetInputSlot(1));
356*89c4ff92SAndroid Build Coastguard Worker     }
357*89c4ff92SAndroid Build Coastguard Worker     else
358*89c4ff92SAndroid Build Coastguard Worker     {
359*89c4ff92SAndroid Build Coastguard Worker         // Input1 is the "smaller" tensor, connect the reshape layer as follows:
360*89c4ff92SAndroid Build Coastguard Worker         //
361*89c4ff92SAndroid Build Coastguard Worker         //  Input0 Input1
362*89c4ff92SAndroid Build Coastguard Worker         //     |     |
363*89c4ff92SAndroid Build Coastguard Worker         //     |  Reshape
364*89c4ff92SAndroid Build Coastguard Worker         //      \   /
365*89c4ff92SAndroid Build Coastguard Worker         //    StartLayer
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker         input0.Connect(startLayer->GetInputSlot(0));
368*89c4ff92SAndroid Build Coastguard Worker         reshapeLayer.GetOutputSlot(0).Connect(startLayer->GetInputSlot(1));
369*89c4ff92SAndroid Build Coastguard Worker     }
370*89c4ff92SAndroid Build Coastguard Worker 
371*89c4ff92SAndroid Build Coastguard Worker     return true;
372*89c4ff92SAndroid Build Coastguard Worker }
373*89c4ff92SAndroid Build Coastguard Worker 
CalcPadding(uint32_t input,uint32_t kernel,uint32_t stride,uint32_t & outPadHead,uint32_t & outPadTail,PaddingScheme scheme)374*89c4ff92SAndroid Build Coastguard Worker void CalcPadding(uint32_t input,
375*89c4ff92SAndroid Build Coastguard Worker                  uint32_t kernel,
376*89c4ff92SAndroid Build Coastguard Worker                  uint32_t stride,
377*89c4ff92SAndroid Build Coastguard Worker                  uint32_t& outPadHead,
378*89c4ff92SAndroid Build Coastguard Worker                  uint32_t& outPadTail,
379*89c4ff92SAndroid Build Coastguard Worker                  PaddingScheme scheme)
380*89c4ff92SAndroid Build Coastguard Worker {
381*89c4ff92SAndroid Build Coastguard Worker     int32_t padHead;
382*89c4ff92SAndroid Build Coastguard Worker     int32_t padTail;
383*89c4ff92SAndroid Build Coastguard Worker     calculateExplicitPadding(input, stride, kernel, scheme, &padHead, &padTail);
384*89c4ff92SAndroid Build Coastguard Worker     outPadHead = armnn::numeric_cast<uint32_t>(padHead);
385*89c4ff92SAndroid Build Coastguard Worker     outPadTail = armnn::numeric_cast<uint32_t>(padTail);
386*89c4ff92SAndroid Build Coastguard Worker }
387*89c4ff92SAndroid Build Coastguard Worker 
CalcPadding(uint32_t input,uint32_t kernel,uint32_t stride,uint32_t dilation,uint32_t & outPadHead,uint32_t & outPadTail,::android::nn::PaddingScheme scheme)388*89c4ff92SAndroid Build Coastguard Worker void CalcPadding(uint32_t input, uint32_t kernel, uint32_t stride, uint32_t dilation, uint32_t& outPadHead,
389*89c4ff92SAndroid Build Coastguard Worker                  uint32_t& outPadTail, ::android::nn::PaddingScheme scheme)
390*89c4ff92SAndroid Build Coastguard Worker {
391*89c4ff92SAndroid Build Coastguard Worker     int32_t padHead;
392*89c4ff92SAndroid Build Coastguard Worker     int32_t padTail;
393*89c4ff92SAndroid Build Coastguard Worker     calculateExplicitPadding(input, stride, dilation, kernel, scheme, &padHead, &padTail);
394*89c4ff92SAndroid Build Coastguard Worker     outPadHead = armnn::numeric_cast<uint32_t>(padHead);
395*89c4ff92SAndroid Build Coastguard Worker     outPadTail = armnn::numeric_cast<uint32_t>(padTail);
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker 
CalcPaddingTransposeConv(uint32_t output,uint32_t kernel,int32_t stride,int32_t & outPadHead,int32_t & outPadTail,::android::nn::PaddingScheme scheme)398*89c4ff92SAndroid Build Coastguard Worker inline void CalcPaddingTransposeConv(uint32_t output, uint32_t kernel, int32_t stride, int32_t& outPadHead,
399*89c4ff92SAndroid Build Coastguard Worker                               int32_t& outPadTail, ::android::nn::PaddingScheme scheme)
400*89c4ff92SAndroid Build Coastguard Worker {
401*89c4ff92SAndroid Build Coastguard Worker     calculateExplicitPaddingTransposeConv(output, stride, kernel, scheme, &outPadHead, &outPadTail);
402*89c4ff92SAndroid Build Coastguard Worker }
403*89c4ff92SAndroid Build Coastguard Worker 
GetOperandShape(const Operand & operand)404*89c4ff92SAndroid Build Coastguard Worker Shape GetOperandShape(const Operand& operand)
405*89c4ff92SAndroid Build Coastguard Worker {
406*89c4ff92SAndroid Build Coastguard Worker     Shape shape;
407*89c4ff92SAndroid Build Coastguard Worker     shape.type = OperandType(operand.type);
408*89c4ff92SAndroid Build Coastguard Worker     shape.dimensions = operand.dimensions;
409*89c4ff92SAndroid Build Coastguard Worker     shape.scale = operand.scale;
410*89c4ff92SAndroid Build Coastguard Worker     shape.offset = operand.zeroPoint;
411*89c4ff92SAndroid Build Coastguard Worker     return shape;
412*89c4ff92SAndroid Build Coastguard Worker }
413*89c4ff92SAndroid Build Coastguard Worker 
414*89c4ff92SAndroid Build Coastguard Worker 
415*89c4ff92SAndroid Build Coastguard Worker // ArmNN requires the bias scale to be equal to the product of the weight and input scales, which is also
416*89c4ff92SAndroid Build Coastguard Worker // what AndroidNN requires. However for some of the AndroidNN tests the values don't exactly match so
417*89c4ff92SAndroid Build Coastguard Worker // we accept some tolerance. We don't want ArmNN itself to accept these inconsistencies as it is up to the
418*89c4ff92SAndroid Build Coastguard Worker // user (us, in this case) to ensure they match.
SanitizeBiasQuantizationScale(armnn::TensorInfo & biasInfo,const armnn::TensorInfo & weightInfo,const armnn::TensorInfo & inputInfo)419*89c4ff92SAndroid Build Coastguard Worker void SanitizeBiasQuantizationScale(armnn::TensorInfo& biasInfo,
420*89c4ff92SAndroid Build Coastguard Worker                                    const armnn::TensorInfo& weightInfo,
421*89c4ff92SAndroid Build Coastguard Worker                                    const armnn::TensorInfo& inputInfo)
422*89c4ff92SAndroid Build Coastguard Worker {
423*89c4ff92SAndroid Build Coastguard Worker     if (weightInfo.HasPerAxisQuantization())
424*89c4ff92SAndroid Build Coastguard Worker     {
425*89c4ff92SAndroid Build Coastguard Worker         // NOTE: Bias scale is always set to 0 for per-axis quantization and
426*89c4ff92SAndroid Build Coastguard Worker         // it needs to be calculated: scale[i] = input_scale * weight_scale[i]
427*89c4ff92SAndroid Build Coastguard Worker         auto UpdateBiasScaleValue = [&inputInfo](float biasScale) -> float
428*89c4ff92SAndroid Build Coastguard Worker         {
429*89c4ff92SAndroid Build Coastguard Worker             return biasScale * inputInfo.GetQuantizationScale();
430*89c4ff92SAndroid Build Coastguard Worker         };
431*89c4ff92SAndroid Build Coastguard Worker 
432*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> biasScales(weightInfo.GetQuantizationScales());
433*89c4ff92SAndroid Build Coastguard Worker         std::transform(biasScales.begin(), biasScales.end(), biasScales.begin(), UpdateBiasScaleValue);
434*89c4ff92SAndroid Build Coastguard Worker 
435*89c4ff92SAndroid Build Coastguard Worker         biasInfo.SetQuantizationScales(biasScales);
436*89c4ff92SAndroid Build Coastguard Worker         // bias is expected to be a 1d tensor, set qdim=0
437*89c4ff92SAndroid Build Coastguard Worker         biasInfo.SetQuantizationDim(0);
438*89c4ff92SAndroid Build Coastguard Worker 
439*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << "Bias quantization params have been updated for per-axis quantization";
440*89c4ff92SAndroid Build Coastguard Worker     }
441*89c4ff92SAndroid Build Coastguard Worker     else
442*89c4ff92SAndroid Build Coastguard Worker     {
443*89c4ff92SAndroid Build Coastguard Worker         const float expectedBiasScale = weightInfo.GetQuantizationScale() * inputInfo.GetQuantizationScale();
444*89c4ff92SAndroid Build Coastguard Worker         if (biasInfo.GetQuantizationScale() != expectedBiasScale)
445*89c4ff92SAndroid Build Coastguard Worker         {
446*89c4ff92SAndroid Build Coastguard Worker             if (armnnUtils::within_percentage_tolerance(biasInfo.GetQuantizationScale(), expectedBiasScale, 1.0f))
447*89c4ff92SAndroid Build Coastguard Worker             {
448*89c4ff92SAndroid Build Coastguard Worker                 VLOG(DRIVER) << "Bias quantization scale has been modified to match input * weights";
449*89c4ff92SAndroid Build Coastguard Worker                 biasInfo.SetQuantizationScale(expectedBiasScale);
450*89c4ff92SAndroid Build Coastguard Worker             }
451*89c4ff92SAndroid Build Coastguard Worker         }
452*89c4ff92SAndroid Build Coastguard Worker     }
453*89c4ff92SAndroid Build Coastguard Worker }
454*89c4ff92SAndroid Build Coastguard Worker 
455*89c4ff92SAndroid Build Coastguard Worker // 4D Tensor Permutations
456*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector IdentityPermutation4D({ 0U, 1U, 2U, 3U });
457*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector IdentityPermutation3D({ 0U, 1U, 2U });
458*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector SwapDim2And3({ 0U, 1U, 3U, 2U });
459*89c4ff92SAndroid Build Coastguard Worker 
460*89c4ff92SAndroid Build Coastguard Worker // 3D Permutation Vectors
461*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector RotateTensorLeft({ 1U, 2U, 0U });
462*89c4ff92SAndroid Build Coastguard Worker const armnn::PermutationVector RotateTensorRight({ 2U, 0U, 1U });
463*89c4ff92SAndroid Build Coastguard Worker 
464*89c4ff92SAndroid Build Coastguard Worker template<typename OSlot>
AddTransposeLayer(armnn::INetwork & network,OSlot & input,const armnn::PermutationVector & mappings)465*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer& AddTransposeLayer(armnn::INetwork& network, OSlot& input,
466*89c4ff92SAndroid Build Coastguard Worker                                             const armnn::PermutationVector& mappings)
467*89c4ff92SAndroid Build Coastguard Worker {
468*89c4ff92SAndroid Build Coastguard Worker     // Add swizzle layer
469*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* const layer = network.AddTransposeLayer(mappings);
470*89c4ff92SAndroid Build Coastguard Worker 
471*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(layer != nullptr);
472*89c4ff92SAndroid Build Coastguard Worker 
473*89c4ff92SAndroid Build Coastguard Worker     // Connect input to swizzle layer
474*89c4ff92SAndroid Build Coastguard Worker     input.Connect(layer->GetInputSlot(0));
475*89c4ff92SAndroid Build Coastguard Worker 
476*89c4ff92SAndroid Build Coastguard Worker     // Setup swizzled output
477*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo outInfo = armnnUtils::TransposeTensorShape(input.GetTensorInfo(), mappings);
478*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(outInfo);
479*89c4ff92SAndroid Build Coastguard Worker 
480*89c4ff92SAndroid Build Coastguard Worker     return *layer;
481*89c4ff92SAndroid Build Coastguard Worker }
482*89c4ff92SAndroid Build Coastguard Worker 
ValidateConcatOutputShape(const std::vector<armnn::TensorShape> & inputShapes,const armnn::TensorShape & outputShape,uint32_t concatDim)483*89c4ff92SAndroid Build Coastguard Worker bool ValidateConcatOutputShape(const std::vector<armnn::TensorShape> & inputShapes,
484*89c4ff92SAndroid Build Coastguard Worker                                const armnn::TensorShape & outputShape,
485*89c4ff92SAndroid Build Coastguard Worker                                uint32_t concatDim)
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker     // Validate the output shape is correct given the input shapes (which have just been validated)
488*89c4ff92SAndroid Build Coastguard Worker     unsigned int numDimensions = inputShapes[0].GetNumDimensions();
489*89c4ff92SAndroid Build Coastguard Worker     if (outputShape.GetNumDimensions() != numDimensions)
490*89c4ff92SAndroid Build Coastguard Worker     {
491*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: Output shape has wrong number of dimensions", __func__);
492*89c4ff92SAndroid Build Coastguard Worker     }
493*89c4ff92SAndroid Build Coastguard Worker 
494*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputSizeAlongConcatenatedDimension = 0;
495*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < inputShapes.size(); i++)
496*89c4ff92SAndroid Build Coastguard Worker     {
497*89c4ff92SAndroid Build Coastguard Worker         outputSizeAlongConcatenatedDimension += inputShapes[i][concatDim];
498*89c4ff92SAndroid Build Coastguard Worker     }
499*89c4ff92SAndroid Build Coastguard Worker 
500*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < numDimensions; ++i)
501*89c4ff92SAndroid Build Coastguard Worker     {
502*89c4ff92SAndroid Build Coastguard Worker         if (i == concatDim)
503*89c4ff92SAndroid Build Coastguard Worker         {
504*89c4ff92SAndroid Build Coastguard Worker             if (outputShape[i] != outputSizeAlongConcatenatedDimension)
505*89c4ff92SAndroid Build Coastguard Worker             {
506*89c4ff92SAndroid Build Coastguard Worker                 return Fail(
507*89c4ff92SAndroid Build Coastguard Worker                         "%s: Invalid output shape for dimension %d (%d != %d)",
508*89c4ff92SAndroid Build Coastguard Worker                         __func__,
509*89c4ff92SAndroid Build Coastguard Worker                         i,
510*89c4ff92SAndroid Build Coastguard Worker                         outputShape[i],
511*89c4ff92SAndroid Build Coastguard Worker                         outputSizeAlongConcatenatedDimension);
512*89c4ff92SAndroid Build Coastguard Worker             }
513*89c4ff92SAndroid Build Coastguard Worker         }
514*89c4ff92SAndroid Build Coastguard Worker         else
515*89c4ff92SAndroid Build Coastguard Worker         {
516*89c4ff92SAndroid Build Coastguard Worker             if (outputShape[i] != inputShapes[0][i])
517*89c4ff92SAndroid Build Coastguard Worker             {
518*89c4ff92SAndroid Build Coastguard Worker                 return Fail("%s: Invalid output shape", __func__);
519*89c4ff92SAndroid Build Coastguard Worker             }
520*89c4ff92SAndroid Build Coastguard Worker         }
521*89c4ff92SAndroid Build Coastguard Worker     }
522*89c4ff92SAndroid Build Coastguard Worker 
523*89c4ff92SAndroid Build Coastguard Worker     return true;
524*89c4ff92SAndroid Build Coastguard Worker }
525*89c4ff92SAndroid Build Coastguard Worker 
RequiresReshape(armnn::TensorShape & inputShape)526*89c4ff92SAndroid Build Coastguard Worker inline bool RequiresReshape(armnn::TensorShape & inputShape)
527*89c4ff92SAndroid Build Coastguard Worker {
528*89c4ff92SAndroid Build Coastguard Worker     return inputShape.GetNumDimensions() < 3;
529*89c4ff92SAndroid Build Coastguard Worker }
530*89c4ff92SAndroid Build Coastguard Worker 
SwizzleInputs(armnn::INetwork & network,std::vector<LayerInputHandle> & inputs,std::vector<armnn::TensorShape> & inputShapes,const armnn::PermutationVector & mapping,std::vector<armnn::BackendId> & setBackends)531*89c4ff92SAndroid Build Coastguard Worker inline void SwizzleInputs(armnn::INetwork& network,
532*89c4ff92SAndroid Build Coastguard Worker                    std::vector<LayerInputHandle>& inputs,
533*89c4ff92SAndroid Build Coastguard Worker                    std::vector<armnn::TensorShape>& inputShapes,
534*89c4ff92SAndroid Build Coastguard Worker                    const armnn::PermutationVector& mapping,
535*89c4ff92SAndroid Build Coastguard Worker                    std::vector<armnn::BackendId>& setBackends)
536*89c4ff92SAndroid Build Coastguard Worker {
537*89c4ff92SAndroid Build Coastguard Worker     if (!mapping.IsEqual(IdentityPermutation4D))
538*89c4ff92SAndroid Build Coastguard Worker     {
539*89c4ff92SAndroid Build Coastguard Worker         size_t nInputs = inputs.size();
540*89c4ff92SAndroid Build Coastguard Worker         for (size_t i=0; i<nInputs; ++i)
541*89c4ff92SAndroid Build Coastguard Worker         {
542*89c4ff92SAndroid Build Coastguard Worker             // add swizzle layer
543*89c4ff92SAndroid Build Coastguard Worker             armnn::IConnectableLayer& swizzleLayer = AddTransposeLayer(network, inputs[i], mapping);
544*89c4ff92SAndroid Build Coastguard Worker             swizzleLayer.SetBackendId(setBackends[i]);
545*89c4ff92SAndroid Build Coastguard Worker             auto& outputSlot = swizzleLayer.GetOutputSlot(0);
546*89c4ff92SAndroid Build Coastguard Worker             auto& outputInfo = outputSlot.GetTensorInfo();
547*89c4ff92SAndroid Build Coastguard Worker             // replace inputs with the swizzled ones
548*89c4ff92SAndroid Build Coastguard Worker             inputs[i] = LayerInputHandle(true, &outputSlot, outputInfo);
549*89c4ff92SAndroid Build Coastguard Worker             inputShapes[i] = inputs[i].GetTensorInfo().GetShape();
550*89c4ff92SAndroid Build Coastguard Worker         }
551*89c4ff92SAndroid Build Coastguard Worker     }
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker 
TransposeInputTensors(ConversionData & data,std::vector<LayerInputHandle> & inputs,std::vector<armnn::TensorShape> & inputShapes,const armnn::PermutationVector & mapping)554*89c4ff92SAndroid Build Coastguard Worker bool TransposeInputTensors(ConversionData& data,
555*89c4ff92SAndroid Build Coastguard Worker                           std::vector<LayerInputHandle>& inputs,
556*89c4ff92SAndroid Build Coastguard Worker                           std::vector<armnn::TensorShape>& inputShapes,
557*89c4ff92SAndroid Build Coastguard Worker                           const armnn::PermutationVector& mapping)
558*89c4ff92SAndroid Build Coastguard Worker {
559*89c4ff92SAndroid Build Coastguard Worker     // If we have a IdentityPermutation4D or IdentityPermutation3D then we are not permuting
560*89c4ff92SAndroid Build Coastguard Worker     if (!mapping.IsEqual(IdentityPermutation4D) && !mapping.IsEqual(IdentityPermutation3D))
561*89c4ff92SAndroid Build Coastguard Worker     {
562*89c4ff92SAndroid Build Coastguard Worker         std::vector<armnn::BackendId> setBackendsVec;
563*89c4ff92SAndroid Build Coastguard Worker         armnn::TensorInfo outputTransposeInfo;
564*89c4ff92SAndroid Build Coastguard Worker         size_t nInputs = inputs.size();
565*89c4ff92SAndroid Build Coastguard Worker         for (size_t i=0; i<nInputs; ++i)
566*89c4ff92SAndroid Build Coastguard Worker         {
567*89c4ff92SAndroid Build Coastguard Worker             // check permute layer
568*89c4ff92SAndroid Build Coastguard Worker             armnn::TransposeDescriptor transposeDesc;
569*89c4ff92SAndroid Build Coastguard Worker             transposeDesc.m_DimMappings = mapping;
570*89c4ff92SAndroid Build Coastguard Worker             outputTransposeInfo = armnnUtils::TransposeTensorShape(inputs[i].GetTensorInfo(), mapping);
571*89c4ff92SAndroid Build Coastguard Worker 
572*89c4ff92SAndroid Build Coastguard Worker             bool isSupported = false;
573*89c4ff92SAndroid Build Coastguard Worker             armnn::BackendId setBackend;
574*89c4ff92SAndroid Build Coastguard Worker             FORWARD_LAYER_SUPPORT_FUNC(__func__,
575*89c4ff92SAndroid Build Coastguard Worker                                        IsTransposeSupported,
576*89c4ff92SAndroid Build Coastguard Worker                                        data.m_Backends,
577*89c4ff92SAndroid Build Coastguard Worker                                        isSupported,
578*89c4ff92SAndroid Build Coastguard Worker                                        setBackend,
579*89c4ff92SAndroid Build Coastguard Worker                                        inputs[i].GetTensorInfo(),
580*89c4ff92SAndroid Build Coastguard Worker                                        outputTransposeInfo,
581*89c4ff92SAndroid Build Coastguard Worker                                        transposeDesc);
582*89c4ff92SAndroid Build Coastguard Worker             setBackendsVec.push_back(setBackend);
583*89c4ff92SAndroid Build Coastguard Worker             if (!isSupported)
584*89c4ff92SAndroid Build Coastguard Worker             {
585*89c4ff92SAndroid Build Coastguard Worker                 return false;
586*89c4ff92SAndroid Build Coastguard Worker             }
587*89c4ff92SAndroid Build Coastguard Worker 
588*89c4ff92SAndroid Build Coastguard Worker         }
589*89c4ff92SAndroid Build Coastguard Worker         SwizzleInputs(*data.m_Network, inputs, inputShapes, mapping, setBackendsVec);
590*89c4ff92SAndroid Build Coastguard Worker     }
591*89c4ff92SAndroid Build Coastguard Worker     return true;
592*89c4ff92SAndroid Build Coastguard Worker }
593*89c4ff92SAndroid Build Coastguard Worker 
CreateConcatPermutationParameters(const unsigned int numberOfDimensions,int32_t & concatDimension,std::pair<armnn::PermutationVector,armnn::PermutationVector> & permutationPair)594*89c4ff92SAndroid Build Coastguard Worker bool CreateConcatPermutationParameters(const unsigned int numberOfDimensions,
595*89c4ff92SAndroid Build Coastguard Worker                                        int32_t & concatDimension,
596*89c4ff92SAndroid Build Coastguard Worker                                        std::pair<armnn::PermutationVector, armnn::PermutationVector> & permutationPair)
597*89c4ff92SAndroid Build Coastguard Worker {
598*89c4ff92SAndroid Build Coastguard Worker     bool needPermute = false;
599*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(numberOfDimensions >= 3);
600*89c4ff92SAndroid Build Coastguard Worker 
601*89c4ff92SAndroid Build Coastguard Worker     // ArmNN uses Compute Library subtensors to perform concatenation
602*89c4ff92SAndroid Build Coastguard Worker     // This only works when concatenating along dimension 0, 1 or 3 for a 4-D tensor,
603*89c4ff92SAndroid Build Coastguard Worker     // or along dimension 0 or 2 for a 3-D tensor.
604*89c4ff92SAndroid Build Coastguard Worker     if (numberOfDimensions == 4 && concatDimension == 2)
605*89c4ff92SAndroid Build Coastguard Worker     {
606*89c4ff92SAndroid Build Coastguard Worker         concatDimension = 3;
607*89c4ff92SAndroid Build Coastguard Worker         permutationPair = std::make_pair(SwapDim2And3, SwapDim2And3);
608*89c4ff92SAndroid Build Coastguard Worker         needPermute = true;
609*89c4ff92SAndroid Build Coastguard Worker     }
610*89c4ff92SAndroid Build Coastguard Worker     else if (numberOfDimensions == 3 && concatDimension == 1)
611*89c4ff92SAndroid Build Coastguard Worker     {
612*89c4ff92SAndroid Build Coastguard Worker         concatDimension = 0;
613*89c4ff92SAndroid Build Coastguard Worker         permutationPair = std::make_pair(RotateTensorLeft, RotateTensorRight);
614*89c4ff92SAndroid Build Coastguard Worker         needPermute = true;
615*89c4ff92SAndroid Build Coastguard Worker     }
616*89c4ff92SAndroid Build Coastguard Worker     // If the tensor is 3-D and the concat dimension is 2 then we don't need to permute but we do need to change the
617*89c4ff92SAndroid Build Coastguard Worker     // permutation identity to only have 3 dimensions
618*89c4ff92SAndroid Build Coastguard Worker     else if (numberOfDimensions == 3 && concatDimension == 2)
619*89c4ff92SAndroid Build Coastguard Worker     {
620*89c4ff92SAndroid Build Coastguard Worker         permutationPair = std::make_pair(IdentityPermutation3D, IdentityPermutation3D);
621*89c4ff92SAndroid Build Coastguard Worker     }
622*89c4ff92SAndroid Build Coastguard Worker     return needPermute;
623*89c4ff92SAndroid Build Coastguard Worker }
624*89c4ff92SAndroid Build Coastguard Worker 
625*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
626*89c4ff92SAndroid Build Coastguard Worker 
627*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
628*89c4ff92SAndroid Build Coastguard Worker {
629*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
630*89c4ff92SAndroid Build Coastguard Worker 
631*89c4ff92SAndroid Build Coastguard Worker //// Creates an ArmNN activation layer and connects it to the given layer, if the
632*89c4ff92SAndroid Build Coastguard Worker //// passed in AndroidNN activation function requires so.
633*89c4ff92SAndroid Build Coastguard Worker //// @return The end layer of the sequence of layers built for the given AndroidNN
634*89c4ff92SAndroid Build Coastguard Worker //// activation function or nullptr if an error occurred (e.g. unsupported activation).
635*89c4ff92SAndroid Build Coastguard Worker //// Note that the end layer matches the input layer if no activation is required
636*89c4ff92SAndroid Build Coastguard Worker //// (the sequence of layers has length 1).
637*89c4ff92SAndroid Build Coastguard Worker armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
638*89c4ff92SAndroid Build Coastguard Worker                                             ActivationFn activation,
639*89c4ff92SAndroid Build Coastguard Worker                                             armnn::IConnectableLayer* prevLayer,
640*89c4ff92SAndroid Build Coastguard Worker                                             ConversionData& data);
641*89c4ff92SAndroid Build Coastguard Worker 
642*89c4ff92SAndroid Build Coastguard Worker 
GetInputOperand(const Operation & operation,uint32_t inputIndex,const Model & model,bool failOnIndexOutOfBounds=true)643*89c4ff92SAndroid Build Coastguard Worker inline const Operand* GetInputOperand(const Operation& operation,
644*89c4ff92SAndroid Build Coastguard Worker                                       uint32_t inputIndex,
645*89c4ff92SAndroid Build Coastguard Worker                                       const Model& model,
646*89c4ff92SAndroid Build Coastguard Worker                                       bool failOnIndexOutOfBounds = true)
647*89c4ff92SAndroid Build Coastguard Worker {
648*89c4ff92SAndroid Build Coastguard Worker     if (inputIndex >= operation.inputs.size())
649*89c4ff92SAndroid Build Coastguard Worker     {
650*89c4ff92SAndroid Build Coastguard Worker         if (failOnIndexOutOfBounds)
651*89c4ff92SAndroid Build Coastguard Worker         {
652*89c4ff92SAndroid Build Coastguard Worker             Fail("%s: invalid input index: %i out of %i", __func__, inputIndex, operation.inputs.size());
653*89c4ff92SAndroid Build Coastguard Worker         }
654*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
655*89c4ff92SAndroid Build Coastguard Worker     }
656*89c4ff92SAndroid Build Coastguard Worker 
657*89c4ff92SAndroid Build Coastguard Worker     // Model should have been validated beforehand
658*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(operation.inputs[inputIndex] < getMainModel(model).operands.size());
659*89c4ff92SAndroid Build Coastguard Worker     return &getMainModel(model).operands[operation.inputs[inputIndex]];
660*89c4ff92SAndroid Build Coastguard Worker }
661*89c4ff92SAndroid Build Coastguard Worker 
GetOutputOperand(const Operation & operation,uint32_t outputIndex,const Model & model)662*89c4ff92SAndroid Build Coastguard Worker inline const Operand* GetOutputOperand(const Operation& operation,
663*89c4ff92SAndroid Build Coastguard Worker                                        uint32_t outputIndex,
664*89c4ff92SAndroid Build Coastguard Worker                                        const Model& model)
665*89c4ff92SAndroid Build Coastguard Worker {
666*89c4ff92SAndroid Build Coastguard Worker     if (outputIndex >= operation.outputs.size())
667*89c4ff92SAndroid Build Coastguard Worker     {
668*89c4ff92SAndroid Build Coastguard Worker         Fail("%s: invalid output index: %i out of %i", __func__, outputIndex, operation.outputs.size());
669*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
670*89c4ff92SAndroid Build Coastguard Worker     }
671*89c4ff92SAndroid Build Coastguard Worker 
672*89c4ff92SAndroid Build Coastguard Worker     // Model should have been validated beforehand
673*89c4ff92SAndroid Build Coastguard Worker     ARMNN_ASSERT(operation.outputs[outputIndex] < getMainModel(model).operands.size());
674*89c4ff92SAndroid Build Coastguard Worker 
675*89c4ff92SAndroid Build Coastguard Worker     return &getMainModel(model).operands[operation.outputs[outputIndex]];
676*89c4ff92SAndroid Build Coastguard Worker }
677*89c4ff92SAndroid Build Coastguard Worker 
678*89c4ff92SAndroid Build Coastguard Worker const void* GetOperandValueReadOnlyAddress(const Operand& operand,
679*89c4ff92SAndroid Build Coastguard Worker                                            const Model& model,
680*89c4ff92SAndroid Build Coastguard Worker                                            const ConversionData& data,
681*89c4ff92SAndroid Build Coastguard Worker                                            bool optional = false);
682*89c4ff92SAndroid Build Coastguard Worker 
GetOperandType(const Operation & operation,uint32_t inputIndex,const Model & model,OperandType & type)683*89c4ff92SAndroid Build Coastguard Worker inline bool GetOperandType(const Operation& operation,
684*89c4ff92SAndroid Build Coastguard Worker                            uint32_t inputIndex,
685*89c4ff92SAndroid Build Coastguard Worker                            const Model& model,
686*89c4ff92SAndroid Build Coastguard Worker                            OperandType& type)
687*89c4ff92SAndroid Build Coastguard Worker {
688*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
689*89c4ff92SAndroid Build Coastguard Worker     if (!operand)
690*89c4ff92SAndroid Build Coastguard Worker     {
691*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
692*89c4ff92SAndroid Build Coastguard Worker     }
693*89c4ff92SAndroid Build Coastguard Worker 
694*89c4ff92SAndroid Build Coastguard Worker     type = operand->type;
695*89c4ff92SAndroid Build Coastguard Worker     return true;
696*89c4ff92SAndroid Build Coastguard Worker }
697*89c4ff92SAndroid Build Coastguard Worker 
IsOperandConstant(const Operand & operand)698*89c4ff92SAndroid Build Coastguard Worker inline bool IsOperandConstant(const Operand& operand)
699*89c4ff92SAndroid Build Coastguard Worker {
700*89c4ff92SAndroid Build Coastguard Worker     OperandLifeTime lifetime = operand.lifetime;
701*89c4ff92SAndroid Build Coastguard Worker 
702*89c4ff92SAndroid Build Coastguard Worker     return lifetime == OperandLifeTime::CONSTANT_COPY ||
703*89c4ff92SAndroid Build Coastguard Worker            lifetime == OperandLifeTime::CONSTANT_REFERENCE ||
704*89c4ff92SAndroid Build Coastguard Worker            lifetime == OperandLifeTime::POINTER ||
705*89c4ff92SAndroid Build Coastguard Worker            lifetime == OperandLifeTime::NO_VALUE;
706*89c4ff92SAndroid Build Coastguard Worker }
707*89c4ff92SAndroid Build Coastguard Worker 
708*89c4ff92SAndroid Build Coastguard Worker bool IsWeightsValid(const Operation& operation, uint32_t inputIndex, const Model& model);
709*89c4ff92SAndroid Build Coastguard Worker 
710*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
711*89c4ff92SAndroid Build Coastguard Worker                                               const Model& model,
712*89c4ff92SAndroid Build Coastguard Worker                                               const ConversionData& data,
713*89c4ff92SAndroid Build Coastguard Worker                                               const armnn::PermutationVector& dimensionMappings = g_DontPermute,
714*89c4ff92SAndroid Build Coastguard Worker                                               const armnn::TensorShape* overrideTensorShape = nullptr,
715*89c4ff92SAndroid Build Coastguard Worker                                               bool optional = false,
716*89c4ff92SAndroid Build Coastguard Worker                                               const armnn::DataType* overrideDataType = nullptr);
717*89c4ff92SAndroid Build Coastguard Worker 
ConvertOperationInputToConstTensorPin(const Operation & operation,uint32_t inputIndex,const Model & model,const ConversionData & data,const armnn::PermutationVector & dimensionMappings=g_DontPermute,const armnn::TensorShape * overrideTensorShape=nullptr,bool optional=false)718*89c4ff92SAndroid Build Coastguard Worker inline ConstTensorPin ConvertOperationInputToConstTensorPin(
719*89c4ff92SAndroid Build Coastguard Worker         const Operation& operation,
720*89c4ff92SAndroid Build Coastguard Worker         uint32_t inputIndex,
721*89c4ff92SAndroid Build Coastguard Worker         const Model& model,
722*89c4ff92SAndroid Build Coastguard Worker         const ConversionData& data,
723*89c4ff92SAndroid Build Coastguard Worker         const armnn::PermutationVector& dimensionMappings = g_DontPermute,
724*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorShape* overrideTensorShape = nullptr,
725*89c4ff92SAndroid Build Coastguard Worker         bool optional = false)
726*89c4ff92SAndroid Build Coastguard Worker {
727*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
728*89c4ff92SAndroid Build Coastguard Worker     if (!operand)
729*89c4ff92SAndroid Build Coastguard Worker     {
730*89c4ff92SAndroid Build Coastguard Worker         Fail("%s: failed to get input operand: index=%u", __func__, inputIndex);
731*89c4ff92SAndroid Build Coastguard Worker         return ConstTensorPin();
732*89c4ff92SAndroid Build Coastguard Worker     }
733*89c4ff92SAndroid Build Coastguard Worker     return ConvertOperandToConstTensorPin(*operand,
734*89c4ff92SAndroid Build Coastguard Worker                                           model,
735*89c4ff92SAndroid Build Coastguard Worker                                           data,
736*89c4ff92SAndroid Build Coastguard Worker                                           dimensionMappings,
737*89c4ff92SAndroid Build Coastguard Worker                                           overrideTensorShape,
738*89c4ff92SAndroid Build Coastguard Worker                                           optional);
739*89c4ff92SAndroid Build Coastguard Worker }
740*89c4ff92SAndroid Build Coastguard Worker 
741*89c4ff92SAndroid Build Coastguard Worker template <typename OutputType>
GetInputScalar(const Operation & operation,uint32_t inputIndex,OperandType type,OutputType & outValue,const Model & model,const ConversionData & data,bool optional=false)742*89c4ff92SAndroid Build Coastguard Worker bool GetInputScalar(const Operation& operation,
743*89c4ff92SAndroid Build Coastguard Worker                     uint32_t inputIndex,
744*89c4ff92SAndroid Build Coastguard Worker                     OperandType type,
745*89c4ff92SAndroid Build Coastguard Worker                     OutputType& outValue,
746*89c4ff92SAndroid Build Coastguard Worker                     const Model& model,
747*89c4ff92SAndroid Build Coastguard Worker                     const ConversionData& data,
748*89c4ff92SAndroid Build Coastguard Worker                     bool optional = false)
749*89c4ff92SAndroid Build Coastguard Worker {
750*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
751*89c4ff92SAndroid Build Coastguard Worker     if (!optional && !operand)
752*89c4ff92SAndroid Build Coastguard Worker     {
753*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: invalid input operand at index %i", __func__, inputIndex);
754*89c4ff92SAndroid Build Coastguard Worker     }
755*89c4ff92SAndroid Build Coastguard Worker 
756*89c4ff92SAndroid Build Coastguard Worker     if (!optional && operand->type != type)
757*89c4ff92SAndroid Build Coastguard Worker     {
758*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": unexpected operand type: " << operand->type << " should be: " << type;
759*89c4ff92SAndroid Build Coastguard Worker         return false;
760*89c4ff92SAndroid Build Coastguard Worker     }
761*89c4ff92SAndroid Build Coastguard Worker 
762*89c4ff92SAndroid Build Coastguard Worker     if (!optional && operand->location.length != sizeof(OutputType))
763*89c4ff92SAndroid Build Coastguard Worker     {
764*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: incorrect operand location length: %i (should be %i)",
765*89c4ff92SAndroid Build Coastguard Worker                     __func__, operand->location.length, sizeof(OutputType));
766*89c4ff92SAndroid Build Coastguard Worker     }
767*89c4ff92SAndroid Build Coastguard Worker 
768*89c4ff92SAndroid Build Coastguard Worker     const void* valueAddress = GetOperandValueReadOnlyAddress(*operand, model, data);
769*89c4ff92SAndroid Build Coastguard Worker     if (!optional && !valueAddress)
770*89c4ff92SAndroid Build Coastguard Worker     {
771*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: failed to get address for operand", __func__);
772*89c4ff92SAndroid Build Coastguard Worker     }
773*89c4ff92SAndroid Build Coastguard Worker 
774*89c4ff92SAndroid Build Coastguard Worker     if(!optional)
775*89c4ff92SAndroid Build Coastguard Worker     {
776*89c4ff92SAndroid Build Coastguard Worker         outValue = *(static_cast<const OutputType*>(valueAddress));
777*89c4ff92SAndroid Build Coastguard Worker     }
778*89c4ff92SAndroid Build Coastguard Worker 
779*89c4ff92SAndroid Build Coastguard Worker     return true;
780*89c4ff92SAndroid Build Coastguard Worker }
781*89c4ff92SAndroid Build Coastguard Worker 
GetInputInt32(const Operation & operation,uint32_t inputIndex,int32_t & outValue,const Model & model,const ConversionData & data)782*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputInt32(const Operation& operation,
783*89c4ff92SAndroid Build Coastguard Worker                           uint32_t inputIndex,
784*89c4ff92SAndroid Build Coastguard Worker                           int32_t& outValue,
785*89c4ff92SAndroid Build Coastguard Worker                           const Model& model,
786*89c4ff92SAndroid Build Coastguard Worker                           const ConversionData& data)
787*89c4ff92SAndroid Build Coastguard Worker {
788*89c4ff92SAndroid Build Coastguard Worker     return GetInputScalar(operation, inputIndex, OperandType::INT32, outValue, model, data);
789*89c4ff92SAndroid Build Coastguard Worker }
790*89c4ff92SAndroid Build Coastguard Worker 
GetInputFloat32(const Operation & operation,uint32_t inputIndex,float & outValue,const Model & model,const ConversionData & data)791*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputFloat32(const Operation& operation,
792*89c4ff92SAndroid Build Coastguard Worker                             uint32_t inputIndex,
793*89c4ff92SAndroid Build Coastguard Worker                             float& outValue,
794*89c4ff92SAndroid Build Coastguard Worker                             const Model& model,
795*89c4ff92SAndroid Build Coastguard Worker                             const ConversionData& data)
796*89c4ff92SAndroid Build Coastguard Worker {
797*89c4ff92SAndroid Build Coastguard Worker     return GetInputScalar(operation, inputIndex, OperandType::FLOAT32, outValue, model, data);
798*89c4ff92SAndroid Build Coastguard Worker }
799*89c4ff92SAndroid Build Coastguard Worker 
GetInputActivationFunctionImpl(const Operation & operation,uint32_t inputIndex,OperandType type,ActivationFn & outActivationFunction,const Model & model,const ConversionData & data)800*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputActivationFunctionImpl(const Operation& operation,
801*89c4ff92SAndroid Build Coastguard Worker                                            uint32_t inputIndex,
802*89c4ff92SAndroid Build Coastguard Worker                                            OperandType type,
803*89c4ff92SAndroid Build Coastguard Worker                                            ActivationFn& outActivationFunction,
804*89c4ff92SAndroid Build Coastguard Worker                                            const Model& model,
805*89c4ff92SAndroid Build Coastguard Worker                                            const ConversionData& data)
806*89c4ff92SAndroid Build Coastguard Worker {
807*89c4ff92SAndroid Build Coastguard Worker     if (type != OperandType::INT32 && type != OperandType::TENSOR_INT32)
808*89c4ff92SAndroid Build Coastguard Worker     {
809*89c4ff92SAndroid Build Coastguard Worker         VLOG(DRIVER) << __func__ << ": unexpected operand type: " << type
810*89c4ff92SAndroid Build Coastguard Worker                      << " should be OperandType::INT32 or OperandType::TENSOR_INT32";
811*89c4ff92SAndroid Build Coastguard Worker         return false;
812*89c4ff92SAndroid Build Coastguard Worker     }
813*89c4ff92SAndroid Build Coastguard Worker 
814*89c4ff92SAndroid Build Coastguard Worker     int32_t activationFunctionAsInt;
815*89c4ff92SAndroid Build Coastguard Worker     if (!GetInputScalar(operation, inputIndex, type, activationFunctionAsInt, model, data))
816*89c4ff92SAndroid Build Coastguard Worker     {
817*89c4ff92SAndroid Build Coastguard Worker         return Fail("%s: failed to get activation input value", __func__);
818*89c4ff92SAndroid Build Coastguard Worker     }
819*89c4ff92SAndroid Build Coastguard Worker     outActivationFunction = static_cast<ActivationFn>(activationFunctionAsInt);
820*89c4ff92SAndroid Build Coastguard Worker     return true;
821*89c4ff92SAndroid Build Coastguard Worker }
822*89c4ff92SAndroid Build Coastguard Worker 
GetInputActivationFunction(const Operation & operation,uint32_t inputIndex,ActivationFn & outActivationFunction,const Model & model,const ConversionData & data)823*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputActivationFunction(const Operation& operation,
824*89c4ff92SAndroid Build Coastguard Worker                                        uint32_t inputIndex,
825*89c4ff92SAndroid Build Coastguard Worker                                        ActivationFn& outActivationFunction,
826*89c4ff92SAndroid Build Coastguard Worker                                        const Model& model,
827*89c4ff92SAndroid Build Coastguard Worker                                        const ConversionData& data)
828*89c4ff92SAndroid Build Coastguard Worker {
829*89c4ff92SAndroid Build Coastguard Worker     return GetInputActivationFunctionImpl(operation,
830*89c4ff92SAndroid Build Coastguard Worker                                           inputIndex,
831*89c4ff92SAndroid Build Coastguard Worker                                           OperandType::INT32,
832*89c4ff92SAndroid Build Coastguard Worker                                           outActivationFunction,
833*89c4ff92SAndroid Build Coastguard Worker                                           model,
834*89c4ff92SAndroid Build Coastguard Worker                                           data);
835*89c4ff92SAndroid Build Coastguard Worker }
836*89c4ff92SAndroid Build Coastguard Worker 
GetInputActivationFunctionFromTensor(const Operation & operation,uint32_t inputIndex,ActivationFn & outActivationFunction,const Model & model,const ConversionData & data)837*89c4ff92SAndroid Build Coastguard Worker inline bool GetInputActivationFunctionFromTensor(const Operation& operation,
838*89c4ff92SAndroid Build Coastguard Worker                                                  uint32_t inputIndex,
839*89c4ff92SAndroid Build Coastguard Worker                                                  ActivationFn& outActivationFunction,
840*89c4ff92SAndroid Build Coastguard Worker                                                  const Model& model,
841*89c4ff92SAndroid Build Coastguard Worker                                                  const ConversionData& data)
842*89c4ff92SAndroid Build Coastguard Worker {
843*89c4ff92SAndroid Build Coastguard Worker     // This only accepts a 1-D tensor of size 1
844*89c4ff92SAndroid Build Coastguard Worker     return GetInputActivationFunctionImpl(operation,
845*89c4ff92SAndroid Build Coastguard Worker                                           inputIndex,
846*89c4ff92SAndroid Build Coastguard Worker                                           OperandType::INT32,
847*89c4ff92SAndroid Build Coastguard Worker                                           outActivationFunction,
848*89c4ff92SAndroid Build Coastguard Worker                                           model,
849*89c4ff92SAndroid Build Coastguard Worker                                           data);
850*89c4ff92SAndroid Build Coastguard Worker }
851*89c4ff92SAndroid Build Coastguard Worker 
852*89c4ff92SAndroid Build Coastguard Worker 
GetOptionalInputActivation(const Operation & operation,uint32_t inputIndex,ActivationFn & activationFunction,const Model & model,const ConversionData & data)853*89c4ff92SAndroid Build Coastguard Worker inline bool GetOptionalInputActivation(const Operation& operation,
854*89c4ff92SAndroid Build Coastguard Worker                                        uint32_t inputIndex,
855*89c4ff92SAndroid Build Coastguard Worker                                        ActivationFn& activationFunction,
856*89c4ff92SAndroid Build Coastguard Worker                                        const Model& model,
857*89c4ff92SAndroid Build Coastguard Worker                                        const ConversionData& data)
858*89c4ff92SAndroid Build Coastguard Worker {
859*89c4ff92SAndroid Build Coastguard Worker     if (operation.inputs.size() <= inputIndex)
860*89c4ff92SAndroid Build Coastguard Worker     {
861*89c4ff92SAndroid Build Coastguard Worker         activationFunction = ActivationFn::kActivationNone;
862*89c4ff92SAndroid Build Coastguard Worker     }
863*89c4ff92SAndroid Build Coastguard Worker     else
864*89c4ff92SAndroid Build Coastguard Worker     {
865*89c4ff92SAndroid Build Coastguard Worker         if (!GetInputActivationFunction(operation, inputIndex, activationFunction, model, data))
866*89c4ff92SAndroid Build Coastguard Worker         {
867*89c4ff92SAndroid Build Coastguard Worker             return Fail("%s: Operation has invalid inputs", __func__);
868*89c4ff92SAndroid Build Coastguard Worker         }
869*89c4ff92SAndroid Build Coastguard Worker     }
870*89c4ff92SAndroid Build Coastguard Worker     return true;
871*89c4ff92SAndroid Build Coastguard Worker }
872*89c4ff92SAndroid Build Coastguard Worker 
873*89c4ff92SAndroid Build Coastguard Worker template<typename ConvolutionDescriptor>
GetOptionalConvolutionDilationParams(const Operation & operation,uint32_t dilationXIndex,ConvolutionDescriptor & descriptor,const Model & model,const ConversionData & data)874*89c4ff92SAndroid Build Coastguard Worker bool GetOptionalConvolutionDilationParams(const Operation& operation,
875*89c4ff92SAndroid Build Coastguard Worker                                           uint32_t dilationXIndex,
876*89c4ff92SAndroid Build Coastguard Worker                                           ConvolutionDescriptor& descriptor,
877*89c4ff92SAndroid Build Coastguard Worker                                           const Model& model,
878*89c4ff92SAndroid Build Coastguard Worker                                           const ConversionData& data)
879*89c4ff92SAndroid Build Coastguard Worker {
880*89c4ff92SAndroid Build Coastguard Worker     bool success = true;
881*89c4ff92SAndroid Build Coastguard Worker     if (operation.inputs.size() >= dilationXIndex + 2)
882*89c4ff92SAndroid Build Coastguard Worker     {
883*89c4ff92SAndroid Build Coastguard Worker         success &= GetInputScalar(operation,
884*89c4ff92SAndroid Build Coastguard Worker                                   dilationXIndex,
885*89c4ff92SAndroid Build Coastguard Worker                                   OperandType::INT32,
886*89c4ff92SAndroid Build Coastguard Worker                                   descriptor.m_DilationX,
887*89c4ff92SAndroid Build Coastguard Worker                                   model,
888*89c4ff92SAndroid Build Coastguard Worker                                   data);
889*89c4ff92SAndroid Build Coastguard Worker         success &= GetInputScalar(operation,
890*89c4ff92SAndroid Build Coastguard Worker                                   dilationXIndex + 1,
891*89c4ff92SAndroid Build Coastguard Worker                                   OperandType::INT32,
892*89c4ff92SAndroid Build Coastguard Worker                                   descriptor.m_DilationY,
893*89c4ff92SAndroid Build Coastguard Worker                                   model,
894*89c4ff92SAndroid Build Coastguard Worker                                   data);
895*89c4ff92SAndroid Build Coastguard Worker     }
896*89c4ff92SAndroid Build Coastguard Worker 
897*89c4ff92SAndroid Build Coastguard Worker     return success;
898*89c4ff92SAndroid Build Coastguard Worker }
899*89c4ff92SAndroid Build Coastguard Worker 
GetOptionalBool(const Operation & operation,uint32_t inputIndex,const Model & model,const ConversionData & data)900*89c4ff92SAndroid Build Coastguard Worker inline bool GetOptionalBool(const Operation& operation,
901*89c4ff92SAndroid Build Coastguard Worker                             uint32_t inputIndex,
902*89c4ff92SAndroid Build Coastguard Worker                             const Model& model,
903*89c4ff92SAndroid Build Coastguard Worker                             const ConversionData& data)
904*89c4ff92SAndroid Build Coastguard Worker {
905*89c4ff92SAndroid Build Coastguard Worker     const Operand* operand = GetInputOperand(operation, inputIndex, model);
906*89c4ff92SAndroid Build Coastguard Worker     if (!operand)
907*89c4ff92SAndroid Build Coastguard Worker     {
908*89c4ff92SAndroid Build Coastguard Worker         return false;
909*89c4ff92SAndroid Build Coastguard Worker     }
910*89c4ff92SAndroid Build Coastguard Worker 
911*89c4ff92SAndroid Build Coastguard Worker     if (!IsBool(*operand))
912*89c4ff92SAndroid Build Coastguard Worker     {
913*89c4ff92SAndroid Build Coastguard Worker         return false;
914*89c4ff92SAndroid Build Coastguard Worker     }
915*89c4ff92SAndroid Build Coastguard Worker 
916*89c4ff92SAndroid Build Coastguard Worker     const void* valueAddress = GetOperandValueReadOnlyAddress(*operand, model, data);
917*89c4ff92SAndroid Build Coastguard Worker     if (!valueAddress)
918*89c4ff92SAndroid Build Coastguard Worker     {
919*89c4ff92SAndroid Build Coastguard Worker         return false;
920*89c4ff92SAndroid Build Coastguard Worker     }
921*89c4ff92SAndroid Build Coastguard Worker 
922*89c4ff92SAndroid Build Coastguard Worker     return *(static_cast<const bool*>(valueAddress));
923*89c4ff92SAndroid Build Coastguard Worker }
924*89c4ff92SAndroid Build Coastguard Worker 
925*89c4ff92SAndroid Build Coastguard Worker bool GetTensorInt32Values(const Operand& operand,
926*89c4ff92SAndroid Build Coastguard Worker                                  std::vector<int32_t>& outValues,
927*89c4ff92SAndroid Build Coastguard Worker                                  const Model& model,
928*89c4ff92SAndroid Build Coastguard Worker                                  const ConversionData& data);
929*89c4ff92SAndroid Build Coastguard Worker 
930*89c4ff92SAndroid Build Coastguard Worker bool GetInputPaddingScheme(const Operation& operation,
931*89c4ff92SAndroid Build Coastguard Worker                            uint32_t inputIndex,
932*89c4ff92SAndroid Build Coastguard Worker                            PaddingScheme& outPaddingScheme,
933*89c4ff92SAndroid Build Coastguard Worker                            const Model& model,
934*89c4ff92SAndroid Build Coastguard Worker                            const ConversionData& data);
935*89c4ff92SAndroid Build Coastguard Worker 
936*89c4ff92SAndroid Build Coastguard Worker LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
937*89c4ff92SAndroid Build Coastguard Worker                                            uint32_t inputIndex,
938*89c4ff92SAndroid Build Coastguard Worker                                            const Model& model,
939*89c4ff92SAndroid Build Coastguard Worker                                            ConversionData& data,
940*89c4ff92SAndroid Build Coastguard Worker                                            const armnn::PermutationVector& dimensionMappings = g_DontPermute,
941*89c4ff92SAndroid Build Coastguard Worker                                            const LayerInputHandle* inputHandle = nullptr);
942*89c4ff92SAndroid Build Coastguard Worker 
943*89c4ff92SAndroid Build Coastguard Worker bool SetupAndTrackLayerOutputSlot(const Operation& operation,
944*89c4ff92SAndroid Build Coastguard Worker                                   uint32_t operationOutputIndex,
945*89c4ff92SAndroid Build Coastguard Worker                                   armnn::IConnectableLayer& layer,
946*89c4ff92SAndroid Build Coastguard Worker                                   uint32_t layerOutputIndex,
947*89c4ff92SAndroid Build Coastguard Worker                                   const Model& model,
948*89c4ff92SAndroid Build Coastguard Worker                                   ConversionData& data,
949*89c4ff92SAndroid Build Coastguard Worker                                   const armnn::TensorInfo* overrideOutputInfo = nullptr,
950*89c4ff92SAndroid Build Coastguard Worker                                   const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr,
951*89c4ff92SAndroid Build Coastguard Worker                                   const ActivationFn& activationFunction = ActivationFn::kActivationNone,
952*89c4ff92SAndroid Build Coastguard Worker                                   bool inferOutputShapes = false);
953*89c4ff92SAndroid Build Coastguard Worker 
954*89c4ff92SAndroid Build Coastguard Worker armnn::DataLayout OptionalDataLayout(const Operation& operation,
955*89c4ff92SAndroid Build Coastguard Worker                                      uint32_t inputIndex,
956*89c4ff92SAndroid Build Coastguard Worker                                      const Model& model,
957*89c4ff92SAndroid Build Coastguard Worker                                      ConversionData& data);
958*89c4ff92SAndroid Build Coastguard Worker 
SetupAndTrackLayerOutputSlot(const Operation & operation,uint32_t outputIndex,armnn::IConnectableLayer & layer,const Model & model,ConversionData & data,const armnn::TensorInfo * overrideOutputInfo=nullptr,const std::function<void (const armnn::TensorInfo &,bool &)> & validateFunc=nullptr,const ActivationFn & activationFunction=ActivationFn::kActivationNone)959*89c4ff92SAndroid Build Coastguard Worker inline bool SetupAndTrackLayerOutputSlot(
960*89c4ff92SAndroid Build Coastguard Worker         const Operation& operation,
961*89c4ff92SAndroid Build Coastguard Worker         uint32_t outputIndex,
962*89c4ff92SAndroid Build Coastguard Worker         armnn::IConnectableLayer& layer,
963*89c4ff92SAndroid Build Coastguard Worker         const Model& model,
964*89c4ff92SAndroid Build Coastguard Worker         ConversionData& data,
965*89c4ff92SAndroid Build Coastguard Worker         const armnn::TensorInfo* overrideOutputInfo = nullptr,
966*89c4ff92SAndroid Build Coastguard Worker         const std::function <void (const armnn::TensorInfo&, bool&)>& validateFunc = nullptr,
967*89c4ff92SAndroid Build Coastguard Worker         const ActivationFn& activationFunction = ActivationFn::kActivationNone)
968*89c4ff92SAndroid Build Coastguard Worker {
969*89c4ff92SAndroid Build Coastguard Worker     return SetupAndTrackLayerOutputSlot(operation,
970*89c4ff92SAndroid Build Coastguard Worker                                         outputIndex,
971*89c4ff92SAndroid Build Coastguard Worker                                         layer,
972*89c4ff92SAndroid Build Coastguard Worker                                         outputIndex,
973*89c4ff92SAndroid Build Coastguard Worker                                         model,
974*89c4ff92SAndroid Build Coastguard Worker                                         data,
975*89c4ff92SAndroid Build Coastguard Worker                                         overrideOutputInfo,
976*89c4ff92SAndroid Build Coastguard Worker                                         validateFunc,
977*89c4ff92SAndroid Build Coastguard Worker                                         activationFunction);
978*89c4ff92SAndroid Build Coastguard Worker }
979*89c4ff92SAndroid Build Coastguard Worker 
980*89c4ff92SAndroid Build Coastguard Worker bool ConvertToActivation(const Operation& operation,
981*89c4ff92SAndroid Build Coastguard Worker                          const char* operationName,
982*89c4ff92SAndroid Build Coastguard Worker                          const armnn::ActivationDescriptor& activationDesc,
983*89c4ff92SAndroid Build Coastguard Worker                          const Model& model,
984*89c4ff92SAndroid Build Coastguard Worker                          ConversionData& data);
985*89c4ff92SAndroid Build Coastguard Worker 
986*89c4ff92SAndroid Build Coastguard Worker bool ConvertPaddings(const Operation& operation,
987*89c4ff92SAndroid Build Coastguard Worker                      const Model& model,
988*89c4ff92SAndroid Build Coastguard Worker                      ConversionData& data,
989*89c4ff92SAndroid Build Coastguard Worker                      unsigned int rank,
990*89c4ff92SAndroid Build Coastguard Worker                      armnn::PadDescriptor& padDescriptor);
991*89c4ff92SAndroid Build Coastguard Worker bool ConvertReduce(const Operation& operation,
992*89c4ff92SAndroid Build Coastguard Worker                    const Model& model,
993*89c4ff92SAndroid Build Coastguard Worker                    ConversionData& data,
994*89c4ff92SAndroid Build Coastguard Worker                    armnn::ReduceOperation reduceOperation);
995*89c4ff92SAndroid Build Coastguard Worker 
996*89c4ff92SAndroid Build Coastguard Worker bool ConvertPooling2d(const Operation& operation,
997*89c4ff92SAndroid Build Coastguard Worker                       const char* operationName,
998*89c4ff92SAndroid Build Coastguard Worker                       armnn::PoolingAlgorithm poolType,
999*89c4ff92SAndroid Build Coastguard Worker                       const Model& model,
1000*89c4ff92SAndroid Build Coastguard Worker                       ConversionData& data);
1001*89c4ff92SAndroid Build Coastguard Worker 
IsQSymm8(const Operand & operand)1002*89c4ff92SAndroid Build Coastguard Worker inline bool IsQSymm8(const Operand& operand)
1003*89c4ff92SAndroid Build Coastguard Worker {
1004*89c4ff92SAndroid Build Coastguard Worker     return operand.type == OperandType::TENSOR_QUANT8_SYMM;
1005*89c4ff92SAndroid Build Coastguard Worker }
1006*89c4ff92SAndroid Build Coastguard Worker 
1007*89c4ff92SAndroid Build Coastguard Worker enum class DequantizeStatus
1008*89c4ff92SAndroid Build Coastguard Worker {
1009*89c4ff92SAndroid Build Coastguard Worker     SUCCESS,
1010*89c4ff92SAndroid Build Coastguard Worker     NOT_REQUIRED,
1011*89c4ff92SAndroid Build Coastguard Worker     INVALID_OPERAND
1012*89c4ff92SAndroid Build Coastguard Worker };
1013*89c4ff92SAndroid Build Coastguard Worker 
1014*89c4ff92SAndroid Build Coastguard Worker using DequantizeResult = std::tuple<std::unique_ptr<float[]>, size_t, armnn::TensorInfo, DequantizeStatus>;
1015*89c4ff92SAndroid Build Coastguard Worker 
1016*89c4ff92SAndroid Build Coastguard Worker DequantizeResult DequantizeIfRequired(size_t operand_index,
1017*89c4ff92SAndroid Build Coastguard Worker                                       const Operation& operation,
1018*89c4ff92SAndroid Build Coastguard Worker                                       const Model& model,
1019*89c4ff92SAndroid Build Coastguard Worker                                       const ConversionData& data);
1020*89c4ff92SAndroid Build Coastguard Worker 
1021*89c4ff92SAndroid Build Coastguard Worker ConstTensorPin DequantizeAndMakeConstTensorPin(const Operation& operation,
1022*89c4ff92SAndroid Build Coastguard Worker                                                const Model& model,
1023*89c4ff92SAndroid Build Coastguard Worker                                                const ConversionData& data,
1024*89c4ff92SAndroid Build Coastguard Worker                                                size_t operandIndex,
1025*89c4ff92SAndroid Build Coastguard Worker                                                bool optional = false);
1026*89c4ff92SAndroid Build Coastguard Worker 
1027*89c4ff92SAndroid Build Coastguard Worker bool IsConnectedToDequantize(armnn::IOutputSlot* ioutputSlot);
1028*89c4ff92SAndroid Build Coastguard Worker 
1029*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
1030