xref: /aosp_15_r20/external/android-nn-driver/ArmnnDriverImpl.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017, 2023 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li 
6*3e777be0SXin Li #define LOG_TAG "ArmnnDriver"
7*3e777be0SXin Li 
8*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
9*3e777be0SXin Li #include "ArmnnPreparedModel.hpp"
10*3e777be0SXin Li 
11*3e777be0SXin Li #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3) // Using ::android::hardware::neuralnetworks::V1_2
12*3e777be0SXin Li #include "ArmnnPreparedModel_1_2.hpp"
13*3e777be0SXin Li #endif
14*3e777be0SXin Li 
15*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_3 // Using ::android::hardware::neuralnetworks::V1_2
16*3e777be0SXin Li #include "ArmnnPreparedModel_1_3.hpp"
17*3e777be0SXin Li #endif
18*3e777be0SXin Li 
19*3e777be0SXin Li #include "Utils.hpp"
20*3e777be0SXin Li 
21*3e777be0SXin Li #include "ModelToINetworkConverter.hpp"
22*3e777be0SXin Li #include "SystemPropertiesUtils.hpp"
23*3e777be0SXin Li 
24*3e777be0SXin Li #include <ValidateHal.h>
25*3e777be0SXin Li #include <log/log.h>
26*3e777be0SXin Li #include <chrono>
27*3e777be0SXin Li 
28*3e777be0SXin Li using namespace std;
29*3e777be0SXin Li using namespace android;
30*3e777be0SXin Li using namespace android::nn;
31*3e777be0SXin Li using namespace android::hardware;
32*3e777be0SXin Li 
33*3e777be0SXin Li namespace
34*3e777be0SXin Li {
35*3e777be0SXin Li 
NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback> & callback,V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModelPtr)36*3e777be0SXin Li void NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback>& callback,
37*3e777be0SXin Li                             V1_0::ErrorStatus errorStatus,
38*3e777be0SXin Li                             const sp<V1_0::IPreparedModel>& preparedModelPtr)
39*3e777be0SXin Li {
40*3e777be0SXin Li     Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
41*3e777be0SXin Li     // This check is required, if the callback fails and it isn't checked it will bring down the service
42*3e777be0SXin Li     if (!returned.isOk())
43*3e777be0SXin Li     {
44*3e777be0SXin Li         ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
45*3e777be0SXin Li               returned.description().c_str());
46*3e777be0SXin Li     }
47*3e777be0SXin Li }
48*3e777be0SXin Li 
FailPrepareModel(V1_0::ErrorStatus error,const string & message,const sp<V1_0::IPreparedModelCallback> & callback)49*3e777be0SXin Li Return<V1_0::ErrorStatus> FailPrepareModel(V1_0::ErrorStatus error,
50*3e777be0SXin Li                                            const string& message,
51*3e777be0SXin Li                                            const sp<V1_0::IPreparedModelCallback>& callback)
52*3e777be0SXin Li {
53*3e777be0SXin Li     ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
54*3e777be0SXin Li     NotifyCallbackAndCheck(callback, error, nullptr);
55*3e777be0SXin Li     return error;
56*3e777be0SXin Li }
57*3e777be0SXin Li 
58*3e777be0SXin Li } // namespace
59*3e777be0SXin Li 
60*3e777be0SXin Li namespace armnn_driver
61*3e777be0SXin Li {
62*3e777be0SXin Li 
63*3e777be0SXin Li template<typename HalPolicy>
prepareModel(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const HalModel & model,const sp<V1_0::IPreparedModelCallback> & cb,bool float32ToFloat16)64*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
65*3e777be0SXin Li         const armnn::IRuntimePtr& runtime,
66*3e777be0SXin Li         const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
67*3e777be0SXin Li         const DriverOptions& options,
68*3e777be0SXin Li         const HalModel& model,
69*3e777be0SXin Li         const sp<V1_0::IPreparedModelCallback>& cb,
70*3e777be0SXin Li         bool float32ToFloat16)
71*3e777be0SXin Li {
72*3e777be0SXin Li     ALOGV("ArmnnDriverImpl::prepareModel()");
73*3e777be0SXin Li 
74*3e777be0SXin Li     std::chrono::time_point<std::chrono::system_clock> prepareModelTimepoint = std::chrono::system_clock::now();
75*3e777be0SXin Li 
76*3e777be0SXin Li     if (cb.get() == nullptr)
77*3e777be0SXin Li     {
78*3e777be0SXin Li         ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
79*3e777be0SXin Li         return V1_0::ErrorStatus::INVALID_ARGUMENT;
80*3e777be0SXin Li     }
81*3e777be0SXin Li 
82*3e777be0SXin Li     if (!runtime)
83*3e777be0SXin Li     {
84*3e777be0SXin Li         return FailPrepareModel(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
85*3e777be0SXin Li     }
86*3e777be0SXin Li 
87*3e777be0SXin Li     if (!android::nn::validateModel(model))
88*3e777be0SXin Li     {
89*3e777be0SXin Li         return FailPrepareModel(V1_0::ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
90*3e777be0SXin Li     }
91*3e777be0SXin Li 
92*3e777be0SXin Li     // Deliberately ignore any unsupported operations requested by the options -
93*3e777be0SXin Li     // at this point we're being asked to prepare a model that we've already declared support for
94*3e777be0SXin Li     // and the operation indices may be different to those in getSupportedOperations anyway.
95*3e777be0SXin Li     set<unsigned int> unsupportedOperations;
96*3e777be0SXin Li     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
97*3e777be0SXin Li                                                        model,
98*3e777be0SXin Li                                                        unsupportedOperations);
99*3e777be0SXin Li 
100*3e777be0SXin Li     if (modelConverter.GetConversionResult() != ConversionResult::Success)
101*3e777be0SXin Li     {
102*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
103*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
104*3e777be0SXin Li     }
105*3e777be0SXin Li 
106*3e777be0SXin Li     // Serialize the network graph to a .armnn file if an output directory
107*3e777be0SXin Li     // has been specified in the drivers' arguments.
108*3e777be0SXin Li     std::vector<uint8_t> dataCacheData;
109*3e777be0SXin Li     auto serializedNetworkFileName =
110*3e777be0SXin Li         SerializeNetwork(*modelConverter.GetINetwork(),
111*3e777be0SXin Li                          options.GetRequestInputsAndOutputsDumpDir(),
112*3e777be0SXin Li                          dataCacheData,
113*3e777be0SXin Li                          false);
114*3e777be0SXin Li 
115*3e777be0SXin Li     // Optimize the network
116*3e777be0SXin Li     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
117*3e777be0SXin Li     armnn::OptimizerOptionsOpaque OptOptions;
118*3e777be0SXin Li     OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
119*3e777be0SXin Li 
120*3e777be0SXin Li     armnn::BackendOptions gpuAcc("GpuAcc",
121*3e777be0SXin Li     {
122*3e777be0SXin Li         { "FastMathEnabled", options.IsFastMathEnabled() },
123*3e777be0SXin Li         { "SaveCachedNetwork", options.SaveCachedNetwork() },
124*3e777be0SXin Li         { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
125*3e777be0SXin Li         { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() }
126*3e777be0SXin Li 
127*3e777be0SXin Li     });
128*3e777be0SXin Li 
129*3e777be0SXin Li     armnn::BackendOptions cpuAcc("CpuAcc",
130*3e777be0SXin Li     {
131*3e777be0SXin Li         { "FastMathEnabled", options.IsFastMathEnabled() },
132*3e777be0SXin Li         { "NumberOfThreads", options.GetNumberOfThreads() }
133*3e777be0SXin Li     });
134*3e777be0SXin Li     OptOptions.AddModelOption(gpuAcc);
135*3e777be0SXin Li     OptOptions.AddModelOption(cpuAcc);
136*3e777be0SXin Li 
137*3e777be0SXin Li     std::vector<std::string> errMessages;
138*3e777be0SXin Li     try
139*3e777be0SXin Li     {
140*3e777be0SXin Li         optNet = armnn::Optimize(*modelConverter.GetINetwork(),
141*3e777be0SXin Li                                  options.GetBackends(),
142*3e777be0SXin Li                                  runtime->GetDeviceSpec(),
143*3e777be0SXin Li                                  OptOptions,
144*3e777be0SXin Li                                  errMessages);
145*3e777be0SXin Li     }
146*3e777be0SXin Li     catch (std::exception &e)
147*3e777be0SXin Li     {
148*3e777be0SXin Li         stringstream message;
149*3e777be0SXin Li         message << "Exception (" << e.what() << ") caught from optimize.";
150*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
151*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
152*3e777be0SXin Li     }
153*3e777be0SXin Li 
154*3e777be0SXin Li     // Check that the optimized network is valid.
155*3e777be0SXin Li     if (!optNet)
156*3e777be0SXin Li     {
157*3e777be0SXin Li         stringstream message;
158*3e777be0SXin Li         message << "Invalid optimized network";
159*3e777be0SXin Li         for (const string& msg : errMessages)
160*3e777be0SXin Li         {
161*3e777be0SXin Li             message << "\n" << msg;
162*3e777be0SXin Li         }
163*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
164*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
165*3e777be0SXin Li     }
166*3e777be0SXin Li 
167*3e777be0SXin Li     // Export the optimized network graph to a dot file if an output dump directory
168*3e777be0SXin Li     // has been specified in the drivers' arguments.
169*3e777be0SXin Li     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet, options.GetRequestInputsAndOutputsDumpDir());
170*3e777be0SXin Li 
171*3e777be0SXin Li     // Load it into the runtime.
172*3e777be0SXin Li     armnn::NetworkId netId = 0;
173*3e777be0SXin Li     std::string msg;
174*3e777be0SXin Li     armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
175*3e777be0SXin Li                                                 armnn::MemorySource::Undefined,
176*3e777be0SXin Li                                                 armnn::MemorySource::Undefined);
177*3e777be0SXin Li 
178*3e777be0SXin Li     try
179*3e777be0SXin Li     {
180*3e777be0SXin Li         if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
181*3e777be0SXin Li         {
182*3e777be0SXin Li             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
183*3e777be0SXin Li         }
184*3e777be0SXin Li     }
185*3e777be0SXin Li     catch (std::exception& e)
186*3e777be0SXin Li     {
187*3e777be0SXin Li         stringstream message;
188*3e777be0SXin Li         message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
189*3e777be0SXin Li         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
190*3e777be0SXin Li         return V1_0::ErrorStatus::NONE;
191*3e777be0SXin Li     }
192*3e777be0SXin Li 
193*3e777be0SXin Li     // Now that we have a networkId for the graph rename the exported files to use it
194*3e777be0SXin Li     // so that we can associate the graph file and the input/output tensor exported files
195*3e777be0SXin Li     RenameExportedFiles(serializedNetworkFileName,
196*3e777be0SXin Li                         dotGraphFileName,
197*3e777be0SXin Li                         options.GetRequestInputsAndOutputsDumpDir(),
198*3e777be0SXin Li                         netId);
199*3e777be0SXin Li 
200*3e777be0SXin Li     sp<ArmnnPreparedModel<HalPolicy>> preparedModel(
201*3e777be0SXin Li             new ArmnnPreparedModel<HalPolicy>(
202*3e777be0SXin Li                     netId,
203*3e777be0SXin Li                     runtime.get(),
204*3e777be0SXin Li                     model,
205*3e777be0SXin Li                     options.GetRequestInputsAndOutputsDumpDir(),
206*3e777be0SXin Li                     options.IsGpuProfilingEnabled(),
207*3e777be0SXin Li                     options.isAsyncModelExecutionEnabled(),
208*3e777be0SXin Li                     options.getNoOfArmnnThreads(),
209*3e777be0SXin Li                     options.isImportEnabled(),
210*3e777be0SXin Li                     options.isExportEnabled()));
211*3e777be0SXin Li 
212*3e777be0SXin Li     if (std::find(options.GetBackends().begin(),
213*3e777be0SXin Li                   options.GetBackends().end(),
214*3e777be0SXin Li                   armnn::Compute::GpuAcc) != options.GetBackends().end())
215*3e777be0SXin Li     {
216*3e777be0SXin Li         // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
217*3e777be0SXin Li         // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
218*3e777be0SXin Li         if (!preparedModel->ExecuteWithDummyInputs())
219*3e777be0SXin Li         {
220*3e777be0SXin Li             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
221*3e777be0SXin Li         }
222*3e777be0SXin Li 
223*3e777be0SXin Li         if (clTunedParameters &&
224*3e777be0SXin Li             options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
225*3e777be0SXin Li         {
226*3e777be0SXin Li             // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file
227*3e777be0SXin Li             try
228*3e777be0SXin Li             {
229*3e777be0SXin Li                 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
230*3e777be0SXin Li             }
231*3e777be0SXin Li             catch (std::exception& error)
232*3e777be0SXin Li             {
233*3e777be0SXin Li                 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
234*3e777be0SXin Li                       options.GetClTunedParametersFile().c_str(), error.what());
235*3e777be0SXin Li             }
236*3e777be0SXin Li         }
237*3e777be0SXin Li     }
238*3e777be0SXin Li     NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel);
239*3e777be0SXin Li 
240*3e777be0SXin Li     ALOGV("ArmnnDriverImpl::prepareModel cache timing = %lld µs", std::chrono::duration_cast<std::chrono::microseconds>
241*3e777be0SXin Li          (std::chrono::system_clock::now() - prepareModelTimepoint).count());
242*3e777be0SXin Li 
243*3e777be0SXin Li     return V1_0::ErrorStatus::NONE;
244*3e777be0SXin Li }
245*3e777be0SXin Li 
246*3e777be0SXin Li template<typename HalPolicy>
getSupportedOperations(const armnn::IRuntimePtr & runtime,const DriverOptions & options,const HalModel & model,HalGetSupportedOperations_cb cb)247*3e777be0SXin Li Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
248*3e777be0SXin Li                                                                 const DriverOptions& options,
249*3e777be0SXin Li                                                                 const HalModel& model,
250*3e777be0SXin Li                                                                 HalGetSupportedOperations_cb cb)
251*3e777be0SXin Li {
252*3e777be0SXin Li     std::stringstream ss;
253*3e777be0SXin Li     ss << "ArmnnDriverImpl::getSupportedOperations()";
254*3e777be0SXin Li     std::string fileName;
255*3e777be0SXin Li     std::string timestamp;
256*3e777be0SXin Li     if (!options.GetRequestInputsAndOutputsDumpDir().empty())
257*3e777be0SXin Li     {
258*3e777be0SXin Li         ss << " : "
259*3e777be0SXin Li            << options.GetRequestInputsAndOutputsDumpDir()
260*3e777be0SXin Li            << "/"
261*3e777be0SXin Li            << GetFileTimestamp()
262*3e777be0SXin Li            << "_getSupportedOperations.txt";
263*3e777be0SXin Li     }
264*3e777be0SXin Li     ALOGV(ss.str().c_str());
265*3e777be0SXin Li 
266*3e777be0SXin Li     if (!options.GetRequestInputsAndOutputsDumpDir().empty())
267*3e777be0SXin Li     {
268*3e777be0SXin Li         //dump the marker file
269*3e777be0SXin Li         std::ofstream fileStream;
270*3e777be0SXin Li         fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
271*3e777be0SXin Li         if (fileStream.good())
272*3e777be0SXin Li         {
273*3e777be0SXin Li             fileStream << timestamp << std::endl;
274*3e777be0SXin Li         }
275*3e777be0SXin Li         fileStream.close();
276*3e777be0SXin Li     }
277*3e777be0SXin Li 
278*3e777be0SXin Li     vector<bool> result;
279*3e777be0SXin Li 
280*3e777be0SXin Li     if (!runtime)
281*3e777be0SXin Li     {
282*3e777be0SXin Li         cb(HalErrorStatus::DEVICE_UNAVAILABLE, result);
283*3e777be0SXin Li         return Void();
284*3e777be0SXin Li     }
285*3e777be0SXin Li 
286*3e777be0SXin Li     // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
287*3e777be0SXin Li     if (!android::nn::validateModel(model))
288*3e777be0SXin Li     {
289*3e777be0SXin Li         cb(HalErrorStatus::INVALID_ARGUMENT, result);
290*3e777be0SXin Li         return Void();
291*3e777be0SXin Li     }
292*3e777be0SXin Li 
293*3e777be0SXin Li     // Attempt to convert the model to an ArmNN input network (INetwork).
294*3e777be0SXin Li     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
295*3e777be0SXin Li                                                        model,
296*3e777be0SXin Li                                                        options.GetForcedUnsupportedOperations());
297*3e777be0SXin Li 
298*3e777be0SXin Li     if (modelConverter.GetConversionResult() != ConversionResult::Success
299*3e777be0SXin Li             && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
300*3e777be0SXin Li     {
301*3e777be0SXin Li         cb(HalErrorStatus::GENERAL_FAILURE, result);
302*3e777be0SXin Li         return Void();
303*3e777be0SXin Li     }
304*3e777be0SXin Li 
305*3e777be0SXin Li     // Check each operation if it was converted successfully and copy the flags
306*3e777be0SXin Li     // into the result (vector<bool>) that we need to return to Android.
307*3e777be0SXin Li     result.reserve(getMainModel(model).operations.size());
308*3e777be0SXin Li     for (uint32_t operationIdx = 0;
309*3e777be0SXin Li          operationIdx < getMainModel(model).operations.size();
310*3e777be0SXin Li          ++operationIdx)
311*3e777be0SXin Li     {
312*3e777be0SXin Li         bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
313*3e777be0SXin Li         result.push_back(operationSupported);
314*3e777be0SXin Li     }
315*3e777be0SXin Li 
316*3e777be0SXin Li     cb(HalErrorStatus::NONE, result);
317*3e777be0SXin Li     return Void();
318*3e777be0SXin Li }
319*3e777be0SXin Li 
320*3e777be0SXin Li template<typename HalPolicy>
getStatus()321*3e777be0SXin Li Return<V1_0::DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
322*3e777be0SXin Li {
323*3e777be0SXin Li     ALOGV("ArmnnDriver::getStatus()");
324*3e777be0SXin Li 
325*3e777be0SXin Li     return V1_0::DeviceStatus::AVAILABLE;
326*3e777be0SXin Li }
327*3e777be0SXin Li 
328*3e777be0SXin Li ///
329*3e777be0SXin Li /// Class template specializations
330*3e777be0SXin Li ///
331*3e777be0SXin Li 
332*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
333*3e777be0SXin Li 
334*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_1
335*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
336*3e777be0SXin Li #endif
337*3e777be0SXin Li 
338*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_2
339*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
340*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
341*3e777be0SXin Li #endif
342*3e777be0SXin Li 
343*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_3
344*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
345*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
346*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_3::HalPolicy>;
347*3e777be0SXin Li #endif
348*3e777be0SXin Li 
349*3e777be0SXin Li } // namespace armnn_driver
350