1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriverImpl.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnPreparedModel.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "ModelToINetworkTransformer.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "SystemPropertiesUtils.hpp"
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <log/log.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <sys/stat.h>
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker namespace
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker
GenerateCapabilities()19*89c4ff92SAndroid Build Coastguard Worker Capabilities GenerateCapabilities()
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriverImpl::GenerateCapabilities()";
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker float defaultPerfValue = .1f;
24*89c4ff92SAndroid Build Coastguard Worker const Capabilities::PerformanceInfo defaultPerfInfo = { /* execTime */ defaultPerfValue,
25*89c4ff92SAndroid Build Coastguard Worker /* powerUsage */ defaultPerfValue
26*89c4ff92SAndroid Build Coastguard Worker };
27*89c4ff92SAndroid Build Coastguard Worker std::vector<OperandType> operandsTypes({
28*89c4ff92SAndroid Build Coastguard Worker OperandType::FLOAT32,
29*89c4ff92SAndroid Build Coastguard Worker OperandType::INT32,
30*89c4ff92SAndroid Build Coastguard Worker OperandType::UINT32,
31*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_FLOAT32,
32*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_INT32,
33*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_QUANT8_ASYMM,
34*89c4ff92SAndroid Build Coastguard Worker OperandType::BOOL,
35*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_QUANT16_SYMM,
36*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_FLOAT16,
37*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_BOOL8,
38*89c4ff92SAndroid Build Coastguard Worker OperandType::FLOAT16,
39*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
40*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_QUANT16_ASYMM,
41*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_QUANT8_SYMM,
42*89c4ff92SAndroid Build Coastguard Worker OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
43*89c4ff92SAndroid Build Coastguard Worker });
44*89c4ff92SAndroid Build Coastguard Worker
45*89c4ff92SAndroid Build Coastguard Worker std::vector<Capabilities::OperandPerformance> operandPerformances;
46*89c4ff92SAndroid Build Coastguard Worker operandPerformances.reserve(operandsTypes.size());
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker for (auto opType : operandsTypes)
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker operandPerformances.push_back(
51*89c4ff92SAndroid Build Coastguard Worker Capabilities::OperandPerformance{ /* type */ opType, /* info */ defaultPerfInfo });
52*89c4ff92SAndroid Build Coastguard Worker }
53*89c4ff92SAndroid Build Coastguard Worker
54*89c4ff92SAndroid Build Coastguard Worker auto operandPerformanceTable =
55*89c4ff92SAndroid Build Coastguard Worker Capabilities::OperandPerformanceTable::create(std::move(operandPerformances)).value();
56*89c4ff92SAndroid Build Coastguard Worker
57*89c4ff92SAndroid Build Coastguard Worker return { /* relaxedFloat32toFloat16PerformanceScalar */ defaultPerfInfo,
58*89c4ff92SAndroid Build Coastguard Worker /* relaxedFloat32toFloat16PerformanceTensor */ defaultPerfInfo,
59*89c4ff92SAndroid Build Coastguard Worker /* operandPerformance */ std::move(operandPerformanceTable),
60*89c4ff92SAndroid Build Coastguard Worker /* ifPerformance */ defaultPerfInfo,
61*89c4ff92SAndroid Build Coastguard Worker /* whilePerformance */ defaultPerfInfo };
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
Hash(std::vector<uint8_t> & cacheData)64*89c4ff92SAndroid Build Coastguard Worker size_t Hash(std::vector<uint8_t>& cacheData)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker std::size_t hash = cacheData.size();
67*89c4ff92SAndroid Build Coastguard Worker for (auto& i : cacheData)
68*89c4ff92SAndroid Build Coastguard Worker {
69*89c4ff92SAndroid Build Coastguard Worker hash = ((hash << 5) - hash) + i;
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker return hash;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker
74*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker
ValidateSharedHandle(const SharedHandle & sharedHandle)81*89c4ff92SAndroid Build Coastguard Worker bool ArmnnDriverImpl::ValidateSharedHandle(const SharedHandle& sharedHandle)
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker bool valid = true;
84*89c4ff92SAndroid Build Coastguard Worker
85*89c4ff92SAndroid Build Coastguard Worker if (*sharedHandle < 0)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker return !valid;
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker int dataCacheFileAccessMode = fcntl(*sharedHandle, F_GETFL) & O_ACCMODE;
91*89c4ff92SAndroid Build Coastguard Worker if (dataCacheFileAccessMode != O_RDWR)
92*89c4ff92SAndroid Build Coastguard Worker {
93*89c4ff92SAndroid Build Coastguard Worker return !valid;
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker return valid;
97*89c4ff92SAndroid Build Coastguard Worker }
98*89c4ff92SAndroid Build Coastguard Worker
PrepareArmnnModel(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const Model & model,const std::vector<SharedHandle> & modelCacheHandle,const std::vector<SharedHandle> & dataCacheHandle,const CacheToken & token,bool float32ToFloat16,Priority priority)99*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
100*89c4ff92SAndroid Build Coastguard Worker const armnn::IRuntimePtr& runtime,
101*89c4ff92SAndroid Build Coastguard Worker const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
102*89c4ff92SAndroid Build Coastguard Worker const DriverOptions& options,
103*89c4ff92SAndroid Build Coastguard Worker const Model& model,
104*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& modelCacheHandle,
105*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& dataCacheHandle,
106*89c4ff92SAndroid Build Coastguard Worker const CacheToken& token,
107*89c4ff92SAndroid Build Coastguard Worker bool float32ToFloat16,
108*89c4ff92SAndroid Build Coastguard Worker Priority priority)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriverImpl::PrepareArmnnModel()";
111*89c4ff92SAndroid Build Coastguard Worker
112*89c4ff92SAndroid Build Coastguard Worker if (!runtime)
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device unavailable";
115*89c4ff92SAndroid Build Coastguard Worker }
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker if (const auto result = validate(model); !result.ok())
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid model passed as input";
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker
122*89c4ff92SAndroid Build Coastguard Worker // Deliberately ignore any unsupported operations requested by the options -
123*89c4ff92SAndroid Build Coastguard Worker // at this point we're being asked to prepare a model that we've already declared support for
124*89c4ff92SAndroid Build Coastguard Worker // and the operation indices may be different to those in getSupportedOperations anyway.
125*89c4ff92SAndroid Build Coastguard Worker std::set<unsigned int> unsupportedOperations;
126*89c4ff92SAndroid Build Coastguard Worker ModelToINetworkTransformer modelConverter(options.GetBackends(),
127*89c4ff92SAndroid Build Coastguard Worker model,
128*89c4ff92SAndroid Build Coastguard Worker unsupportedOperations);
129*89c4ff92SAndroid Build Coastguard Worker
130*89c4ff92SAndroid Build Coastguard Worker if (modelConverter.GetConversionResult() != ConversionResult::Success)
131*89c4ff92SAndroid Build Coastguard Worker {
132*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "ModelToINetworkConverter failed";
133*89c4ff92SAndroid Build Coastguard Worker }
134*89c4ff92SAndroid Build Coastguard Worker
135*89c4ff92SAndroid Build Coastguard Worker // Serialize the network graph to a .armnn file if an output directory
136*89c4ff92SAndroid Build Coastguard Worker // has been specified in the drivers' arguments.
137*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> dataCacheData;
138*89c4ff92SAndroid Build Coastguard Worker bool serializeToFile = dataCacheHandle.size() < 1 ? false : true;
139*89c4ff92SAndroid Build Coastguard Worker auto serializedNetworkFileName =
140*89c4ff92SAndroid Build Coastguard Worker SerializeNetwork(*modelConverter.GetINetwork(),
141*89c4ff92SAndroid Build Coastguard Worker options.GetRequestInputsAndOutputsDumpDir(),
142*89c4ff92SAndroid Build Coastguard Worker dataCacheData,
143*89c4ff92SAndroid Build Coastguard Worker serializeToFile);
144*89c4ff92SAndroid Build Coastguard Worker
145*89c4ff92SAndroid Build Coastguard Worker // Optimize the network
146*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
147*89c4ff92SAndroid Build Coastguard Worker armnn::OptimizerOptionsOpaque OptOptions;
148*89c4ff92SAndroid Build Coastguard Worker OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
149*89c4ff92SAndroid Build Coastguard Worker OptOptions.SetProfilingEnabled(options.IsGpuProfilingEnabled());
150*89c4ff92SAndroid Build Coastguard Worker
151*89c4ff92SAndroid Build Coastguard Worker int cachedFd = -1;
152*89c4ff92SAndroid Build Coastguard Worker bool saveCachedNetwork = options.SaveCachedNetwork();
153*89c4ff92SAndroid Build Coastguard Worker
154*89c4ff92SAndroid Build Coastguard Worker unsigned int numberOfCachedModelFiles = 0;
155*89c4ff92SAndroid Build Coastguard Worker if (modelCacheHandle.size() > 0)
156*89c4ff92SAndroid Build Coastguard Worker {
157*89c4ff92SAndroid Build Coastguard Worker unsigned int index = 0;
158*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : options.GetBackends())
159*89c4ff92SAndroid Build Coastguard Worker {
160*89c4ff92SAndroid Build Coastguard Worker // modelCacheHandle size should be equal to numberOfCachedModelFiles
161*89c4ff92SAndroid Build Coastguard Worker // modelCacheHandle vector should be in same order as backends
162*89c4ff92SAndroid Build Coastguard Worker auto numberOfCacheFiles = GetNumberOfCacheFiles(backend);
163*89c4ff92SAndroid Build Coastguard Worker if (numberOfCacheFiles > 0)
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker numberOfCachedModelFiles += numberOfCacheFiles;
166*89c4ff92SAndroid Build Coastguard Worker // For GpuAcc numberOfCachedFiles is 1
167*89c4ff92SAndroid Build Coastguard Worker if (backend == armnn::Compute::GpuAcc)
168*89c4ff92SAndroid Build Coastguard Worker {
169*89c4ff92SAndroid Build Coastguard Worker cachedFd = *modelCacheHandle[index];
170*89c4ff92SAndroid Build Coastguard Worker saveCachedNetwork = true;
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker index += numberOfCachedModelFiles;
173*89c4ff92SAndroid Build Coastguard Worker }
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker }
176*89c4ff92SAndroid Build Coastguard Worker
177*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions gpuAcc("GpuAcc",
178*89c4ff92SAndroid Build Coastguard Worker {
179*89c4ff92SAndroid Build Coastguard Worker { "FastMathEnabled", options.IsFastMathEnabled() },
180*89c4ff92SAndroid Build Coastguard Worker { "SaveCachedNetwork", saveCachedNetwork },
181*89c4ff92SAndroid Build Coastguard Worker { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
182*89c4ff92SAndroid Build Coastguard Worker { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() },
183*89c4ff92SAndroid Build Coastguard Worker { "CachedFileDescriptor", cachedFd }
184*89c4ff92SAndroid Build Coastguard Worker });
185*89c4ff92SAndroid Build Coastguard Worker
186*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions cpuAcc("CpuAcc",
187*89c4ff92SAndroid Build Coastguard Worker {
188*89c4ff92SAndroid Build Coastguard Worker { "FastMathEnabled", options.IsFastMathEnabled() },
189*89c4ff92SAndroid Build Coastguard Worker { "NumberOfThreads", options.GetNumberOfThreads() }
190*89c4ff92SAndroid Build Coastguard Worker });
191*89c4ff92SAndroid Build Coastguard Worker OptOptions.AddModelOption(gpuAcc);
192*89c4ff92SAndroid Build Coastguard Worker OptOptions.AddModelOption(cpuAcc);
193*89c4ff92SAndroid Build Coastguard Worker
194*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> errMessages;
195*89c4ff92SAndroid Build Coastguard Worker try
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker optNet = armnn::Optimize(*modelConverter.GetINetwork(),
198*89c4ff92SAndroid Build Coastguard Worker options.GetBackends(),
199*89c4ff92SAndroid Build Coastguard Worker runtime->GetDeviceSpec(),
200*89c4ff92SAndroid Build Coastguard Worker OptOptions,
201*89c4ff92SAndroid Build Coastguard Worker errMessages);
202*89c4ff92SAndroid Build Coastguard Worker }
203*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& e)
204*89c4ff92SAndroid Build Coastguard Worker {
205*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << e.what();
206*89c4ff92SAndroid Build Coastguard Worker }
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker // Check that the optimized network is valid.
209*89c4ff92SAndroid Build Coastguard Worker if (!optNet)
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker std::stringstream message;
212*89c4ff92SAndroid Build Coastguard Worker message << "Invalid optimized network";
213*89c4ff92SAndroid Build Coastguard Worker for (const std::string& msg : errMessages)
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker message << "\n" << msg;
216*89c4ff92SAndroid Build Coastguard Worker }
217*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker // Export the optimized network graph to a dot file if an output dump directory
221*89c4ff92SAndroid Build Coastguard Worker // has been specified in the drivers' arguments.
222*89c4ff92SAndroid Build Coastguard Worker std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet,
223*89c4ff92SAndroid Build Coastguard Worker options.GetRequestInputsAndOutputsDumpDir());
224*89c4ff92SAndroid Build Coastguard Worker
225*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime.
226*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId netId = 0;
227*89c4ff92SAndroid Build Coastguard Worker std::string msg;
228*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
229*89c4ff92SAndroid Build Coastguard Worker MemorySource::Undefined,
230*89c4ff92SAndroid Build Coastguard Worker MemorySource::Undefined,
231*89c4ff92SAndroid Build Coastguard Worker options.IsGpuProfilingEnabled());
232*89c4ff92SAndroid Build Coastguard Worker auto numInputs = getMainModel(model).inputIndexes.size();
233*89c4ff92SAndroid Build Coastguard Worker auto numOutputs = getMainModel(model).outputIndexes.size();
234*89c4ff92SAndroid Build Coastguard Worker try
235*89c4ff92SAndroid Build Coastguard Worker {
236*89c4ff92SAndroid Build Coastguard Worker if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
237*89c4ff92SAndroid Build Coastguard Worker {
238*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Network could not be loaded";
239*89c4ff92SAndroid Build Coastguard Worker }
240*89c4ff92SAndroid Build Coastguard Worker }
241*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& e)
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker std::stringstream message;
244*89c4ff92SAndroid Build Coastguard Worker message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
245*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
246*89c4ff92SAndroid Build Coastguard Worker }
247*89c4ff92SAndroid Build Coastguard Worker
248*89c4ff92SAndroid Build Coastguard Worker // Now that we have a networkId for the graph rename the exported files to use it
249*89c4ff92SAndroid Build Coastguard Worker // so that we can associate the graph file and the input/output tensor exported files
250*89c4ff92SAndroid Build Coastguard Worker RenameExportedFiles(serializedNetworkFileName,
251*89c4ff92SAndroid Build Coastguard Worker dotGraphFileName,
252*89c4ff92SAndroid Build Coastguard Worker options.GetRequestInputsAndOutputsDumpDir(),
253*89c4ff92SAndroid Build Coastguard Worker netId);
254*89c4ff92SAndroid Build Coastguard Worker
255*89c4ff92SAndroid Build Coastguard Worker // Cache the model
256*89c4ff92SAndroid Build Coastguard Worker size_t hashValue = 0;
257*89c4ff92SAndroid Build Coastguard Worker if (dataCacheHandle.size() == 1 )
258*89c4ff92SAndroid Build Coastguard Worker {
259*89c4ff92SAndroid Build Coastguard Worker hashValue = Hash(dataCacheData);
260*89c4ff92SAndroid Build Coastguard Worker }
261*89c4ff92SAndroid Build Coastguard Worker
262*89c4ff92SAndroid Build Coastguard Worker // Cache the model data
263*89c4ff92SAndroid Build Coastguard Worker if (modelCacheHandle.size() > 0)
264*89c4ff92SAndroid Build Coastguard Worker {
265*89c4ff92SAndroid Build Coastguard Worker if (modelCacheHandle.size() == numberOfCachedModelFiles)
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker for (uint32_t i = 0; i < modelCacheHandle.size(); ++i)
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker int modelCacheFileAccessMode = fcntl(*modelCacheHandle[i], F_GETFL) & O_ACCMODE;
270*89c4ff92SAndroid Build Coastguard Worker if (modelCacheFileAccessMode != O_RDONLY)
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker struct stat statBuffer;
273*89c4ff92SAndroid Build Coastguard Worker if (fstat(*modelCacheHandle[i], &statBuffer) == 0)
274*89c4ff92SAndroid Build Coastguard Worker {
275*89c4ff92SAndroid Build Coastguard Worker long modelDataSize = statBuffer.st_size;
276*89c4ff92SAndroid Build Coastguard Worker if (modelDataSize > 0)
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> modelData(modelDataSize);
279*89c4ff92SAndroid Build Coastguard Worker pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0);
280*89c4ff92SAndroid Build Coastguard Worker hashValue ^= Hash(modelData);
281*89c4ff92SAndroid Build Coastguard Worker }
282*89c4ff92SAndroid Build Coastguard Worker }
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker }
285*89c4ff92SAndroid Build Coastguard Worker }
286*89c4ff92SAndroid Build Coastguard Worker }
287*89c4ff92SAndroid Build Coastguard Worker if (dataCacheHandle.size() == 1 && hashValue != 0)
288*89c4ff92SAndroid Build Coastguard Worker {
289*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> theHashValue(sizeof(hashValue));
290*89c4ff92SAndroid Build Coastguard Worker ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue));
291*89c4ff92SAndroid Build Coastguard Worker
292*89c4ff92SAndroid Build Coastguard Worker write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size());
293*89c4ff92SAndroid Build Coastguard Worker pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size());
294*89c4ff92SAndroid Build Coastguard Worker }
295*89c4ff92SAndroid Build Coastguard Worker
296*89c4ff92SAndroid Build Coastguard Worker bool executeWithDummyInputs = (std::find(options.GetBackends().begin(),
297*89c4ff92SAndroid Build Coastguard Worker options.GetBackends().end(),
298*89c4ff92SAndroid Build Coastguard Worker armnn::Compute::GpuAcc) != options.GetBackends().end());
299*89c4ff92SAndroid Build Coastguard Worker
300*89c4ff92SAndroid Build Coastguard Worker auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
301*89c4ff92SAndroid Build Coastguard Worker runtime.get(),
302*89c4ff92SAndroid Build Coastguard Worker model,
303*89c4ff92SAndroid Build Coastguard Worker options.GetRequestInputsAndOutputsDumpDir(),
304*89c4ff92SAndroid Build Coastguard Worker options.IsGpuProfilingEnabled(),
305*89c4ff92SAndroid Build Coastguard Worker priority);
306*89c4ff92SAndroid Build Coastguard Worker
307*89c4ff92SAndroid Build Coastguard Worker // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
308*89c4ff92SAndroid Build Coastguard Worker // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
309*89c4ff92SAndroid Build Coastguard Worker // Only run this if the GpuAcc backend has been added to options
310*89c4ff92SAndroid Build Coastguard Worker if (std::find(options.GetBackends().begin(),
311*89c4ff92SAndroid Build Coastguard Worker options.GetBackends().end(),
312*89c4ff92SAndroid Build Coastguard Worker armnn::Compute::GpuAcc) != options.GetBackends().end())
313*89c4ff92SAndroid Build Coastguard Worker {
314*89c4ff92SAndroid Build Coastguard Worker if (!preparedModel->ExecuteWithDummyInputs(numInputs, numOutputs))
315*89c4ff92SAndroid Build Coastguard Worker {
316*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Network could not be executed";
317*89c4ff92SAndroid Build Coastguard Worker }
318*89c4ff92SAndroid Build Coastguard Worker
319*89c4ff92SAndroid Build Coastguard Worker if (clTunedParameters &&
320*89c4ff92SAndroid Build Coastguard Worker options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
321*89c4ff92SAndroid Build Coastguard Worker {
322*89c4ff92SAndroid Build Coastguard Worker // Now that we've done one inference the CL kernel parameters will have been tuned,
323*89c4ff92SAndroid Build Coastguard Worker // so save the updated file.
324*89c4ff92SAndroid Build Coastguard Worker try
325*89c4ff92SAndroid Build Coastguard Worker {
326*89c4ff92SAndroid Build Coastguard Worker clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
327*89c4ff92SAndroid Build Coastguard Worker }
328*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& error)
329*89c4ff92SAndroid Build Coastguard Worker {
330*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file"
331*89c4ff92SAndroid Build Coastguard Worker << options.GetClTunedParametersFile().c_str() << error.what();
332*89c4ff92SAndroid Build Coastguard Worker }
333*89c4ff92SAndroid Build Coastguard Worker }
334*89c4ff92SAndroid Build Coastguard Worker }
335*89c4ff92SAndroid Build Coastguard Worker return std::move(preparedModel);
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker
PrepareArmnnModelFromCache(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const std::vector<SharedHandle> & modelCacheHandle,const std::vector<SharedHandle> & dataCacheHandle,const CacheToken & token,bool float32ToFloat16)338*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
339*89c4ff92SAndroid Build Coastguard Worker const armnn::IRuntimePtr& runtime,
340*89c4ff92SAndroid Build Coastguard Worker const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
341*89c4ff92SAndroid Build Coastguard Worker const DriverOptions& options,
342*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& modelCacheHandle,
343*89c4ff92SAndroid Build Coastguard Worker const std::vector<SharedHandle>& dataCacheHandle,
344*89c4ff92SAndroid Build Coastguard Worker const CacheToken& token,
345*89c4ff92SAndroid Build Coastguard Worker bool float32ToFloat16)
346*89c4ff92SAndroid Build Coastguard Worker {
347*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriverImpl::PrepareArmnnModelFromCache()";
348*89c4ff92SAndroid Build Coastguard Worker
349*89c4ff92SAndroid Build Coastguard Worker if (!runtime)
350*89c4ff92SAndroid Build Coastguard Worker {
351*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE)
352*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Device unavailable";
353*89c4ff92SAndroid Build Coastguard Worker }
354*89c4ff92SAndroid Build Coastguard Worker
355*89c4ff92SAndroid Build Coastguard Worker if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN)
356*89c4ff92SAndroid Build Coastguard Worker {
357*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
358*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Token size does not match!";
359*89c4ff92SAndroid Build Coastguard Worker }
360*89c4ff92SAndroid Build Coastguard Worker
361*89c4ff92SAndroid Build Coastguard Worker // Validate dataCacheHandle
362*89c4ff92SAndroid Build Coastguard Worker if (dataCacheHandle.size() != 1)
363*89c4ff92SAndroid Build Coastguard Worker {
364*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
365*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
366*89c4ff92SAndroid Build Coastguard Worker }
367*89c4ff92SAndroid Build Coastguard Worker
368*89c4ff92SAndroid Build Coastguard Worker if (!ValidateSharedHandle(dataCacheHandle[0]))
369*89c4ff92SAndroid Build Coastguard Worker {
370*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
371*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
372*89c4ff92SAndroid Build Coastguard Worker }
373*89c4ff92SAndroid Build Coastguard Worker
374*89c4ff92SAndroid Build Coastguard Worker size_t cachedDataSize = 0;
375*89c4ff92SAndroid Build Coastguard Worker struct stat dataStatBuffer;
376*89c4ff92SAndroid Build Coastguard Worker if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0)
377*89c4ff92SAndroid Build Coastguard Worker {
378*89c4ff92SAndroid Build Coastguard Worker cachedDataSize = dataStatBuffer.st_size;
379*89c4ff92SAndroid Build Coastguard Worker }
380*89c4ff92SAndroid Build Coastguard Worker if (cachedDataSize == 0)
381*89c4ff92SAndroid Build Coastguard Worker {
382*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
383*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!";
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker
386*89c4ff92SAndroid Build Coastguard Worker // Check if model files cached they match the expected value
387*89c4ff92SAndroid Build Coastguard Worker unsigned int numberOfCachedModelFiles = 0;
388*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : options.GetBackends())
389*89c4ff92SAndroid Build Coastguard Worker {
390*89c4ff92SAndroid Build Coastguard Worker numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
391*89c4ff92SAndroid Build Coastguard Worker }
392*89c4ff92SAndroid Build Coastguard Worker if (modelCacheHandle.size() != numberOfCachedModelFiles)
393*89c4ff92SAndroid Build Coastguard Worker {
394*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
395*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match.";
396*89c4ff92SAndroid Build Coastguard Worker }
397*89c4ff92SAndroid Build Coastguard Worker
398*89c4ff92SAndroid Build Coastguard Worker // Read the hashValue
399*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> hashValue(sizeof(size_t));
400*89c4ff92SAndroid Build Coastguard Worker pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0);
401*89c4ff92SAndroid Build Coastguard Worker
402*89c4ff92SAndroid Build Coastguard Worker // Read the model
403*89c4ff92SAndroid Build Coastguard Worker if (cachedDataSize < hashValue.size())
404*89c4ff92SAndroid Build Coastguard Worker {
405*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
406*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): cachedDataSize is less than hashValue!";
407*89c4ff92SAndroid Build Coastguard Worker }
408*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> dataCacheData(cachedDataSize - hashValue.size());
409*89c4ff92SAndroid Build Coastguard Worker pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size());
410*89c4ff92SAndroid Build Coastguard Worker auto calculatedHashValue = Hash(dataCacheData);
411*89c4ff92SAndroid Build Coastguard Worker
412*89c4ff92SAndroid Build Coastguard Worker int gpuAccCachedFd = -1;
413*89c4ff92SAndroid Build Coastguard Worker if (modelCacheHandle.size() > 0)
414*89c4ff92SAndroid Build Coastguard Worker {
415*89c4ff92SAndroid Build Coastguard Worker unsigned int index = 0;
416*89c4ff92SAndroid Build Coastguard Worker for (auto& backend : options.GetBackends())
417*89c4ff92SAndroid Build Coastguard Worker {
418*89c4ff92SAndroid Build Coastguard Worker // modelCacheHandle size should be equal to numberOfCachedModelFiles
419*89c4ff92SAndroid Build Coastguard Worker // modelCacheHandle vector should be in same order as backends
420*89c4ff92SAndroid Build Coastguard Worker auto numberOfCacheFiles = GetNumberOfCacheFiles(backend);
421*89c4ff92SAndroid Build Coastguard Worker if (numberOfCacheFiles > 0)
422*89c4ff92SAndroid Build Coastguard Worker {
423*89c4ff92SAndroid Build Coastguard Worker if (!ValidateSharedHandle(modelCacheHandle[index]))
424*89c4ff92SAndroid Build Coastguard Worker {
425*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
426*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Invalid model cache handle!";
427*89c4ff92SAndroid Build Coastguard Worker }
428*89c4ff92SAndroid Build Coastguard Worker int cachedFd = *modelCacheHandle[index];
429*89c4ff92SAndroid Build Coastguard Worker struct stat statBuffer;
430*89c4ff92SAndroid Build Coastguard Worker if (fstat(cachedFd, &statBuffer) == 0)
431*89c4ff92SAndroid Build Coastguard Worker {
432*89c4ff92SAndroid Build Coastguard Worker long modelDataSize = statBuffer.st_size;
433*89c4ff92SAndroid Build Coastguard Worker if (modelDataSize > 0)
434*89c4ff92SAndroid Build Coastguard Worker {
435*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> modelData(modelDataSize);
436*89c4ff92SAndroid Build Coastguard Worker pread(cachedFd, modelData.data(), modelData.size(), 0);
437*89c4ff92SAndroid Build Coastguard Worker calculatedHashValue ^= Hash(modelData);
438*89c4ff92SAndroid Build Coastguard Worker
439*89c4ff92SAndroid Build Coastguard Worker if (backend == armnn::Compute::GpuAcc)
440*89c4ff92SAndroid Build Coastguard Worker {
441*89c4ff92SAndroid Build Coastguard Worker gpuAccCachedFd = cachedFd;
442*89c4ff92SAndroid Build Coastguard Worker }
443*89c4ff92SAndroid Build Coastguard Worker }
444*89c4ff92SAndroid Build Coastguard Worker }
445*89c4ff92SAndroid Build Coastguard Worker index += numberOfCacheFiles;
446*89c4ff92SAndroid Build Coastguard Worker }
447*89c4ff92SAndroid Build Coastguard Worker }
448*89c4ff92SAndroid Build Coastguard Worker }
449*89c4ff92SAndroid Build Coastguard Worker
450*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> calculatedHashData(sizeof(calculatedHashValue));
451*89c4ff92SAndroid Build Coastguard Worker ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue));
452*89c4ff92SAndroid Build Coastguard Worker if (hashValue != calculatedHashData)
453*89c4ff92SAndroid Build Coastguard Worker {
454*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
455*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!";
456*89c4ff92SAndroid Build Coastguard Worker }
457*89c4ff92SAndroid Build Coastguard Worker
458*89c4ff92SAndroid Build Coastguard Worker // Deserialize the network..
459*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
460*89c4ff92SAndroid Build Coastguard Worker try
461*89c4ff92SAndroid Build Coastguard Worker {
462*89c4ff92SAndroid Build Coastguard Worker network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
463*89c4ff92SAndroid Build Coastguard Worker }
464*89c4ff92SAndroid Build Coastguard Worker catch (std::exception&)
465*89c4ff92SAndroid Build Coastguard Worker {
466*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
467*89c4ff92SAndroid Build Coastguard Worker << "ArmnnDriverImpl::prepareModelFromCache(): Exception caught from Deserializer!";
468*89c4ff92SAndroid Build Coastguard Worker }
469*89c4ff92SAndroid Build Coastguard Worker
470*89c4ff92SAndroid Build Coastguard Worker // Optimize the network
471*89c4ff92SAndroid Build Coastguard Worker armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
472*89c4ff92SAndroid Build Coastguard Worker armnn::OptimizerOptionsOpaque OptOptions;
473*89c4ff92SAndroid Build Coastguard Worker OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
474*89c4ff92SAndroid Build Coastguard Worker OptOptions.SetProfilingEnabled(options.IsGpuProfilingEnabled());
475*89c4ff92SAndroid Build Coastguard Worker
476*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions gpuAcc("GpuAcc",
477*89c4ff92SAndroid Build Coastguard Worker {
478*89c4ff92SAndroid Build Coastguard Worker { "FastMathEnabled", options.IsFastMathEnabled() },
479*89c4ff92SAndroid Build Coastguard Worker { "SaveCachedNetwork", false },
480*89c4ff92SAndroid Build Coastguard Worker { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
481*89c4ff92SAndroid Build Coastguard Worker { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() },
482*89c4ff92SAndroid Build Coastguard Worker { "CachedFileDescriptor", gpuAccCachedFd }
483*89c4ff92SAndroid Build Coastguard Worker });
484*89c4ff92SAndroid Build Coastguard Worker
485*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions cpuAcc("CpuAcc",
486*89c4ff92SAndroid Build Coastguard Worker {
487*89c4ff92SAndroid Build Coastguard Worker { "FastMathEnabled", options.IsFastMathEnabled() },
488*89c4ff92SAndroid Build Coastguard Worker { "NumberOfThreads", options.GetNumberOfThreads() }
489*89c4ff92SAndroid Build Coastguard Worker });
490*89c4ff92SAndroid Build Coastguard Worker OptOptions.AddModelOption(gpuAcc);
491*89c4ff92SAndroid Build Coastguard Worker OptOptions.AddModelOption(cpuAcc);
492*89c4ff92SAndroid Build Coastguard Worker
493*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> errMessages;
494*89c4ff92SAndroid Build Coastguard Worker try
495*89c4ff92SAndroid Build Coastguard Worker {
496*89c4ff92SAndroid Build Coastguard Worker optNet = armnn::Optimize(*network.get(),
497*89c4ff92SAndroid Build Coastguard Worker options.GetBackends(),
498*89c4ff92SAndroid Build Coastguard Worker runtime->GetDeviceSpec(),
499*89c4ff92SAndroid Build Coastguard Worker OptOptions,
500*89c4ff92SAndroid Build Coastguard Worker errMessages);
501*89c4ff92SAndroid Build Coastguard Worker }
502*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& e)
503*89c4ff92SAndroid Build Coastguard Worker {
504*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << e.what();
505*89c4ff92SAndroid Build Coastguard Worker }
506*89c4ff92SAndroid Build Coastguard Worker
507*89c4ff92SAndroid Build Coastguard Worker // Check that the optimized network is valid.
508*89c4ff92SAndroid Build Coastguard Worker if (!optNet)
509*89c4ff92SAndroid Build Coastguard Worker {
510*89c4ff92SAndroid Build Coastguard Worker std::stringstream message;
511*89c4ff92SAndroid Build Coastguard Worker message << "Invalid optimized network";
512*89c4ff92SAndroid Build Coastguard Worker for (const std::string& msg : errMessages)
513*89c4ff92SAndroid Build Coastguard Worker {
514*89c4ff92SAndroid Build Coastguard Worker message << "\n" << msg;
515*89c4ff92SAndroid Build Coastguard Worker }
516*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
517*89c4ff92SAndroid Build Coastguard Worker }
518*89c4ff92SAndroid Build Coastguard Worker
519*89c4ff92SAndroid Build Coastguard Worker // Export the optimized network graph to a dot file if an output dump directory
520*89c4ff92SAndroid Build Coastguard Worker // has been specified in the drivers' arguments.
521*89c4ff92SAndroid Build Coastguard Worker std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet,
522*89c4ff92SAndroid Build Coastguard Worker options.GetRequestInputsAndOutputsDumpDir());
523*89c4ff92SAndroid Build Coastguard Worker
524*89c4ff92SAndroid Build Coastguard Worker // Load it into the runtime.
525*89c4ff92SAndroid Build Coastguard Worker armnn::NetworkId netId = 0;
526*89c4ff92SAndroid Build Coastguard Worker std::string msg;
527*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
528*89c4ff92SAndroid Build Coastguard Worker MemorySource::Undefined,
529*89c4ff92SAndroid Build Coastguard Worker MemorySource::Undefined,
530*89c4ff92SAndroid Build Coastguard Worker options.IsGpuProfilingEnabled());
531*89c4ff92SAndroid Build Coastguard Worker try
532*89c4ff92SAndroid Build Coastguard Worker {
533*89c4ff92SAndroid Build Coastguard Worker if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
534*89c4ff92SAndroid Build Coastguard Worker {
535*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Network could not be loaded";
536*89c4ff92SAndroid Build Coastguard Worker }
537*89c4ff92SAndroid Build Coastguard Worker }
538*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& e)
539*89c4ff92SAndroid Build Coastguard Worker {
540*89c4ff92SAndroid Build Coastguard Worker std::stringstream message;
541*89c4ff92SAndroid Build Coastguard Worker message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
542*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
543*89c4ff92SAndroid Build Coastguard Worker }
544*89c4ff92SAndroid Build Coastguard Worker
545*89c4ff92SAndroid Build Coastguard Worker auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
546*89c4ff92SAndroid Build Coastguard Worker runtime.get(),
547*89c4ff92SAndroid Build Coastguard Worker options.GetRequestInputsAndOutputsDumpDir(),
548*89c4ff92SAndroid Build Coastguard Worker options.IsGpuProfilingEnabled(),
549*89c4ff92SAndroid Build Coastguard Worker Priority::MEDIUM,
550*89c4ff92SAndroid Build Coastguard Worker true);
551*89c4ff92SAndroid Build Coastguard Worker return std::move(preparedModel);
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker
GetCapabilities(const armnn::IRuntimePtr & runtime)554*89c4ff92SAndroid Build Coastguard Worker const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime)
555*89c4ff92SAndroid Build Coastguard Worker {
556*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnDriverImpl::GetCapabilities()";
557*89c4ff92SAndroid Build Coastguard Worker static const Capabilities theCapabilities = GenerateCapabilities();
558*89c4ff92SAndroid Build Coastguard Worker return theCapabilities;
559*89c4ff92SAndroid Build Coastguard Worker }
560*89c4ff92SAndroid Build Coastguard Worker
561*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
562