1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017, 2023 Arm Ltd. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li
6*3e777be0SXin Li #define LOG_TAG "ArmnnDriver"
7*3e777be0SXin Li
8*3e777be0SXin Li #include "ArmnnDriverImpl.hpp"
9*3e777be0SXin Li #include "ArmnnPreparedModel.hpp"
10*3e777be0SXin Li
11*3e777be0SXin Li #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3) // Using ::android::hardware::neuralnetworks::V1_2
12*3e777be0SXin Li #include "ArmnnPreparedModel_1_2.hpp"
13*3e777be0SXin Li #endif
14*3e777be0SXin Li
15*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_3 // Using ::android::hardware::neuralnetworks::V1_2
16*3e777be0SXin Li #include "ArmnnPreparedModel_1_3.hpp"
17*3e777be0SXin Li #endif
18*3e777be0SXin Li
19*3e777be0SXin Li #include "Utils.hpp"
20*3e777be0SXin Li
21*3e777be0SXin Li #include "ModelToINetworkConverter.hpp"
22*3e777be0SXin Li #include "SystemPropertiesUtils.hpp"
23*3e777be0SXin Li
24*3e777be0SXin Li #include <ValidateHal.h>
25*3e777be0SXin Li #include <log/log.h>
26*3e777be0SXin Li #include <chrono>
27*3e777be0SXin Li
28*3e777be0SXin Li using namespace std;
29*3e777be0SXin Li using namespace android;
30*3e777be0SXin Li using namespace android::nn;
31*3e777be0SXin Li using namespace android::hardware;
32*3e777be0SXin Li
33*3e777be0SXin Li namespace
34*3e777be0SXin Li {
35*3e777be0SXin Li
NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback> & callback,V1_0::ErrorStatus errorStatus,const sp<V1_0::IPreparedModel> & preparedModelPtr)36*3e777be0SXin Li void NotifyCallbackAndCheck(const sp<V1_0::IPreparedModelCallback>& callback,
37*3e777be0SXin Li V1_0::ErrorStatus errorStatus,
38*3e777be0SXin Li const sp<V1_0::IPreparedModel>& preparedModelPtr)
39*3e777be0SXin Li {
40*3e777be0SXin Li Return<void> returned = callback->notify(errorStatus, preparedModelPtr);
41*3e777be0SXin Li // This check is required, if the callback fails and it isn't checked it will bring down the service
42*3e777be0SXin Li if (!returned.isOk())
43*3e777be0SXin Li {
44*3e777be0SXin Li ALOGE("ArmnnDriverImpl::prepareModel: hidl callback failed to return properly: %s ",
45*3e777be0SXin Li returned.description().c_str());
46*3e777be0SXin Li }
47*3e777be0SXin Li }
48*3e777be0SXin Li
FailPrepareModel(V1_0::ErrorStatus error,const string & message,const sp<V1_0::IPreparedModelCallback> & callback)49*3e777be0SXin Li Return<V1_0::ErrorStatus> FailPrepareModel(V1_0::ErrorStatus error,
50*3e777be0SXin Li const string& message,
51*3e777be0SXin Li const sp<V1_0::IPreparedModelCallback>& callback)
52*3e777be0SXin Li {
53*3e777be0SXin Li ALOGW("ArmnnDriverImpl::prepareModel: %s", message.c_str());
54*3e777be0SXin Li NotifyCallbackAndCheck(callback, error, nullptr);
55*3e777be0SXin Li return error;
56*3e777be0SXin Li }
57*3e777be0SXin Li
58*3e777be0SXin Li } // namespace
59*3e777be0SXin Li
60*3e777be0SXin Li namespace armnn_driver
61*3e777be0SXin Li {
62*3e777be0SXin Li
63*3e777be0SXin Li template<typename HalPolicy>
prepareModel(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const HalModel & model,const sp<V1_0::IPreparedModelCallback> & cb,bool float32ToFloat16)64*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnDriverImpl<HalPolicy>::prepareModel(
65*3e777be0SXin Li const armnn::IRuntimePtr& runtime,
66*3e777be0SXin Li const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
67*3e777be0SXin Li const DriverOptions& options,
68*3e777be0SXin Li const HalModel& model,
69*3e777be0SXin Li const sp<V1_0::IPreparedModelCallback>& cb,
70*3e777be0SXin Li bool float32ToFloat16)
71*3e777be0SXin Li {
72*3e777be0SXin Li ALOGV("ArmnnDriverImpl::prepareModel()");
73*3e777be0SXin Li
74*3e777be0SXin Li std::chrono::time_point<std::chrono::system_clock> prepareModelTimepoint = std::chrono::system_clock::now();
75*3e777be0SXin Li
76*3e777be0SXin Li if (cb.get() == nullptr)
77*3e777be0SXin Li {
78*3e777be0SXin Li ALOGW("ArmnnDriverImpl::prepareModel: Invalid callback passed to prepareModel");
79*3e777be0SXin Li return V1_0::ErrorStatus::INVALID_ARGUMENT;
80*3e777be0SXin Li }
81*3e777be0SXin Li
82*3e777be0SXin Li if (!runtime)
83*3e777be0SXin Li {
84*3e777be0SXin Li return FailPrepareModel(V1_0::ErrorStatus::DEVICE_UNAVAILABLE, "Device unavailable", cb);
85*3e777be0SXin Li }
86*3e777be0SXin Li
87*3e777be0SXin Li if (!android::nn::validateModel(model))
88*3e777be0SXin Li {
89*3e777be0SXin Li return FailPrepareModel(V1_0::ErrorStatus::INVALID_ARGUMENT, "Invalid model passed as input", cb);
90*3e777be0SXin Li }
91*3e777be0SXin Li
92*3e777be0SXin Li // Deliberately ignore any unsupported operations requested by the options -
93*3e777be0SXin Li // at this point we're being asked to prepare a model that we've already declared support for
94*3e777be0SXin Li // and the operation indices may be different to those in getSupportedOperations anyway.
95*3e777be0SXin Li set<unsigned int> unsupportedOperations;
96*3e777be0SXin Li ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
97*3e777be0SXin Li model,
98*3e777be0SXin Li unsupportedOperations);
99*3e777be0SXin Li
100*3e777be0SXin Li if (modelConverter.GetConversionResult() != ConversionResult::Success)
101*3e777be0SXin Li {
102*3e777be0SXin Li FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "ModelToINetworkConverter failed", cb);
103*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
104*3e777be0SXin Li }
105*3e777be0SXin Li
106*3e777be0SXin Li // Serialize the network graph to a .armnn file if an output directory
107*3e777be0SXin Li // has been specified in the drivers' arguments.
108*3e777be0SXin Li std::vector<uint8_t> dataCacheData;
109*3e777be0SXin Li auto serializedNetworkFileName =
110*3e777be0SXin Li SerializeNetwork(*modelConverter.GetINetwork(),
111*3e777be0SXin Li options.GetRequestInputsAndOutputsDumpDir(),
112*3e777be0SXin Li dataCacheData,
113*3e777be0SXin Li false);
114*3e777be0SXin Li
115*3e777be0SXin Li // Optimize the network
116*3e777be0SXin Li armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
117*3e777be0SXin Li armnn::OptimizerOptionsOpaque OptOptions;
118*3e777be0SXin Li OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
119*3e777be0SXin Li
120*3e777be0SXin Li armnn::BackendOptions gpuAcc("GpuAcc",
121*3e777be0SXin Li {
122*3e777be0SXin Li { "FastMathEnabled", options.IsFastMathEnabled() },
123*3e777be0SXin Li { "SaveCachedNetwork", options.SaveCachedNetwork() },
124*3e777be0SXin Li { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
125*3e777be0SXin Li { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() }
126*3e777be0SXin Li
127*3e777be0SXin Li });
128*3e777be0SXin Li
129*3e777be0SXin Li armnn::BackendOptions cpuAcc("CpuAcc",
130*3e777be0SXin Li {
131*3e777be0SXin Li { "FastMathEnabled", options.IsFastMathEnabled() },
132*3e777be0SXin Li { "NumberOfThreads", options.GetNumberOfThreads() }
133*3e777be0SXin Li });
134*3e777be0SXin Li OptOptions.AddModelOption(gpuAcc);
135*3e777be0SXin Li OptOptions.AddModelOption(cpuAcc);
136*3e777be0SXin Li
137*3e777be0SXin Li std::vector<std::string> errMessages;
138*3e777be0SXin Li try
139*3e777be0SXin Li {
140*3e777be0SXin Li optNet = armnn::Optimize(*modelConverter.GetINetwork(),
141*3e777be0SXin Li options.GetBackends(),
142*3e777be0SXin Li runtime->GetDeviceSpec(),
143*3e777be0SXin Li OptOptions,
144*3e777be0SXin Li errMessages);
145*3e777be0SXin Li }
146*3e777be0SXin Li catch (std::exception &e)
147*3e777be0SXin Li {
148*3e777be0SXin Li stringstream message;
149*3e777be0SXin Li message << "Exception (" << e.what() << ") caught from optimize.";
150*3e777be0SXin Li FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
151*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
152*3e777be0SXin Li }
153*3e777be0SXin Li
154*3e777be0SXin Li // Check that the optimized network is valid.
155*3e777be0SXin Li if (!optNet)
156*3e777be0SXin Li {
157*3e777be0SXin Li stringstream message;
158*3e777be0SXin Li message << "Invalid optimized network";
159*3e777be0SXin Li for (const string& msg : errMessages)
160*3e777be0SXin Li {
161*3e777be0SXin Li message << "\n" << msg;
162*3e777be0SXin Li }
163*3e777be0SXin Li FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
164*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
165*3e777be0SXin Li }
166*3e777be0SXin Li
167*3e777be0SXin Li // Export the optimized network graph to a dot file if an output dump directory
168*3e777be0SXin Li // has been specified in the drivers' arguments.
169*3e777be0SXin Li std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet, options.GetRequestInputsAndOutputsDumpDir());
170*3e777be0SXin Li
171*3e777be0SXin Li // Load it into the runtime.
172*3e777be0SXin Li armnn::NetworkId netId = 0;
173*3e777be0SXin Li std::string msg;
174*3e777be0SXin Li armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
175*3e777be0SXin Li armnn::MemorySource::Undefined,
176*3e777be0SXin Li armnn::MemorySource::Undefined);
177*3e777be0SXin Li
178*3e777be0SXin Li try
179*3e777be0SXin Li {
180*3e777be0SXin Li if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
181*3e777be0SXin Li {
182*3e777be0SXin Li return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be loaded", cb);
183*3e777be0SXin Li }
184*3e777be0SXin Li }
185*3e777be0SXin Li catch (std::exception& e)
186*3e777be0SXin Li {
187*3e777be0SXin Li stringstream message;
188*3e777be0SXin Li message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
189*3e777be0SXin Li FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, message.str(), cb);
190*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
191*3e777be0SXin Li }
192*3e777be0SXin Li
193*3e777be0SXin Li // Now that we have a networkId for the graph rename the exported files to use it
194*3e777be0SXin Li // so that we can associate the graph file and the input/output tensor exported files
195*3e777be0SXin Li RenameExportedFiles(serializedNetworkFileName,
196*3e777be0SXin Li dotGraphFileName,
197*3e777be0SXin Li options.GetRequestInputsAndOutputsDumpDir(),
198*3e777be0SXin Li netId);
199*3e777be0SXin Li
200*3e777be0SXin Li sp<ArmnnPreparedModel<HalPolicy>> preparedModel(
201*3e777be0SXin Li new ArmnnPreparedModel<HalPolicy>(
202*3e777be0SXin Li netId,
203*3e777be0SXin Li runtime.get(),
204*3e777be0SXin Li model,
205*3e777be0SXin Li options.GetRequestInputsAndOutputsDumpDir(),
206*3e777be0SXin Li options.IsGpuProfilingEnabled(),
207*3e777be0SXin Li options.isAsyncModelExecutionEnabled(),
208*3e777be0SXin Li options.getNoOfArmnnThreads(),
209*3e777be0SXin Li options.isImportEnabled(),
210*3e777be0SXin Li options.isExportEnabled()));
211*3e777be0SXin Li
212*3e777be0SXin Li if (std::find(options.GetBackends().begin(),
213*3e777be0SXin Li options.GetBackends().end(),
214*3e777be0SXin Li armnn::Compute::GpuAcc) != options.GetBackends().end())
215*3e777be0SXin Li {
216*3e777be0SXin Li // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
217*3e777be0SXin Li // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
218*3e777be0SXin Li if (!preparedModel->ExecuteWithDummyInputs())
219*3e777be0SXin Li {
220*3e777be0SXin Li return FailPrepareModel(V1_0::ErrorStatus::GENERAL_FAILURE, "Network could not be executed", cb);
221*3e777be0SXin Li }
222*3e777be0SXin Li
223*3e777be0SXin Li if (clTunedParameters &&
224*3e777be0SXin Li options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
225*3e777be0SXin Li {
226*3e777be0SXin Li // Now that we've done one inference the CL kernel parameters will have been tuned, so save the updated file
227*3e777be0SXin Li try
228*3e777be0SXin Li {
229*3e777be0SXin Li clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
230*3e777be0SXin Li }
231*3e777be0SXin Li catch (std::exception& error)
232*3e777be0SXin Li {
233*3e777be0SXin Li ALOGE("ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file '%s': %s",
234*3e777be0SXin Li options.GetClTunedParametersFile().c_str(), error.what());
235*3e777be0SXin Li }
236*3e777be0SXin Li }
237*3e777be0SXin Li }
238*3e777be0SXin Li NotifyCallbackAndCheck(cb, V1_0::ErrorStatus::NONE, preparedModel);
239*3e777be0SXin Li
240*3e777be0SXin Li ALOGV("ArmnnDriverImpl::prepareModel cache timing = %lld µs", std::chrono::duration_cast<std::chrono::microseconds>
241*3e777be0SXin Li (std::chrono::system_clock::now() - prepareModelTimepoint).count());
242*3e777be0SXin Li
243*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
244*3e777be0SXin Li }
245*3e777be0SXin Li
246*3e777be0SXin Li template<typename HalPolicy>
getSupportedOperations(const armnn::IRuntimePtr & runtime,const DriverOptions & options,const HalModel & model,HalGetSupportedOperations_cb cb)247*3e777be0SXin Li Return<void> ArmnnDriverImpl<HalPolicy>::getSupportedOperations(const armnn::IRuntimePtr& runtime,
248*3e777be0SXin Li const DriverOptions& options,
249*3e777be0SXin Li const HalModel& model,
250*3e777be0SXin Li HalGetSupportedOperations_cb cb)
251*3e777be0SXin Li {
252*3e777be0SXin Li std::stringstream ss;
253*3e777be0SXin Li ss << "ArmnnDriverImpl::getSupportedOperations()";
254*3e777be0SXin Li std::string fileName;
255*3e777be0SXin Li std::string timestamp;
256*3e777be0SXin Li if (!options.GetRequestInputsAndOutputsDumpDir().empty())
257*3e777be0SXin Li {
258*3e777be0SXin Li ss << " : "
259*3e777be0SXin Li << options.GetRequestInputsAndOutputsDumpDir()
260*3e777be0SXin Li << "/"
261*3e777be0SXin Li << GetFileTimestamp()
262*3e777be0SXin Li << "_getSupportedOperations.txt";
263*3e777be0SXin Li }
264*3e777be0SXin Li ALOGV(ss.str().c_str());
265*3e777be0SXin Li
266*3e777be0SXin Li if (!options.GetRequestInputsAndOutputsDumpDir().empty())
267*3e777be0SXin Li {
268*3e777be0SXin Li //dump the marker file
269*3e777be0SXin Li std::ofstream fileStream;
270*3e777be0SXin Li fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
271*3e777be0SXin Li if (fileStream.good())
272*3e777be0SXin Li {
273*3e777be0SXin Li fileStream << timestamp << std::endl;
274*3e777be0SXin Li }
275*3e777be0SXin Li fileStream.close();
276*3e777be0SXin Li }
277*3e777be0SXin Li
278*3e777be0SXin Li vector<bool> result;
279*3e777be0SXin Li
280*3e777be0SXin Li if (!runtime)
281*3e777be0SXin Li {
282*3e777be0SXin Li cb(HalErrorStatus::DEVICE_UNAVAILABLE, result);
283*3e777be0SXin Li return Void();
284*3e777be0SXin Li }
285*3e777be0SXin Li
286*3e777be0SXin Li // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
287*3e777be0SXin Li if (!android::nn::validateModel(model))
288*3e777be0SXin Li {
289*3e777be0SXin Li cb(HalErrorStatus::INVALID_ARGUMENT, result);
290*3e777be0SXin Li return Void();
291*3e777be0SXin Li }
292*3e777be0SXin Li
293*3e777be0SXin Li // Attempt to convert the model to an ArmNN input network (INetwork).
294*3e777be0SXin Li ModelToINetworkConverter<HalPolicy> modelConverter(options.GetBackends(),
295*3e777be0SXin Li model,
296*3e777be0SXin Li options.GetForcedUnsupportedOperations());
297*3e777be0SXin Li
298*3e777be0SXin Li if (modelConverter.GetConversionResult() != ConversionResult::Success
299*3e777be0SXin Li && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
300*3e777be0SXin Li {
301*3e777be0SXin Li cb(HalErrorStatus::GENERAL_FAILURE, result);
302*3e777be0SXin Li return Void();
303*3e777be0SXin Li }
304*3e777be0SXin Li
305*3e777be0SXin Li // Check each operation if it was converted successfully and copy the flags
306*3e777be0SXin Li // into the result (vector<bool>) that we need to return to Android.
307*3e777be0SXin Li result.reserve(getMainModel(model).operations.size());
308*3e777be0SXin Li for (uint32_t operationIdx = 0;
309*3e777be0SXin Li operationIdx < getMainModel(model).operations.size();
310*3e777be0SXin Li ++operationIdx)
311*3e777be0SXin Li {
312*3e777be0SXin Li bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
313*3e777be0SXin Li result.push_back(operationSupported);
314*3e777be0SXin Li }
315*3e777be0SXin Li
316*3e777be0SXin Li cb(HalErrorStatus::NONE, result);
317*3e777be0SXin Li return Void();
318*3e777be0SXin Li }
319*3e777be0SXin Li
320*3e777be0SXin Li template<typename HalPolicy>
getStatus()321*3e777be0SXin Li Return<V1_0::DeviceStatus> ArmnnDriverImpl<HalPolicy>::getStatus()
322*3e777be0SXin Li {
323*3e777be0SXin Li ALOGV("ArmnnDriver::getStatus()");
324*3e777be0SXin Li
325*3e777be0SXin Li return V1_0::DeviceStatus::AVAILABLE;
326*3e777be0SXin Li }
327*3e777be0SXin Li
328*3e777be0SXin Li ///
329*3e777be0SXin Li /// Class template specializations
330*3e777be0SXin Li ///
331*3e777be0SXin Li
332*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_0::HalPolicy>;
333*3e777be0SXin Li
334*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_1
335*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
336*3e777be0SXin Li #endif
337*3e777be0SXin Li
338*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_2
339*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
340*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
341*3e777be0SXin Li #endif
342*3e777be0SXin Li
343*3e777be0SXin Li #ifdef ARMNN_ANDROID_NN_V1_3
344*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_1::HalPolicy>;
345*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_2::HalPolicy>;
346*3e777be0SXin Li template class ArmnnDriverImpl<hal_1_3::HalPolicy>;
347*3e777be0SXin Li #endif
348*3e777be0SXin Li
349*3e777be0SXin Li } // namespace armnn_driver
350