xref: /aosp_15_r20/external/android-nn-driver/ConversionUtils.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #include "ConversionUtils.hpp"
7*3e777be0SXin Li #include <armnnUtils/Permute.hpp>
8*3e777be0SXin Li 
9*3e777be0SXin Li ///
10*3e777be0SXin Li /// Helper classes
11*3e777be0SXin Li ///
12*3e777be0SXin Li 
13*3e777be0SXin Li namespace armnn_driver
14*3e777be0SXin Li {
15*3e777be0SXin Li 
LayerInputHandle()16*3e777be0SXin Li LayerInputHandle::LayerInputHandle()
17*3e777be0SXin Li     : m_OutputSlot(nullptr)
18*3e777be0SXin Li     , m_Valid(false)
19*3e777be0SXin Li {}
20*3e777be0SXin Li 
LayerInputHandle(bool valid,armnn::IOutputSlot * outputSlot,armnn::TensorInfo tensorInfo)21*3e777be0SXin Li LayerInputHandle::LayerInputHandle(bool valid, armnn::IOutputSlot* outputSlot, armnn::TensorInfo tensorInfo)
22*3e777be0SXin Li     : m_OutputSlot(outputSlot)
23*3e777be0SXin Li     , m_Valid(valid)
24*3e777be0SXin Li     , m_TensorInfo(tensorInfo)
25*3e777be0SXin Li {}
26*3e777be0SXin Li 
IsValid() const27*3e777be0SXin Li bool LayerInputHandle::IsValid() const
28*3e777be0SXin Li {
29*3e777be0SXin Li     return m_Valid;
30*3e777be0SXin Li }
31*3e777be0SXin Li 
Connect(armnn::IInputSlot & inputSlot)32*3e777be0SXin Li void LayerInputHandle::Connect(armnn::IInputSlot& inputSlot)
33*3e777be0SXin Li {
34*3e777be0SXin Li     if (!IsValid())
35*3e777be0SXin Li     {
36*3e777be0SXin Li         throw armnn::RuntimeException("LayerInputHandle is invalid");
37*3e777be0SXin Li     }
38*3e777be0SXin Li 
39*3e777be0SXin Li     if (m_OutputSlot)
40*3e777be0SXin Li     {
41*3e777be0SXin Li         m_OutputSlot->Connect(inputSlot);
42*3e777be0SXin Li     }
43*3e777be0SXin Li }
44*3e777be0SXin Li 
Disconnect(armnn::IInputSlot & inputSlot)45*3e777be0SXin Li void LayerInputHandle::Disconnect(armnn::IInputSlot& inputSlot)
46*3e777be0SXin Li {
47*3e777be0SXin Li     if (!IsValid())
48*3e777be0SXin Li     {
49*3e777be0SXin Li         throw armnn::RuntimeException("LayerInputHandle is invalid");
50*3e777be0SXin Li     }
51*3e777be0SXin Li     if (m_OutputSlot)
52*3e777be0SXin Li     {
53*3e777be0SXin Li         m_OutputSlot->Disconnect(inputSlot);
54*3e777be0SXin Li     }
55*3e777be0SXin Li }
56*3e777be0SXin Li 
GetTensorInfo() const57*3e777be0SXin Li const armnn::TensorInfo& LayerInputHandle::GetTensorInfo() const
58*3e777be0SXin Li {
59*3e777be0SXin Li     return m_TensorInfo;
60*3e777be0SXin Li }
61*3e777be0SXin Li 
SanitizeQuantizationScale(LayerInputHandle & weight,LayerInputHandle & input)62*3e777be0SXin Li void LayerInputHandle::SanitizeQuantizationScale(LayerInputHandle& weight,
63*3e777be0SXin Li                                                  LayerInputHandle& input)
64*3e777be0SXin Li {
65*3e777be0SXin Li     if (m_OutputSlot)
66*3e777be0SXin Li     {
67*3e777be0SXin Li         armnn::TensorInfo weightInfo = weight.GetTensorInfo();
68*3e777be0SXin Li         armnn::TensorInfo inputInfo = input.GetTensorInfo();
69*3e777be0SXin Li         armnn::TensorInfo biasInfo = GetTensorInfo();
70*3e777be0SXin Li 
71*3e777be0SXin Li         SanitizeBiasQuantizationScale(biasInfo, weightInfo, inputInfo);
72*3e777be0SXin Li 
73*3e777be0SXin Li         m_TensorInfo = biasInfo;
74*3e777be0SXin Li         m_OutputSlot->SetTensorInfo(biasInfo);
75*3e777be0SXin Li     }
76*3e777be0SXin Li }
77*3e777be0SXin Li 
ConstTensorPin(bool optional)78*3e777be0SXin Li ConstTensorPin::ConstTensorPin(bool optional)
79*3e777be0SXin Li     : m_Optional(optional)
80*3e777be0SXin Li {}
81*3e777be0SXin Li 
ConstTensorPin(armnn::TensorInfo & tensorInfo,const void * valueStart,uint32_t numBytes,const armnn::PermutationVector & mappings)82*3e777be0SXin Li ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo,
83*3e777be0SXin Li                                const void* valueStart,
84*3e777be0SXin Li                                uint32_t numBytes,
85*3e777be0SXin Li                                const armnn::PermutationVector& mappings)
86*3e777be0SXin Li     : m_Optional(false)
87*3e777be0SXin Li {
88*3e777be0SXin Li     armnn::IgnoreUnused(numBytes);
89*3e777be0SXin Li     if (tensorInfo.GetNumBytes() != numBytes)
90*3e777be0SXin Li     {
91*3e777be0SXin Li         ALOGW("The size of ConstTensor does not match its TensorInfo.");
92*3e777be0SXin Li     }
93*3e777be0SXin Li 
94*3e777be0SXin Li     const bool needsSwizzling = (mappings.GetSize() > 0);
95*3e777be0SXin Li     if (needsSwizzling)
96*3e777be0SXin Li     {
97*3e777be0SXin Li         m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
98*3e777be0SXin Li         SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
99*3e777be0SXin Li 
100*3e777be0SXin Li         m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data());
101*3e777be0SXin Li     }
102*3e777be0SXin Li     else
103*3e777be0SXin Li     {
104*3e777be0SXin Li         m_ConstTensor = armnn::ConstTensor(tensorInfo, valueStart);
105*3e777be0SXin Li     }
106*3e777be0SXin Li }
107*3e777be0SXin Li 
IsValid() const108*3e777be0SXin Li bool ConstTensorPin::IsValid() const
109*3e777be0SXin Li {
110*3e777be0SXin Li     return m_ConstTensor.GetMemoryArea() != nullptr;
111*3e777be0SXin Li }
112*3e777be0SXin Li 
IsOptional() const113*3e777be0SXin Li bool ConstTensorPin::IsOptional() const
114*3e777be0SXin Li {
115*3e777be0SXin Li     return m_Optional;
116*3e777be0SXin Li }
117*3e777be0SXin Li 
GetConstTensor() const118*3e777be0SXin Li const armnn::ConstTensor& ConstTensorPin::GetConstTensor() const
119*3e777be0SXin Li {
120*3e777be0SXin Li     return m_ConstTensor;
121*3e777be0SXin Li }
122*3e777be0SXin Li 
GetConstTensorPtr() const123*3e777be0SXin Li const armnn::ConstTensor* ConstTensorPin::GetConstTensorPtr() const
124*3e777be0SXin Li {
125*3e777be0SXin Li     if (IsValid() && m_ConstTensor.GetNumElements() > 0)
126*3e777be0SXin Li     {
127*3e777be0SXin Li         return &m_ConstTensor;
128*3e777be0SXin Li     }
129*3e777be0SXin Li     // tensor is either invalid, or has no elements (indicating an optional tensor that was not provided)
130*3e777be0SXin Li     return nullptr;
131*3e777be0SXin Li }
132*3e777be0SXin Li 
133*3e777be0SXin Li ///
134*3e777be0SXin Li /// Utility functions
135*3e777be0SXin Li ///
136*3e777be0SXin Li 
ProcessActivation(const armnn::TensorInfo & tensorInfo,ActivationFn activation,armnn::IConnectableLayer * prevLayer,ConversionData & data)137*3e777be0SXin Li armnn::IConnectableLayer* ProcessActivation(const armnn::TensorInfo& tensorInfo,
138*3e777be0SXin Li                                             ActivationFn activation,
139*3e777be0SXin Li                                             armnn::IConnectableLayer* prevLayer,
140*3e777be0SXin Li                                             ConversionData& data)
141*3e777be0SXin Li {
142*3e777be0SXin Li     if (prevLayer->GetNumOutputSlots() != 1)
143*3e777be0SXin Li     {
144*3e777be0SXin Li         Fail("%s: Incorrect Number of OutputSlots expected 1 was %i", __func__, prevLayer->GetNumOutputSlots());
145*3e777be0SXin Li         return nullptr;
146*3e777be0SXin Li     }
147*3e777be0SXin Li     prevLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
148*3e777be0SXin Li 
149*3e777be0SXin Li     armnn::IConnectableLayer* activationLayer = prevLayer;
150*3e777be0SXin Li 
151*3e777be0SXin Li     if (activation != ActivationFn::kActivationNone)
152*3e777be0SXin Li     {
153*3e777be0SXin Li         armnn::ActivationDescriptor activationDesc;
154*3e777be0SXin Li         switch (activation)
155*3e777be0SXin Li         {
156*3e777be0SXin Li             case ActivationFn::kActivationRelu:
157*3e777be0SXin Li             {
158*3e777be0SXin Li                 activationDesc.m_Function = armnn::ActivationFunction::ReLu;
159*3e777be0SXin Li                 break;
160*3e777be0SXin Li             }
161*3e777be0SXin Li             case ActivationFn::kActivationRelu1:
162*3e777be0SXin Li             {
163*3e777be0SXin Li                 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
164*3e777be0SXin Li                 activationDesc.m_A = 1.0f;
165*3e777be0SXin Li                 activationDesc.m_B = -1.0f;
166*3e777be0SXin Li                 break;
167*3e777be0SXin Li             }
168*3e777be0SXin Li             case ActivationFn::kActivationRelu6:
169*3e777be0SXin Li             {
170*3e777be0SXin Li                 activationDesc.m_Function = armnn::ActivationFunction::BoundedReLu;
171*3e777be0SXin Li                 activationDesc.m_A = 6.0f;
172*3e777be0SXin Li                 break;
173*3e777be0SXin Li             }
174*3e777be0SXin Li             case ActivationFn::kActivationSigmoid:
175*3e777be0SXin Li             {
176*3e777be0SXin Li                 activationDesc.m_Function = armnn::ActivationFunction::Sigmoid;
177*3e777be0SXin Li                 break;
178*3e777be0SXin Li             }
179*3e777be0SXin Li             case ActivationFn::kActivationTanh:
180*3e777be0SXin Li             {
181*3e777be0SXin Li                 activationDesc.m_Function = armnn::ActivationFunction::TanH;
182*3e777be0SXin Li                 activationDesc.m_A = 1.0f;
183*3e777be0SXin Li                 activationDesc.m_B = 1.0f;
184*3e777be0SXin Li                 break;
185*3e777be0SXin Li             }
186*3e777be0SXin Li             default:
187*3e777be0SXin Li             {
188*3e777be0SXin Li                 Fail("%s: Invalid activation enum value %i", __func__, activation);
189*3e777be0SXin Li                 return nullptr;
190*3e777be0SXin Li             }
191*3e777be0SXin Li         }
192*3e777be0SXin Li 
193*3e777be0SXin Li         bool isSupported = false;
194*3e777be0SXin Li         armnn::BackendId setBackend;
195*3e777be0SXin Li         FORWARD_LAYER_SUPPORT_FUNC(__func__,
196*3e777be0SXin Li                                    IsActivationSupported,
197*3e777be0SXin Li                                    data.m_Backends,
198*3e777be0SXin Li                                    isSupported,
199*3e777be0SXin Li                                    setBackend,
200*3e777be0SXin Li                                    prevLayer->GetOutputSlot(0).GetTensorInfo(),
201*3e777be0SXin Li                                    tensorInfo,
202*3e777be0SXin Li                                    activationDesc);
203*3e777be0SXin Li         if (!isSupported)
204*3e777be0SXin Li         {
205*3e777be0SXin Li             return nullptr;
206*3e777be0SXin Li         }
207*3e777be0SXin Li 
208*3e777be0SXin Li         activationLayer = data.m_Network->AddActivationLayer(activationDesc);
209*3e777be0SXin Li         activationLayer->SetBackendId(setBackend);
210*3e777be0SXin Li 
211*3e777be0SXin Li         prevLayer->GetOutputSlot(0).Connect(activationLayer->GetInputSlot(0));
212*3e777be0SXin Li         activationLayer->GetOutputSlot(0).SetTensorInfo(tensorInfo);
213*3e777be0SXin Li     }
214*3e777be0SXin Li 
215*3e777be0SXin Li     return activationLayer;
216*3e777be0SXin Li }
217*3e777be0SXin Li 
218*3e777be0SXin Li } // namespace armnn_driver
219