xref: /aosp_15_r20/external/armnn/shim/sl/canonical/ArmnnDriverImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnDriverImpl.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnPreparedModel.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "ModelToINetworkTransformer.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "SystemPropertiesUtils.hpp"
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <log/log.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <sys/stat.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker namespace
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
GenerateCapabilities()19*89c4ff92SAndroid Build Coastguard Worker Capabilities GenerateCapabilities()
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "ArmnnDriverImpl::GenerateCapabilities()";
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker     float defaultPerfValue = .1f;
24*89c4ff92SAndroid Build Coastguard Worker     const Capabilities::PerformanceInfo defaultPerfInfo = { /* execTime */ defaultPerfValue,
25*89c4ff92SAndroid Build Coastguard Worker                                                             /* powerUsage */ defaultPerfValue
26*89c4ff92SAndroid Build Coastguard Worker                                                           };
27*89c4ff92SAndroid Build Coastguard Worker     std::vector<OperandType> operandsTypes({
28*89c4ff92SAndroid Build Coastguard Worker                 OperandType::FLOAT32,
29*89c4ff92SAndroid Build Coastguard Worker                 OperandType::INT32,
30*89c4ff92SAndroid Build Coastguard Worker                 OperandType::UINT32,
31*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_FLOAT32,
32*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_INT32,
33*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_QUANT8_ASYMM,
34*89c4ff92SAndroid Build Coastguard Worker                 OperandType::BOOL,
35*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_QUANT16_SYMM,
36*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_FLOAT16,
37*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_BOOL8,
38*89c4ff92SAndroid Build Coastguard Worker                 OperandType::FLOAT16,
39*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_QUANT8_SYMM_PER_CHANNEL,
40*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_QUANT16_ASYMM,
41*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_QUANT8_SYMM,
42*89c4ff92SAndroid Build Coastguard Worker                 OperandType::TENSOR_QUANT8_ASYMM_SIGNED,
43*89c4ff92SAndroid Build Coastguard Worker     });
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     std::vector<Capabilities::OperandPerformance> operandPerformances;
46*89c4ff92SAndroid Build Coastguard Worker     operandPerformances.reserve(operandsTypes.size());
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     for (auto opType : operandsTypes)
49*89c4ff92SAndroid Build Coastguard Worker     {
50*89c4ff92SAndroid Build Coastguard Worker         operandPerformances.push_back(
51*89c4ff92SAndroid Build Coastguard Worker                 Capabilities::OperandPerformance{ /* type */ opType, /* info */ defaultPerfInfo });
52*89c4ff92SAndroid Build Coastguard Worker     }
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     auto operandPerformanceTable =
55*89c4ff92SAndroid Build Coastguard Worker                Capabilities::OperandPerformanceTable::create(std::move(operandPerformances)).value();
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker     return { /* relaxedFloat32toFloat16PerformanceScalar */ defaultPerfInfo,
58*89c4ff92SAndroid Build Coastguard Worker              /* relaxedFloat32toFloat16PerformanceTensor */ defaultPerfInfo,
59*89c4ff92SAndroid Build Coastguard Worker              /* operandPerformance */ std::move(operandPerformanceTable),
60*89c4ff92SAndroid Build Coastguard Worker              /* ifPerformance */ defaultPerfInfo,
61*89c4ff92SAndroid Build Coastguard Worker              /* whilePerformance */ defaultPerfInfo };
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
Hash(std::vector<uint8_t> & cacheData)64*89c4ff92SAndroid Build Coastguard Worker size_t Hash(std::vector<uint8_t>& cacheData)
65*89c4ff92SAndroid Build Coastguard Worker {
66*89c4ff92SAndroid Build Coastguard Worker     std::size_t hash = cacheData.size();
67*89c4ff92SAndroid Build Coastguard Worker     for (auto& i : cacheData)
68*89c4ff92SAndroid Build Coastguard Worker     {
69*89c4ff92SAndroid Build Coastguard Worker         hash = ((hash << 5) - hash) + i;
70*89c4ff92SAndroid Build Coastguard Worker     }
71*89c4ff92SAndroid Build Coastguard Worker     return hash;
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker 
ValidateSharedHandle(const SharedHandle & sharedHandle)81*89c4ff92SAndroid Build Coastguard Worker bool ArmnnDriverImpl::ValidateSharedHandle(const SharedHandle& sharedHandle)
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker     bool valid = true;
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker     if (*sharedHandle < 0)
86*89c4ff92SAndroid Build Coastguard Worker     {
87*89c4ff92SAndroid Build Coastguard Worker         return !valid;
88*89c4ff92SAndroid Build Coastguard Worker     }
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     int dataCacheFileAccessMode = fcntl(*sharedHandle, F_GETFL) & O_ACCMODE;
91*89c4ff92SAndroid Build Coastguard Worker     if (dataCacheFileAccessMode != O_RDWR)
92*89c4ff92SAndroid Build Coastguard Worker     {
93*89c4ff92SAndroid Build Coastguard Worker         return !valid;
94*89c4ff92SAndroid Build Coastguard Worker     }
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker     return valid;
97*89c4ff92SAndroid Build Coastguard Worker }
98*89c4ff92SAndroid Build Coastguard Worker 
PrepareArmnnModel(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const Model & model,const std::vector<SharedHandle> & modelCacheHandle,const std::vector<SharedHandle> & dataCacheHandle,const CacheToken & token,bool float32ToFloat16,Priority priority)99*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModel(
100*89c4ff92SAndroid Build Coastguard Worker     const armnn::IRuntimePtr& runtime,
101*89c4ff92SAndroid Build Coastguard Worker     const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
102*89c4ff92SAndroid Build Coastguard Worker     const DriverOptions& options,
103*89c4ff92SAndroid Build Coastguard Worker     const Model& model,
104*89c4ff92SAndroid Build Coastguard Worker     const std::vector<SharedHandle>& modelCacheHandle,
105*89c4ff92SAndroid Build Coastguard Worker     const std::vector<SharedHandle>& dataCacheHandle,
106*89c4ff92SAndroid Build Coastguard Worker     const CacheToken& token,
107*89c4ff92SAndroid Build Coastguard Worker     bool float32ToFloat16,
108*89c4ff92SAndroid Build Coastguard Worker     Priority priority)
109*89c4ff92SAndroid Build Coastguard Worker {
110*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "ArmnnDriverImpl::PrepareArmnnModel()";
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     if (!runtime)
113*89c4ff92SAndroid Build Coastguard Worker     {
114*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device unavailable";
115*89c4ff92SAndroid Build Coastguard Worker     }
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker     if (const auto result = validate(model); !result.ok())
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid model passed as input";
120*89c4ff92SAndroid Build Coastguard Worker     }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     // Deliberately ignore any unsupported operations requested by the options -
123*89c4ff92SAndroid Build Coastguard Worker     // at this point we're being asked to prepare a model that we've already declared support for
124*89c4ff92SAndroid Build Coastguard Worker     // and the operation indices may be different to those in getSupportedOperations anyway.
125*89c4ff92SAndroid Build Coastguard Worker     std::set<unsigned int> unsupportedOperations;
126*89c4ff92SAndroid Build Coastguard Worker     ModelToINetworkTransformer modelConverter(options.GetBackends(),
127*89c4ff92SAndroid Build Coastguard Worker                                               model,
128*89c4ff92SAndroid Build Coastguard Worker                                               unsupportedOperations);
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker     if (modelConverter.GetConversionResult() != ConversionResult::Success)
131*89c4ff92SAndroid Build Coastguard Worker     {
132*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "ModelToINetworkConverter failed";
133*89c4ff92SAndroid Build Coastguard Worker     }
134*89c4ff92SAndroid Build Coastguard Worker 
135*89c4ff92SAndroid Build Coastguard Worker     // Serialize the network graph to a .armnn file if an output directory
136*89c4ff92SAndroid Build Coastguard Worker     // has been specified in the drivers' arguments.
137*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> dataCacheData;
138*89c4ff92SAndroid Build Coastguard Worker     bool serializeToFile = dataCacheHandle.size() < 1 ? false : true;
139*89c4ff92SAndroid Build Coastguard Worker     auto serializedNetworkFileName =
140*89c4ff92SAndroid Build Coastguard Worker             SerializeNetwork(*modelConverter.GetINetwork(),
141*89c4ff92SAndroid Build Coastguard Worker                              options.GetRequestInputsAndOutputsDumpDir(),
142*89c4ff92SAndroid Build Coastguard Worker                              dataCacheData,
143*89c4ff92SAndroid Build Coastguard Worker                              serializeToFile);
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
146*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
147*89c4ff92SAndroid Build Coastguard Worker     armnn::OptimizerOptionsOpaque OptOptions;
148*89c4ff92SAndroid Build Coastguard Worker     OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
149*89c4ff92SAndroid Build Coastguard Worker     OptOptions.SetProfilingEnabled(options.IsGpuProfilingEnabled());
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker     int cachedFd = -1;
152*89c4ff92SAndroid Build Coastguard Worker     bool saveCachedNetwork = options.SaveCachedNetwork();
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     unsigned int numberOfCachedModelFiles = 0;
155*89c4ff92SAndroid Build Coastguard Worker     if (modelCacheHandle.size() > 0)
156*89c4ff92SAndroid Build Coastguard Worker     {
157*89c4ff92SAndroid Build Coastguard Worker         unsigned int index = 0;
158*89c4ff92SAndroid Build Coastguard Worker         for (auto& backend : options.GetBackends())
159*89c4ff92SAndroid Build Coastguard Worker         {
160*89c4ff92SAndroid Build Coastguard Worker             // modelCacheHandle size should be equal to numberOfCachedModelFiles
161*89c4ff92SAndroid Build Coastguard Worker             // modelCacheHandle vector should be in same order as backends
162*89c4ff92SAndroid Build Coastguard Worker             auto numberOfCacheFiles = GetNumberOfCacheFiles(backend);
163*89c4ff92SAndroid Build Coastguard Worker             if (numberOfCacheFiles > 0)
164*89c4ff92SAndroid Build Coastguard Worker             {
165*89c4ff92SAndroid Build Coastguard Worker                 numberOfCachedModelFiles += numberOfCacheFiles;
166*89c4ff92SAndroid Build Coastguard Worker                 // For GpuAcc numberOfCachedFiles is 1
167*89c4ff92SAndroid Build Coastguard Worker                 if (backend == armnn::Compute::GpuAcc)
168*89c4ff92SAndroid Build Coastguard Worker                 {
169*89c4ff92SAndroid Build Coastguard Worker                     cachedFd = *modelCacheHandle[index];
170*89c4ff92SAndroid Build Coastguard Worker                     saveCachedNetwork = true;
171*89c4ff92SAndroid Build Coastguard Worker                 }
172*89c4ff92SAndroid Build Coastguard Worker                 index += numberOfCachedModelFiles;
173*89c4ff92SAndroid Build Coastguard Worker             }
174*89c4ff92SAndroid Build Coastguard Worker         }
175*89c4ff92SAndroid Build Coastguard Worker     }
176*89c4ff92SAndroid Build Coastguard Worker 
177*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions gpuAcc("GpuAcc",
178*89c4ff92SAndroid Build Coastguard Worker     {
179*89c4ff92SAndroid Build Coastguard Worker         { "FastMathEnabled", options.IsFastMathEnabled() },
180*89c4ff92SAndroid Build Coastguard Worker         { "SaveCachedNetwork", saveCachedNetwork },
181*89c4ff92SAndroid Build Coastguard Worker         { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
182*89c4ff92SAndroid Build Coastguard Worker         { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() },
183*89c4ff92SAndroid Build Coastguard Worker         { "CachedFileDescriptor", cachedFd }
184*89c4ff92SAndroid Build Coastguard Worker     });
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions cpuAcc("CpuAcc",
187*89c4ff92SAndroid Build Coastguard Worker     {
188*89c4ff92SAndroid Build Coastguard Worker         { "FastMathEnabled", options.IsFastMathEnabled() },
189*89c4ff92SAndroid Build Coastguard Worker         { "NumberOfThreads", options.GetNumberOfThreads() }
190*89c4ff92SAndroid Build Coastguard Worker     });
191*89c4ff92SAndroid Build Coastguard Worker     OptOptions.AddModelOption(gpuAcc);
192*89c4ff92SAndroid Build Coastguard Worker     OptOptions.AddModelOption(cpuAcc);
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> errMessages;
195*89c4ff92SAndroid Build Coastguard Worker     try
196*89c4ff92SAndroid Build Coastguard Worker     {
197*89c4ff92SAndroid Build Coastguard Worker         optNet = armnn::Optimize(*modelConverter.GetINetwork(),
198*89c4ff92SAndroid Build Coastguard Worker                                  options.GetBackends(),
199*89c4ff92SAndroid Build Coastguard Worker                                  runtime->GetDeviceSpec(),
200*89c4ff92SAndroid Build Coastguard Worker                                  OptOptions,
201*89c4ff92SAndroid Build Coastguard Worker                                  errMessages);
202*89c4ff92SAndroid Build Coastguard Worker     }
203*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception& e)
204*89c4ff92SAndroid Build Coastguard Worker     {
205*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << e.what();
206*89c4ff92SAndroid Build Coastguard Worker     }
207*89c4ff92SAndroid Build Coastguard Worker 
208*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimized network is valid.
209*89c4ff92SAndroid Build Coastguard Worker     if (!optNet)
210*89c4ff92SAndroid Build Coastguard Worker     {
211*89c4ff92SAndroid Build Coastguard Worker         std::stringstream message;
212*89c4ff92SAndroid Build Coastguard Worker         message << "Invalid optimized network";
213*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& msg : errMessages)
214*89c4ff92SAndroid Build Coastguard Worker         {
215*89c4ff92SAndroid Build Coastguard Worker             message << "\n" << msg;
216*89c4ff92SAndroid Build Coastguard Worker         }
217*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
218*89c4ff92SAndroid Build Coastguard Worker     }
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker     // Export the optimized network graph to a dot file if an output dump directory
221*89c4ff92SAndroid Build Coastguard Worker     // has been specified in the drivers' arguments.
222*89c4ff92SAndroid Build Coastguard Worker     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet,
223*89c4ff92SAndroid Build Coastguard Worker                                                                options.GetRequestInputsAndOutputsDumpDir());
224*89c4ff92SAndroid Build Coastguard Worker 
225*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime.
226*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId netId = 0;
227*89c4ff92SAndroid Build Coastguard Worker     std::string msg;
228*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
229*89c4ff92SAndroid Build Coastguard Worker                                                 MemorySource::Undefined,
230*89c4ff92SAndroid Build Coastguard Worker                                                 MemorySource::Undefined,
231*89c4ff92SAndroid Build Coastguard Worker                                                 options.IsGpuProfilingEnabled());
232*89c4ff92SAndroid Build Coastguard Worker     auto numInputs  = getMainModel(model).inputIndexes.size();
233*89c4ff92SAndroid Build Coastguard Worker     auto numOutputs = getMainModel(model).outputIndexes.size();
234*89c4ff92SAndroid Build Coastguard Worker     try
235*89c4ff92SAndroid Build Coastguard Worker     {
236*89c4ff92SAndroid Build Coastguard Worker         if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
237*89c4ff92SAndroid Build Coastguard Worker         {
238*89c4ff92SAndroid Build Coastguard Worker             return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Network could not be loaded";
239*89c4ff92SAndroid Build Coastguard Worker         }
240*89c4ff92SAndroid Build Coastguard Worker     }
241*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception& e)
242*89c4ff92SAndroid Build Coastguard Worker     {
243*89c4ff92SAndroid Build Coastguard Worker         std::stringstream message;
244*89c4ff92SAndroid Build Coastguard Worker         message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
245*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
246*89c4ff92SAndroid Build Coastguard Worker     }
247*89c4ff92SAndroid Build Coastguard Worker 
248*89c4ff92SAndroid Build Coastguard Worker     // Now that we have a networkId for the graph rename the exported files to use it
249*89c4ff92SAndroid Build Coastguard Worker     // so that we can associate the graph file and the input/output tensor exported files
250*89c4ff92SAndroid Build Coastguard Worker     RenameExportedFiles(serializedNetworkFileName,
251*89c4ff92SAndroid Build Coastguard Worker                         dotGraphFileName,
252*89c4ff92SAndroid Build Coastguard Worker                         options.GetRequestInputsAndOutputsDumpDir(),
253*89c4ff92SAndroid Build Coastguard Worker                         netId);
254*89c4ff92SAndroid Build Coastguard Worker 
255*89c4ff92SAndroid Build Coastguard Worker     // Cache the model
256*89c4ff92SAndroid Build Coastguard Worker     size_t hashValue = 0;
257*89c4ff92SAndroid Build Coastguard Worker     if (dataCacheHandle.size() == 1 )
258*89c4ff92SAndroid Build Coastguard Worker     {
259*89c4ff92SAndroid Build Coastguard Worker         hashValue = Hash(dataCacheData);
260*89c4ff92SAndroid Build Coastguard Worker     }
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker     // Cache the model data
263*89c4ff92SAndroid Build Coastguard Worker     if (modelCacheHandle.size() > 0)
264*89c4ff92SAndroid Build Coastguard Worker     {
265*89c4ff92SAndroid Build Coastguard Worker         if (modelCacheHandle.size() == numberOfCachedModelFiles)
266*89c4ff92SAndroid Build Coastguard Worker         {
267*89c4ff92SAndroid Build Coastguard Worker             for (uint32_t i = 0; i < modelCacheHandle.size(); ++i)
268*89c4ff92SAndroid Build Coastguard Worker             {
269*89c4ff92SAndroid Build Coastguard Worker                 int modelCacheFileAccessMode = fcntl(*modelCacheHandle[i], F_GETFL) & O_ACCMODE;
270*89c4ff92SAndroid Build Coastguard Worker                 if (modelCacheFileAccessMode != O_RDONLY)
271*89c4ff92SAndroid Build Coastguard Worker                 {
272*89c4ff92SAndroid Build Coastguard Worker                     struct stat statBuffer;
273*89c4ff92SAndroid Build Coastguard Worker                     if (fstat(*modelCacheHandle[i], &statBuffer) == 0)
274*89c4ff92SAndroid Build Coastguard Worker                     {
275*89c4ff92SAndroid Build Coastguard Worker                         long modelDataSize = statBuffer.st_size;
276*89c4ff92SAndroid Build Coastguard Worker                         if (modelDataSize > 0)
277*89c4ff92SAndroid Build Coastguard Worker                         {
278*89c4ff92SAndroid Build Coastguard Worker                             std::vector<uint8_t> modelData(modelDataSize);
279*89c4ff92SAndroid Build Coastguard Worker                             pread(*modelCacheHandle[i], modelData.data(), modelData.size(), 0);
280*89c4ff92SAndroid Build Coastguard Worker                             hashValue ^= Hash(modelData);
281*89c4ff92SAndroid Build Coastguard Worker                         }
282*89c4ff92SAndroid Build Coastguard Worker                     }
283*89c4ff92SAndroid Build Coastguard Worker                 }
284*89c4ff92SAndroid Build Coastguard Worker             }
285*89c4ff92SAndroid Build Coastguard Worker         }
286*89c4ff92SAndroid Build Coastguard Worker     }
287*89c4ff92SAndroid Build Coastguard Worker     if (dataCacheHandle.size() == 1 && hashValue != 0)
288*89c4ff92SAndroid Build Coastguard Worker     {
289*89c4ff92SAndroid Build Coastguard Worker         std::vector<uint8_t> theHashValue(sizeof(hashValue));
290*89c4ff92SAndroid Build Coastguard Worker         ::memcpy(theHashValue.data(), &hashValue, sizeof(hashValue));
291*89c4ff92SAndroid Build Coastguard Worker 
292*89c4ff92SAndroid Build Coastguard Worker         write(*dataCacheHandle[0], theHashValue.data(), theHashValue.size());
293*89c4ff92SAndroid Build Coastguard Worker         pwrite(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), theHashValue.size());
294*89c4ff92SAndroid Build Coastguard Worker     }
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker     bool executeWithDummyInputs = (std::find(options.GetBackends().begin(),
297*89c4ff92SAndroid Build Coastguard Worker                                             options.GetBackends().end(),
298*89c4ff92SAndroid Build Coastguard Worker                                             armnn::Compute::GpuAcc) != options.GetBackends().end());
299*89c4ff92SAndroid Build Coastguard Worker 
300*89c4ff92SAndroid Build Coastguard Worker     auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
301*89c4ff92SAndroid Build Coastguard Worker                                                                     runtime.get(),
302*89c4ff92SAndroid Build Coastguard Worker                                                                     model,
303*89c4ff92SAndroid Build Coastguard Worker                                                                     options.GetRequestInputsAndOutputsDumpDir(),
304*89c4ff92SAndroid Build Coastguard Worker                                                                     options.IsGpuProfilingEnabled(),
305*89c4ff92SAndroid Build Coastguard Worker                                                                     priority);
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker     // Run a single 'dummy' inference of the model. This means that CL kernels will get compiled (and tuned if
308*89c4ff92SAndroid Build Coastguard Worker     // this is enabled) before the first 'real' inference which removes the overhead of the first inference.
309*89c4ff92SAndroid Build Coastguard Worker     // Only run this if the GpuAcc backend has been added to options
310*89c4ff92SAndroid Build Coastguard Worker     if (std::find(options.GetBackends().begin(),
311*89c4ff92SAndroid Build Coastguard Worker                   options.GetBackends().end(),
312*89c4ff92SAndroid Build Coastguard Worker                   armnn::Compute::GpuAcc) != options.GetBackends().end())
313*89c4ff92SAndroid Build Coastguard Worker     {
314*89c4ff92SAndroid Build Coastguard Worker         if (!preparedModel->ExecuteWithDummyInputs(numInputs, numOutputs))
315*89c4ff92SAndroid Build Coastguard Worker         {
316*89c4ff92SAndroid Build Coastguard Worker             return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Network could not be executed";
317*89c4ff92SAndroid Build Coastguard Worker         }
318*89c4ff92SAndroid Build Coastguard Worker 
319*89c4ff92SAndroid Build Coastguard Worker         if (clTunedParameters &&
320*89c4ff92SAndroid Build Coastguard Worker             options.GetClTunedParametersMode() == armnn::IGpuAccTunedParameters::Mode::UpdateTunedParameters)
321*89c4ff92SAndroid Build Coastguard Worker         {
322*89c4ff92SAndroid Build Coastguard Worker             // Now that we've done one inference the CL kernel parameters will have been tuned,
323*89c4ff92SAndroid Build Coastguard Worker             // so save the updated file.
324*89c4ff92SAndroid Build Coastguard Worker             try
325*89c4ff92SAndroid Build Coastguard Worker             {
326*89c4ff92SAndroid Build Coastguard Worker                 clTunedParameters->Save(options.GetClTunedParametersFile().c_str());
327*89c4ff92SAndroid Build Coastguard Worker             }
328*89c4ff92SAndroid Build Coastguard Worker             catch (std::exception& error)
329*89c4ff92SAndroid Build Coastguard Worker             {
330*89c4ff92SAndroid Build Coastguard Worker                 VLOG(DRIVER) << "ArmnnDriverImpl::prepareModel: Failed to save CL tuned parameters file"
331*89c4ff92SAndroid Build Coastguard Worker                              << options.GetClTunedParametersFile().c_str() << error.what();
332*89c4ff92SAndroid Build Coastguard Worker             }
333*89c4ff92SAndroid Build Coastguard Worker         }
334*89c4ff92SAndroid Build Coastguard Worker     }
335*89c4ff92SAndroid Build Coastguard Worker     return std::move(preparedModel);
336*89c4ff92SAndroid Build Coastguard Worker }
337*89c4ff92SAndroid Build Coastguard Worker 
PrepareArmnnModelFromCache(const armnn::IRuntimePtr & runtime,const armnn::IGpuAccTunedParametersPtr & clTunedParameters,const DriverOptions & options,const std::vector<SharedHandle> & modelCacheHandle,const std::vector<SharedHandle> & dataCacheHandle,const CacheToken & token,bool float32ToFloat16)338*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedPreparedModel> ArmnnDriverImpl::PrepareArmnnModelFromCache(
339*89c4ff92SAndroid Build Coastguard Worker     const armnn::IRuntimePtr& runtime,
340*89c4ff92SAndroid Build Coastguard Worker     const armnn::IGpuAccTunedParametersPtr& clTunedParameters,
341*89c4ff92SAndroid Build Coastguard Worker     const DriverOptions& options,
342*89c4ff92SAndroid Build Coastguard Worker     const std::vector<SharedHandle>& modelCacheHandle,
343*89c4ff92SAndroid Build Coastguard Worker     const std::vector<SharedHandle>& dataCacheHandle,
344*89c4ff92SAndroid Build Coastguard Worker     const CacheToken& token,
345*89c4ff92SAndroid Build Coastguard Worker     bool float32ToFloat16)
346*89c4ff92SAndroid Build Coastguard Worker {
347*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "ArmnnDriverImpl::PrepareArmnnModelFromCache()";
348*89c4ff92SAndroid Build Coastguard Worker 
349*89c4ff92SAndroid Build Coastguard Worker     if (!runtime)
350*89c4ff92SAndroid Build Coastguard Worker     {
351*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE)
352*89c4ff92SAndroid Build Coastguard Worker                             << "ArmnnDriverImpl::prepareModelFromCache(): Device unavailable";
353*89c4ff92SAndroid Build Coastguard Worker     }
354*89c4ff92SAndroid Build Coastguard Worker 
355*89c4ff92SAndroid Build Coastguard Worker     if (token.size() != ANEURALNETWORKS_BYTE_SIZE_OF_CACHE_TOKEN)
356*89c4ff92SAndroid Build Coastguard Worker     {
357*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
358*89c4ff92SAndroid Build Coastguard Worker                             << "ArmnnDriverImpl::prepareModelFromCache(): Token size does not match!";
359*89c4ff92SAndroid Build Coastguard Worker     }
360*89c4ff92SAndroid Build Coastguard Worker 
361*89c4ff92SAndroid Build Coastguard Worker     // Validate dataCacheHandle
362*89c4ff92SAndroid Build Coastguard Worker     if (dataCacheHandle.size() != 1)
363*89c4ff92SAndroid Build Coastguard Worker     {
364*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
365*89c4ff92SAndroid Build Coastguard Worker                             << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
366*89c4ff92SAndroid Build Coastguard Worker     }
367*89c4ff92SAndroid Build Coastguard Worker 
368*89c4ff92SAndroid Build Coastguard Worker     if (!ValidateSharedHandle(dataCacheHandle[0]))
369*89c4ff92SAndroid Build Coastguard Worker     {
370*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
371*89c4ff92SAndroid Build Coastguard Worker                 << "ArmnnDriverImpl::prepareModelFromCache(): Not valid data cache handle!";
372*89c4ff92SAndroid Build Coastguard Worker     }
373*89c4ff92SAndroid Build Coastguard Worker 
374*89c4ff92SAndroid Build Coastguard Worker     size_t cachedDataSize = 0;
375*89c4ff92SAndroid Build Coastguard Worker     struct stat dataStatBuffer;
376*89c4ff92SAndroid Build Coastguard Worker     if (fstat(*dataCacheHandle[0], &dataStatBuffer) == 0)
377*89c4ff92SAndroid Build Coastguard Worker     {
378*89c4ff92SAndroid Build Coastguard Worker         cachedDataSize = dataStatBuffer.st_size;
379*89c4ff92SAndroid Build Coastguard Worker     }
380*89c4ff92SAndroid Build Coastguard Worker     if (cachedDataSize == 0)
381*89c4ff92SAndroid Build Coastguard Worker     {
382*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
383*89c4ff92SAndroid Build Coastguard Worker                 << "ArmnnDriverImpl::prepareModelFromCache(): Not valid cached data!";
384*89c4ff92SAndroid Build Coastguard Worker     }
385*89c4ff92SAndroid Build Coastguard Worker 
386*89c4ff92SAndroid Build Coastguard Worker     // Check if model files cached they match the expected value
387*89c4ff92SAndroid Build Coastguard Worker     unsigned int numberOfCachedModelFiles = 0;
388*89c4ff92SAndroid Build Coastguard Worker     for (auto& backend : options.GetBackends())
389*89c4ff92SAndroid Build Coastguard Worker     {
390*89c4ff92SAndroid Build Coastguard Worker         numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
391*89c4ff92SAndroid Build Coastguard Worker     }
392*89c4ff92SAndroid Build Coastguard Worker     if (modelCacheHandle.size() != numberOfCachedModelFiles)
393*89c4ff92SAndroid Build Coastguard Worker     {
394*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
395*89c4ff92SAndroid Build Coastguard Worker                            << "ArmnnDriverImpl::prepareModelFromCache(): Model cache handle size does not match.";
396*89c4ff92SAndroid Build Coastguard Worker     }
397*89c4ff92SAndroid Build Coastguard Worker 
398*89c4ff92SAndroid Build Coastguard Worker     // Read the hashValue
399*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> hashValue(sizeof(size_t));
400*89c4ff92SAndroid Build Coastguard Worker     pread(*dataCacheHandle[0], hashValue.data(), hashValue.size(), 0);
401*89c4ff92SAndroid Build Coastguard Worker 
402*89c4ff92SAndroid Build Coastguard Worker     // Read the model
403*89c4ff92SAndroid Build Coastguard Worker     if (cachedDataSize < hashValue.size())
404*89c4ff92SAndroid Build Coastguard Worker     {
405*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
406*89c4ff92SAndroid Build Coastguard Worker                 << "ArmnnDriverImpl::prepareModelFromCache(): cachedDataSize is less than hashValue!";
407*89c4ff92SAndroid Build Coastguard Worker     }
408*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> dataCacheData(cachedDataSize - hashValue.size());
409*89c4ff92SAndroid Build Coastguard Worker     pread(*dataCacheHandle[0], dataCacheData.data(), dataCacheData.size(), hashValue.size());
410*89c4ff92SAndroid Build Coastguard Worker     auto calculatedHashValue = Hash(dataCacheData);
411*89c4ff92SAndroid Build Coastguard Worker 
412*89c4ff92SAndroid Build Coastguard Worker     int gpuAccCachedFd = -1;
413*89c4ff92SAndroid Build Coastguard Worker     if (modelCacheHandle.size() > 0)
414*89c4ff92SAndroid Build Coastguard Worker     {
415*89c4ff92SAndroid Build Coastguard Worker         unsigned int index = 0;
416*89c4ff92SAndroid Build Coastguard Worker         for (auto& backend : options.GetBackends())
417*89c4ff92SAndroid Build Coastguard Worker         {
418*89c4ff92SAndroid Build Coastguard Worker             // modelCacheHandle size should be equal to numberOfCachedModelFiles
419*89c4ff92SAndroid Build Coastguard Worker             // modelCacheHandle vector should be in same order as backends
420*89c4ff92SAndroid Build Coastguard Worker             auto numberOfCacheFiles = GetNumberOfCacheFiles(backend);
421*89c4ff92SAndroid Build Coastguard Worker             if (numberOfCacheFiles > 0)
422*89c4ff92SAndroid Build Coastguard Worker             {
423*89c4ff92SAndroid Build Coastguard Worker                 if (!ValidateSharedHandle(modelCacheHandle[index]))
424*89c4ff92SAndroid Build Coastguard Worker                 {
425*89c4ff92SAndroid Build Coastguard Worker                     return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
426*89c4ff92SAndroid Build Coastguard Worker                             << "ArmnnDriverImpl::prepareModelFromCache(): Invalid model cache handle!";
427*89c4ff92SAndroid Build Coastguard Worker                 }
428*89c4ff92SAndroid Build Coastguard Worker                 int cachedFd = *modelCacheHandle[index];
429*89c4ff92SAndroid Build Coastguard Worker                 struct stat statBuffer;
430*89c4ff92SAndroid Build Coastguard Worker                 if (fstat(cachedFd, &statBuffer) == 0)
431*89c4ff92SAndroid Build Coastguard Worker                 {
432*89c4ff92SAndroid Build Coastguard Worker                     long modelDataSize = statBuffer.st_size;
433*89c4ff92SAndroid Build Coastguard Worker                     if (modelDataSize > 0)
434*89c4ff92SAndroid Build Coastguard Worker                     {
435*89c4ff92SAndroid Build Coastguard Worker                         std::vector<uint8_t> modelData(modelDataSize);
436*89c4ff92SAndroid Build Coastguard Worker                         pread(cachedFd, modelData.data(), modelData.size(), 0);
437*89c4ff92SAndroid Build Coastguard Worker                         calculatedHashValue ^= Hash(modelData);
438*89c4ff92SAndroid Build Coastguard Worker 
439*89c4ff92SAndroid Build Coastguard Worker                         if (backend == armnn::Compute::GpuAcc)
440*89c4ff92SAndroid Build Coastguard Worker                         {
441*89c4ff92SAndroid Build Coastguard Worker                             gpuAccCachedFd = cachedFd;
442*89c4ff92SAndroid Build Coastguard Worker                         }
443*89c4ff92SAndroid Build Coastguard Worker                     }
444*89c4ff92SAndroid Build Coastguard Worker                 }
445*89c4ff92SAndroid Build Coastguard Worker                 index += numberOfCacheFiles;
446*89c4ff92SAndroid Build Coastguard Worker             }
447*89c4ff92SAndroid Build Coastguard Worker         }
448*89c4ff92SAndroid Build Coastguard Worker     }
449*89c4ff92SAndroid Build Coastguard Worker 
450*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> calculatedHashData(sizeof(calculatedHashValue));
451*89c4ff92SAndroid Build Coastguard Worker     ::memcpy(calculatedHashData.data(), &calculatedHashValue, sizeof(calculatedHashValue));
452*89c4ff92SAndroid Build Coastguard Worker     if (hashValue != calculatedHashData)
453*89c4ff92SAndroid Build Coastguard Worker     {
454*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
455*89c4ff92SAndroid Build Coastguard Worker                 << "ArmnnDriverImpl::prepareModelFromCache(): ValidateHash() failed!";
456*89c4ff92SAndroid Build Coastguard Worker     }
457*89c4ff92SAndroid Build Coastguard Worker 
458*89c4ff92SAndroid Build Coastguard Worker     // Deserialize the network..
459*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = armnn::INetworkPtr(nullptr, [](armnn::INetwork*){});
460*89c4ff92SAndroid Build Coastguard Worker     try
461*89c4ff92SAndroid Build Coastguard Worker     {
462*89c4ff92SAndroid Build Coastguard Worker         network = armnnDeserializer::IDeserializer::Create()->CreateNetworkFromBinary(dataCacheData);
463*89c4ff92SAndroid Build Coastguard Worker     }
464*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception&)
465*89c4ff92SAndroid Build Coastguard Worker     {
466*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE)
467*89c4ff92SAndroid Build Coastguard Worker                 << "ArmnnDriverImpl::prepareModelFromCache(): Exception caught from Deserializer!";
468*89c4ff92SAndroid Build Coastguard Worker     }
469*89c4ff92SAndroid Build Coastguard Worker 
470*89c4ff92SAndroid Build Coastguard Worker     // Optimize the network
471*89c4ff92SAndroid Build Coastguard Worker     armnn::IOptimizedNetworkPtr optNet(nullptr, nullptr);
472*89c4ff92SAndroid Build Coastguard Worker     armnn::OptimizerOptionsOpaque OptOptions;
473*89c4ff92SAndroid Build Coastguard Worker     OptOptions.SetReduceFp32ToFp16(float32ToFloat16);
474*89c4ff92SAndroid Build Coastguard Worker     OptOptions.SetProfilingEnabled(options.IsGpuProfilingEnabled());
475*89c4ff92SAndroid Build Coastguard Worker 
476*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions gpuAcc("GpuAcc",
477*89c4ff92SAndroid Build Coastguard Worker     {
478*89c4ff92SAndroid Build Coastguard Worker         { "FastMathEnabled", options.IsFastMathEnabled() },
479*89c4ff92SAndroid Build Coastguard Worker         { "SaveCachedNetwork", false },
480*89c4ff92SAndroid Build Coastguard Worker         { "CachedNetworkFilePath", options.GetCachedNetworkFilePath() },
481*89c4ff92SAndroid Build Coastguard Worker         { "MLGOTuningFilePath", options.GetClMLGOTunedParametersFile() },
482*89c4ff92SAndroid Build Coastguard Worker         { "CachedFileDescriptor", gpuAccCachedFd }
483*89c4ff92SAndroid Build Coastguard Worker     });
484*89c4ff92SAndroid Build Coastguard Worker 
485*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions cpuAcc("CpuAcc",
486*89c4ff92SAndroid Build Coastguard Worker     {
487*89c4ff92SAndroid Build Coastguard Worker         { "FastMathEnabled", options.IsFastMathEnabled() },
488*89c4ff92SAndroid Build Coastguard Worker         { "NumberOfThreads", options.GetNumberOfThreads() }
489*89c4ff92SAndroid Build Coastguard Worker     });
490*89c4ff92SAndroid Build Coastguard Worker     OptOptions.AddModelOption(gpuAcc);
491*89c4ff92SAndroid Build Coastguard Worker     OptOptions.AddModelOption(cpuAcc);
492*89c4ff92SAndroid Build Coastguard Worker 
493*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> errMessages;
494*89c4ff92SAndroid Build Coastguard Worker     try
495*89c4ff92SAndroid Build Coastguard Worker     {
496*89c4ff92SAndroid Build Coastguard Worker         optNet = armnn::Optimize(*network.get(),
497*89c4ff92SAndroid Build Coastguard Worker                                  options.GetBackends(),
498*89c4ff92SAndroid Build Coastguard Worker                                  runtime->GetDeviceSpec(),
499*89c4ff92SAndroid Build Coastguard Worker                                  OptOptions,
500*89c4ff92SAndroid Build Coastguard Worker                                  errMessages);
501*89c4ff92SAndroid Build Coastguard Worker     }
502*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception& e)
503*89c4ff92SAndroid Build Coastguard Worker     {
504*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << e.what();
505*89c4ff92SAndroid Build Coastguard Worker     }
506*89c4ff92SAndroid Build Coastguard Worker 
507*89c4ff92SAndroid Build Coastguard Worker     // Check that the optimized network is valid.
508*89c4ff92SAndroid Build Coastguard Worker     if (!optNet)
509*89c4ff92SAndroid Build Coastguard Worker     {
510*89c4ff92SAndroid Build Coastguard Worker         std::stringstream message;
511*89c4ff92SAndroid Build Coastguard Worker         message << "Invalid optimized network";
512*89c4ff92SAndroid Build Coastguard Worker         for (const std::string& msg : errMessages)
513*89c4ff92SAndroid Build Coastguard Worker         {
514*89c4ff92SAndroid Build Coastguard Worker             message << "\n" << msg;
515*89c4ff92SAndroid Build Coastguard Worker         }
516*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
517*89c4ff92SAndroid Build Coastguard Worker     }
518*89c4ff92SAndroid Build Coastguard Worker 
519*89c4ff92SAndroid Build Coastguard Worker     // Export the optimized network graph to a dot file if an output dump directory
520*89c4ff92SAndroid Build Coastguard Worker     // has been specified in the drivers' arguments.
521*89c4ff92SAndroid Build Coastguard Worker     std::string dotGraphFileName = ExportNetworkGraphToDotFile(*optNet,
522*89c4ff92SAndroid Build Coastguard Worker                                                                options.GetRequestInputsAndOutputsDumpDir());
523*89c4ff92SAndroid Build Coastguard Worker 
524*89c4ff92SAndroid Build Coastguard Worker     // Load it into the runtime.
525*89c4ff92SAndroid Build Coastguard Worker     armnn::NetworkId netId = 0;
526*89c4ff92SAndroid Build Coastguard Worker     std::string msg;
527*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkProperties networkProperties(options.isAsyncModelExecutionEnabled(),
528*89c4ff92SAndroid Build Coastguard Worker                                                 MemorySource::Undefined,
529*89c4ff92SAndroid Build Coastguard Worker                                                 MemorySource::Undefined,
530*89c4ff92SAndroid Build Coastguard Worker                                                 options.IsGpuProfilingEnabled());
531*89c4ff92SAndroid Build Coastguard Worker     try
532*89c4ff92SAndroid Build Coastguard Worker     {
533*89c4ff92SAndroid Build Coastguard Worker         if (runtime->LoadNetwork(netId, move(optNet), msg, networkProperties) != armnn::Status::Success)
534*89c4ff92SAndroid Build Coastguard Worker         {
535*89c4ff92SAndroid Build Coastguard Worker             return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Network could not be loaded";
536*89c4ff92SAndroid Build Coastguard Worker         }
537*89c4ff92SAndroid Build Coastguard Worker     }
538*89c4ff92SAndroid Build Coastguard Worker     catch (std::exception& e)
539*89c4ff92SAndroid Build Coastguard Worker     {
540*89c4ff92SAndroid Build Coastguard Worker         std::stringstream message;
541*89c4ff92SAndroid Build Coastguard Worker         message << "Exception (" << e.what()<< ") caught from LoadNetwork.";
542*89c4ff92SAndroid Build Coastguard Worker         return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << message.str();
543*89c4ff92SAndroid Build Coastguard Worker     }
544*89c4ff92SAndroid Build Coastguard Worker 
545*89c4ff92SAndroid Build Coastguard Worker     auto preparedModel = std::make_shared<const ArmnnPreparedModel>(netId,
546*89c4ff92SAndroid Build Coastguard Worker                                                       runtime.get(),
547*89c4ff92SAndroid Build Coastguard Worker                                                       options.GetRequestInputsAndOutputsDumpDir(),
548*89c4ff92SAndroid Build Coastguard Worker                                                       options.IsGpuProfilingEnabled(),
549*89c4ff92SAndroid Build Coastguard Worker                                                       Priority::MEDIUM,
550*89c4ff92SAndroid Build Coastguard Worker                                                       true);
551*89c4ff92SAndroid Build Coastguard Worker     return std::move(preparedModel);
552*89c4ff92SAndroid Build Coastguard Worker }
553*89c4ff92SAndroid Build Coastguard Worker 
GetCapabilities(const armnn::IRuntimePtr & runtime)554*89c4ff92SAndroid Build Coastguard Worker const Capabilities& ArmnnDriverImpl::GetCapabilities(const armnn::IRuntimePtr& runtime)
555*89c4ff92SAndroid Build Coastguard Worker {
556*89c4ff92SAndroid Build Coastguard Worker     VLOG(DRIVER) << "ArmnnDriverImpl::GetCapabilities()";
557*89c4ff92SAndroid Build Coastguard Worker     static const Capabilities theCapabilities = GenerateCapabilities();
558*89c4ff92SAndroid Build Coastguard Worker     return theCapabilities;
559*89c4ff92SAndroid Build Coastguard Worker }
560*89c4ff92SAndroid Build Coastguard Worker 
561*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
562