xref: /aosp_15_r20/external/android-nn-driver/ArmnnDriverImpl.cpp (revision 3e777be0405cee09af5d5785ff37f7cfb5bee59a)
1 //
2 // Copyright © 2017, 2023 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #define LOG_TAG "ArmnnDriver"
7 
8 #include "ArmnnDriverImpl.hpp"
9 #include "ArmnnPreparedModel.hpp"
10 
11 #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3) // Using ::android::hardware::neuralnetworks::V1_2
12 #include "ArmnnPreparedModel_1_2.hpp"
13 #endif
14 
15 #ifdef ARMNN_ANDROID_NN_V1_3 // Using ::android::hardware::neuralnetworks::V1_2
16 #include "ArmnnPreparedModel_1_3.hpp"
17 #endif
18 
19 #include "Utils.hpp"
20 
21 #include "ModelToINetworkConverter.hpp"
22 #include "SystemPropertiesUtils.hpp"
23 
24 #include <ValidateHal.h>
25 #include <log/log.h>
26 #include <chrono>
27 
28 using namespace std;
29 using namespace android;
30 using namespace android::nn;
31 using namespace android::hardware;
32 
33 namespace
34 {
35 
NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback> & callback,V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModelPtr)36 void NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback>& callback,
37                             V1_0::ErrorStatus errorStatus,
38                             const sp<V1_0::IPreparedModel>& preparedModelPtr)
39 {
40     Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
41     // This check is required, if the callback fails and it isn't checked it will bring down the service
42     if (!returned.isOk())
43     {
44         ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
45               returned.description().c_str());
46     }
47 }
48 
FailPrepareModel(V1_0::ErrorStatus error,const string & message,const sp<V1_0::IPreparedModelCallback> & callback)49 Return<V1_0::ErrorStatus> FailPrepareModel(V1_0::ErrorStatus error,
50                                            const string& message,
51                                            const sp<V1_0::IPreparedModelCallback>& callback)
52 {
53     ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
54     NotifyCallbackAndCheck(callback, error, nullptr);
55     return error;
56 }
57 
58 } // namespace
59 
60 namespace armnn_driver
61 {
62 
63 template<typename HalPolicy>
prepareModel(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const HalModel & model,const sp<V1_0::IPreparedModelCallback> & cb,bool float32ToFloat16)64 Return<V1_0::ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
65         const armnn::IRuntimePtr& runtime,
66         const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
67         const DriverOptions& options,
68         const HalModel& model,
69         const sp<V1_0::IPreparedModelCallback>& cb,
70         bool float32ToFloat16)
71 {
72     ALOGV("ArmnnDriverImpl::prepareModel()");
73 
74     std::chrono::time_point<std::chrono::system_clock> prepareModelTimepoint = std::chrono::system_clock::now();
75 
76     if (cb.get() == nullptr)
77     {
78         ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
79         return V1_0::ErrorStatus::INVALID_ARGUMENT;
80     }
81 
82     if (!runtime)
83     {
84         return FailPrepareModel(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
85     }
86 
87     if (!android::nn::validateModel(model))
88     {
89         return FailPrepareModel(V1_0::ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
90     }
91 
92     // Deliberately ignore any unsupported operations requested by the options -
93     // at this point we're being asked to prepare a model that we've already declared support for
94     // and the operation indices may be different to those in getSupportedOperations anyway.
95     set<unsigned int> unsupportedOperations;
96     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
97                                                        model,
98                                                        unsupportedOperations);
99 
100     if (modelConverter.GetConversionResult() != ConversionResult::Success)
101     {
102         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
103         return V1_0::ErrorStatus::NONE;
104     }
105 
106     // Serialize the network graph to a .armnn file if an output directory
107     // has been specified in the drivers' arguments.
108     std::vector<uint8_t> dataCacheData;
109     auto serializedNetworkFileName =
110         SerializeNetwork(*modelConverter.GetINetwork(),
111                          options.GetRequestInputsAndOutputsDumpDir(),
112                          dataCacheData,
113                          false);
114 
115     // Optimize the network
116     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
117     armnn::OptimizerOptionsOpaque OptOptions;
118     OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
119 
120     armnn::BackendOptions gpuAcc("GpuAcc",
121     {
122         { "FastMathEnabled", options.IsFastMathEnabled() },
123         { "SaveCachedNetwork", options.SaveCachedNetwork() },
124         { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
125         { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() }
126 
127     });
128 
129     armnn::BackendOptions cpuAcc("CpuAcc",
130     {
131         { "FastMathEnabled", options.IsFastMathEnabled() },
132         { "NumberOfThreads", options.GetNumberOfThreads() }
133     });
134     OptOptions.AddModelOption(gpuAcc);
135     OptOptions.AddModelOption(cpuAcc);
136 
137     std::vector<std::string> errMessages;
138     try
139     {
140         optNet = armnn::Optimize(*modelConverter.GetINetwork(),
141                                  options.GetBackends(),
142                                  runtime->GetDeviceSpec(),
143                                  OptOptions,
144                                  errMessages);
145     }
146     catch (std::exception &e)
147     {
148         stringstream message;
149         message << "Exception (" << e.what() << ") caught from optimize.";
150         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
151         return V1_0::ErrorStatus::NONE;
152     }
153 
154     // Check that the optimized network is valid.
155     if (!optNet)
156     {
157         stringstream message;
158         message << "Invalid optimized network";
159         for (const string& msg : errMessages)
160         {
161             message << "\n" << msg;
162         }
163         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
164         return V1_0::ErrorStatus::NONE;
165     }
166 
167     // Export the optimized network graph to a dot file if an output dump directory
168     // has been specified in the drivers' arguments.
169     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet, options.GetRequestInputsAndOutputsDumpDir());
170 
171     // Load it into the runtime.
172     armnn::NetworkId netId = 0;
173     std::string msg;
174     armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
175                                                 armnn::MemorySource::Undefined,
176                                                 armnn::MemorySource::Undefined);
177 
178     try
179     {
180         if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
181         {
182             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
183         }
184     }
185     catch (std::exception& e)
186     {
187         stringstream message;
188         message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
189         FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
190         return V1_0::ErrorStatus::NONE;
191     }
192 
193     // Now that we have a networkId for the graph rename the exported files to use it
194     // so that we can associate the graph file and the input/output tensor exported files
195     RenameExportedFiles(serializedNetworkFileName,
196                         dotGraphFileName,
197                         options.GetRequestInputsAndOutputsDumpDir(),
198                         netId);
199 
200     sp<ArmnnPreparedModel<HalPolicy>> preparedModel(
201             new ArmnnPreparedModel<HalPolicy>(
202                     netId,
203                     runtime.get(),
204                     model,
205                     options.GetRequestInputsAndOutputsDumpDir(),
206                     options.IsGpuProfilingEnabled(),
207                     options.isAsyncModelExecutionEnabled(),
208                     options.getNoOfArmnnThreads(),
209                     options.isImportEnabled(),
210                     options.isExportEnabled()));
211 
212     if (std::find(options.GetBackends().begin(),
213                   options.GetBackends().end(),
214                   armnn::Compute::GpuAcc) != options.GetBackends().end())
215     {
216         // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
217         // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
218         if (!preparedModel->ExecuteWithDummyInputs())
219         {
220             return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
221         }
222 
223         if (clTunedParameters &&
224             options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
225         {
226             // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file
227             try
228             {
229                 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
230             }
231             catch (std::exception& error)
232             {
233                 ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
234                       options.GetClTunedParametersFile().c_str(), error.what());
235             }
236         }
237     }
238     NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel);
239 
240     ALOGV("ArmnnDriverImpl::prepareModel cache timing = %lld µs", std::chrono::duration_cast<std::chrono::microseconds>
241          (std::chrono::system_clock::now() - prepareModelTimepoint).count());
242 
243     return V1_0::ErrorStatus::NONE;
244 }
245 
246 template<typename HalPolicy>
getSupportedOperations(const armnn::IRuntimePtr & runtime,const DriverOptions & options,const HalModel & model,HalGetSupportedOperations_cb cb)247 Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
248                                                                 const DriverOptions& options,
249                                                                 const HalModel& model,
250                                                                 HalGetSupportedOperations_cb cb)
251 {
252     std::stringstream ss;
253     ss << "ArmnnDriverImpl::getSupportedOperations()";
254     std::string fileName;
255     std::string timestamp;
256     if (!options.GetRequestInputsAndOutputsDumpDir().empty())
257     {
258         ss << " : "
259            << options.GetRequestInputsAndOutputsDumpDir()
260            << "/"
261            << GetFileTimestamp()
262            << "_getSupportedOperations.txt";
263     }
264     ALOGV(ss.str().c_str());
265 
266     if (!options.GetRequestInputsAndOutputsDumpDir().empty())
267     {
268         //dump the marker file
269         std::ofstream fileStream;
270         fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
271         if (fileStream.good())
272         {
273             fileStream << timestamp << std::endl;
274         }
275         fileStream.close();
276     }
277 
278     vector<bool> result;
279 
280     if (!runtime)
281     {
282         cb(HalErrorStatus::DEVICE_UNAVAILABLE, result);
283         return Void();
284     }
285 
286     // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
287     if (!android::nn::validateModel(model))
288     {
289         cb(HalErrorStatus::INVALID_ARGUMENT, result);
290         return Void();
291     }
292 
293     // Attempt to convert the model to an ArmNN input network (INetwork).
294     ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
295                                                        model,
296                                                        options.GetForcedUnsupportedOperations());
297 
298     if (modelConverter.GetConversionResult() != ConversionResult::Success
299             && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
300     {
301         cb(HalErrorStatus::GENERAL_FAILURE, result);
302         return Void();
303     }
304 
305     // Check each operation if it was converted successfully and copy the flags
306     // into the result (vector<bool>) that we need to return to Android.
307     result.reserve(getMainModel(model).operations.size());
308     for (uint32_t operationIdx = 0;
309          operationIdx < getMainModel(model).operations.size();
310          ++operationIdx)
311     {
312         bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
313         result.push_back(operationSupported);
314     }
315 
316     cb(HalErrorStatus::NONE, result);
317     return Void();
318 }
319 
320 template<typename HalPolicy>
getStatus()321 Return<V1_0::DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
322 {
323     ALOGV("ArmnnDriver::getStatus()");
324 
325     return V1_0::DeviceStatus::AVAILABLE;
326 }
327 
328 ///
329 /// Class template specializations
330 ///
331 
332 template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
333 
334 #ifdef ARMNN_ANDROID_NN_V1_1
335 template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
336 #endif
337 
338 #ifdef ARMNN_ANDROID_NN_V1_2
339 template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
340 template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
341 #endif
342 
343 #ifdef ARMNN_ANDROID_NN_V1_3
344 template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
345 template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
346 template class ArmnnDriverImpl<hal_1_3::HalPolicy>;
347 #endif
348 
349 } // namespace armnn_driver
350