1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #define LOG_TAG "arm-armnn-sl"
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "ArmnnPreparedModel.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "CanonicalUtils.hpp"
10*89c4ff92SAndroid Build Coastguard Worker
11*89c4ff92SAndroid Build Coastguard Worker #include <DefaultExecution.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <LegacyUtils.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IBurst.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/IPreparedModel.h>
15*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Result.h>
16*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/SharedMemory.h>
17*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/TypeUtils.h>
18*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Types.h>
19*89c4ff92SAndroid Build Coastguard Worker #include <nnapi/Validation.h>
20*89c4ff92SAndroid Build Coastguard Worker
21*89c4ff92SAndroid Build Coastguard Worker #include <memory>
22*89c4ff92SAndroid Build Coastguard Worker #include <tuple>
23*89c4ff92SAndroid Build Coastguard Worker #include <utility>
24*89c4ff92SAndroid Build Coastguard Worker #include <vector>
25*89c4ff92SAndroid Build Coastguard Worker
26*89c4ff92SAndroid Build Coastguard Worker using namespace android;
27*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker static const Timing g_NoTiming = {};
30*89c4ff92SAndroid Build Coastguard Worker
31*89c4ff92SAndroid Build Coastguard Worker namespace {
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker using namespace armnn_driver;
34*89c4ff92SAndroid Build Coastguard Worker
MicrosecondsDuration(android::nn::TimePoint endPoint,android::nn::TimePoint startPoint)35*89c4ff92SAndroid Build Coastguard Worker unsigned long MicrosecondsDuration(android::nn::TimePoint endPoint, android::nn::TimePoint startPoint)
36*89c4ff92SAndroid Build Coastguard Worker {
37*89c4ff92SAndroid Build Coastguard Worker return static_cast<unsigned long>(std::chrono::duration_cast<std::chrono::microseconds>(
38*89c4ff92SAndroid Build Coastguard Worker endPoint - startPoint).count());
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
ValidateRequestArgument(const Request::Argument & requestArg,const armnn::TensorInfo & tensorInfo)41*89c4ff92SAndroid Build Coastguard Worker bool ValidateRequestArgument(const Request::Argument& requestArg, const armnn::TensorInfo& tensorInfo)
42*89c4ff92SAndroid Build Coastguard Worker {
43*89c4ff92SAndroid Build Coastguard Worker if (requestArg.dimensions.size() != 0)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker if (requestArg.dimensions.size() != tensorInfo.GetNumDimensions())
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Mismatched dimensions (request argument: "
48*89c4ff92SAndroid Build Coastguard Worker << requestArg.dimensions.size() << " expected: " << tensorInfo.GetNumDimensions();
49*89c4ff92SAndroid Build Coastguard Worker return false;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker for (unsigned int d = 0; d < tensorInfo.GetNumDimensions(); ++d)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker if (requestArg.dimensions[d] != 0 && requestArg.dimensions[d] != tensorInfo.GetShape()[d])
55*89c4ff92SAndroid Build Coastguard Worker {
56*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Mismatched dimensions " << d
57*89c4ff92SAndroid Build Coastguard Worker << " (request argument: " << requestArg.dimensions[d]
58*89c4ff92SAndroid Build Coastguard Worker << " expected: " << tensorInfo.GetShape()[d];
59*89c4ff92SAndroid Build Coastguard Worker return false;
60*89c4ff92SAndroid Build Coastguard Worker }
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker
64*89c4ff92SAndroid Build Coastguard Worker return true;
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker
GetTensorForRequestArgument(const Request::Argument & requestArg,const armnn::TensorInfo & tensorInfo,const std::vector<::android::nn::RunTimePoolInfo> & requestPools)67*89c4ff92SAndroid Build Coastguard Worker armnn::Tensor GetTensorForRequestArgument(const Request::Argument& requestArg,
68*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& tensorInfo,
69*89c4ff92SAndroid Build Coastguard Worker const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker if (!ValidateRequestArgument(requestArg, tensorInfo))
72*89c4ff92SAndroid Build Coastguard Worker {
73*89c4ff92SAndroid Build Coastguard Worker return armnn::Tensor();
74*89c4ff92SAndroid Build Coastguard Worker }
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker if (requestArg.lifetime == Request::Argument::LifeTime::POINTER)
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker return armnn::Tensor(tensorInfo, GetMemoryFromPointer(requestArg));
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker else if (requestArg.lifetime == Request::Argument::LifeTime::POOL)
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker return armnn::Tensor(tensorInfo, GetMemoryFromPool(requestArg.location, requestPools));
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker return armnn::Tensor();
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker
BuildTensorName(const char * tensorNamePrefix,std::size_t index)87*89c4ff92SAndroid Build Coastguard Worker inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t index)
88*89c4ff92SAndroid Build Coastguard Worker {
89*89c4ff92SAndroid Build Coastguard Worker return tensorNamePrefix + std::to_string(index);
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker
IsPointerTypeMemory(const Request & request)92*89c4ff92SAndroid Build Coastguard Worker bool IsPointerTypeMemory(const Request& request)
93*89c4ff92SAndroid Build Coastguard Worker {
94*89c4ff92SAndroid Build Coastguard Worker for (auto& input : request.inputs)
95*89c4ff92SAndroid Build Coastguard Worker {
96*89c4ff92SAndroid Build Coastguard Worker if (input.lifetime != Request::Argument::LifeTime::POINTER)
97*89c4ff92SAndroid Build Coastguard Worker {
98*89c4ff92SAndroid Build Coastguard Worker return false;
99*89c4ff92SAndroid Build Coastguard Worker }
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker for (auto& output: request.outputs)
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker if (output.lifetime != Request::Argument::LifeTime::POINTER)
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker return false;
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker
110*89c4ff92SAndroid Build Coastguard Worker return true;
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker
113*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
114*89c4ff92SAndroid Build Coastguard Worker
115*89c4ff92SAndroid Build Coastguard Worker using namespace android::nn;
116*89c4ff92SAndroid Build Coastguard Worker
117*89c4ff92SAndroid Build Coastguard Worker namespace armnn_driver
118*89c4ff92SAndroid Build Coastguard Worker {
119*89c4ff92SAndroid Build Coastguard Worker
Init()120*89c4ff92SAndroid Build Coastguard Worker void ArmnnPreparedModel::Init()
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker // Enable profiling if required.
123*89c4ff92SAndroid Build Coastguard Worker m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker
ArmnnPreparedModel(armnn::NetworkId networkId,armnn::IRuntime * runtime,const Model & model,const std::string & requestInputsAndOutputsDumpDir,const bool gpuProfilingEnabled,Priority priority)126*89c4ff92SAndroid Build Coastguard Worker ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId,
127*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* runtime,
128*89c4ff92SAndroid Build Coastguard Worker const Model& model,
129*89c4ff92SAndroid Build Coastguard Worker const std::string& requestInputsAndOutputsDumpDir,
130*89c4ff92SAndroid Build Coastguard Worker const bool gpuProfilingEnabled,
131*89c4ff92SAndroid Build Coastguard Worker Priority priority)
132*89c4ff92SAndroid Build Coastguard Worker : m_NetworkId(networkId)
133*89c4ff92SAndroid Build Coastguard Worker , m_Runtime(runtime)
134*89c4ff92SAndroid Build Coastguard Worker , m_Model(model)
135*89c4ff92SAndroid Build Coastguard Worker , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
136*89c4ff92SAndroid Build Coastguard Worker , m_GpuProfilingEnabled(gpuProfilingEnabled)
137*89c4ff92SAndroid Build Coastguard Worker , m_ModelPriority(priority)
138*89c4ff92SAndroid Build Coastguard Worker , m_PrepareFromCache(false)
139*89c4ff92SAndroid Build Coastguard Worker {
140*89c4ff92SAndroid Build Coastguard Worker Init();
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker
ArmnnPreparedModel(armnn::NetworkId networkId,armnn::IRuntime * runtime,const std::string & requestInputsAndOutputsDumpDir,const bool gpuProfilingEnabled,Priority priority,const bool prepareModelFromCache)143*89c4ff92SAndroid Build Coastguard Worker ArmnnPreparedModel::ArmnnPreparedModel(armnn::NetworkId networkId,
144*89c4ff92SAndroid Build Coastguard Worker armnn::IRuntime* runtime,
145*89c4ff92SAndroid Build Coastguard Worker const std::string& requestInputsAndOutputsDumpDir,
146*89c4ff92SAndroid Build Coastguard Worker const bool gpuProfilingEnabled,
147*89c4ff92SAndroid Build Coastguard Worker Priority priority,
148*89c4ff92SAndroid Build Coastguard Worker const bool prepareModelFromCache)
149*89c4ff92SAndroid Build Coastguard Worker : m_NetworkId(networkId)
150*89c4ff92SAndroid Build Coastguard Worker , m_Runtime(runtime)
151*89c4ff92SAndroid Build Coastguard Worker , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
152*89c4ff92SAndroid Build Coastguard Worker , m_GpuProfilingEnabled(gpuProfilingEnabled)
153*89c4ff92SAndroid Build Coastguard Worker , m_ModelPriority(priority)
154*89c4ff92SAndroid Build Coastguard Worker , m_PrepareFromCache(prepareModelFromCache)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker Init();
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker
159*89c4ff92SAndroid Build Coastguard Worker
PrepareMemoryForInputs(armnn::InputTensors & inputs,const Request & request,const std::vector<android::nn::RunTimePoolInfo> & memPools) const160*89c4ff92SAndroid Build Coastguard Worker ErrorStatus ArmnnPreparedModel::PrepareMemoryForInputs(
161*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors& inputs,
162*89c4ff92SAndroid Build Coastguard Worker const Request& request,
163*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::RunTimePoolInfo>& memPools) const
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker inputs.reserve(request.inputs.size());
166*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < request.inputs.size(); i++)
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker const auto& inputArg = request.inputs[i];
169*89c4ff92SAndroid Build Coastguard Worker
170*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
171*89c4ff92SAndroid Build Coastguard Worker // inputs (of type InputTensors) is composed of a vector of ConstTensors.
172*89c4ff92SAndroid Build Coastguard Worker // Therefore, set all TensorInfo isConstant parameters of input Tensors to true.
173*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant();
174*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools);
175*89c4ff92SAndroid Build Coastguard Worker
176*89c4ff92SAndroid Build Coastguard Worker if (inputTensor.GetMemoryArea() == nullptr)
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Cannot execute request. Error converting request input " << i << "to tensor.";
179*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
180*89c4ff92SAndroid Build Coastguard Worker }
181*89c4ff92SAndroid Build Coastguard Worker inputs.emplace_back(i, inputTensor);
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker
184*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::NONE;
185*89c4ff92SAndroid Build Coastguard Worker }
186*89c4ff92SAndroid Build Coastguard Worker
PrepareMemoryForOutputs(armnn::OutputTensors & outputs,std::vector<OutputShape> & outputShapes,const Request & request,const std::vector<android::nn::RunTimePoolInfo> & memPools) const187*89c4ff92SAndroid Build Coastguard Worker ErrorStatus ArmnnPreparedModel::PrepareMemoryForOutputs(
188*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors& outputs,
189*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputShape> &outputShapes,
190*89c4ff92SAndroid Build Coastguard Worker const Request& request,
191*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::RunTimePoolInfo>& memPools) const
192*89c4ff92SAndroid Build Coastguard Worker {
193*89c4ff92SAndroid Build Coastguard Worker outputs.reserve(request.outputs.size());
194*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < request.outputs.size(); i++)
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker auto& outputArg = request.outputs[i];
197*89c4ff92SAndroid Build Coastguard Worker
198*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
199*89c4ff92SAndroid Build Coastguard Worker armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools);
200*89c4ff92SAndroid Build Coastguard Worker if (outputTensor.GetMemoryArea() == nullptr)
201*89c4ff92SAndroid Build Coastguard Worker {
202*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Cannot execute request. Error converting request output " << i << "to tensor.";
203*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker
206*89c4ff92SAndroid Build Coastguard Worker const size_t outputSize = outputTensorInfo.GetNumBytes();
207*89c4ff92SAndroid Build Coastguard Worker
208*89c4ff92SAndroid Build Coastguard Worker unsigned int count = 0;
209*89c4ff92SAndroid Build Coastguard Worker std::for_each(outputArg.dimensions.begin(), outputArg.dimensions.end(), [&](auto dim)
210*89c4ff92SAndroid Build Coastguard Worker {
211*89c4ff92SAndroid Build Coastguard Worker if (dim != 0)
212*89c4ff92SAndroid Build Coastguard Worker {
213*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.GetShape()[count] = dim;
214*89c4ff92SAndroid Build Coastguard Worker }
215*89c4ff92SAndroid Build Coastguard Worker else
216*89c4ff92SAndroid Build Coastguard Worker {
217*89c4ff92SAndroid Build Coastguard Worker outputTensorInfo.GetShape()[count] = outputArg.dimensions.size();
218*89c4ff92SAndroid Build Coastguard Worker }
219*89c4ff92SAndroid Build Coastguard Worker
220*89c4ff92SAndroid Build Coastguard Worker count++;
221*89c4ff92SAndroid Build Coastguard Worker });
222*89c4ff92SAndroid Build Coastguard Worker
223*89c4ff92SAndroid Build Coastguard Worker outputs.emplace_back(i, outputTensor);
224*89c4ff92SAndroid Build Coastguard Worker outputShapes[i] = ComputeShape(outputTensorInfo);
225*89c4ff92SAndroid Build Coastguard Worker
226*89c4ff92SAndroid Build Coastguard Worker if (outputArg.location.length < outputSize)
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::Execute failed outputArg.location.length "
229*89c4ff92SAndroid Build Coastguard Worker << std::to_string(outputArg.location.length).c_str()
230*89c4ff92SAndroid Build Coastguard Worker << " < outputSize " << std::to_string(outputSize).c_str();
231*89c4ff92SAndroid Build Coastguard Worker outputShapes[i].isSufficient = false;
232*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
233*89c4ff92SAndroid Build Coastguard Worker }
234*89c4ff92SAndroid Build Coastguard Worker
235*89c4ff92SAndroid Build Coastguard Worker //TODO: Need to check for Request::Argument::LifeTime::POINTER
236*89c4ff92SAndroid Build Coastguard Worker if (outputArg.lifetime == Request::Argument::LifeTime::POOL)
237*89c4ff92SAndroid Build Coastguard Worker {
238*89c4ff92SAndroid Build Coastguard Worker size_t bufferSize = memPools.at(outputArg.location.poolIndex).getSize();
239*89c4ff92SAndroid Build Coastguard Worker if (bufferSize < outputSize)
240*89c4ff92SAndroid Build Coastguard Worker {
241*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::Execute failed bufferSize "
242*89c4ff92SAndroid Build Coastguard Worker << std::to_string(outputArg.location.length).c_str()
243*89c4ff92SAndroid Build Coastguard Worker << " < outputSize " << std::to_string(outputSize).c_str();
244*89c4ff92SAndroid Build Coastguard Worker outputShapes[i].isSufficient = false;
245*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
246*89c4ff92SAndroid Build Coastguard Worker }
247*89c4ff92SAndroid Build Coastguard Worker }
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::NONE;
250*89c4ff92SAndroid Build Coastguard Worker }
251*89c4ff92SAndroid Build Coastguard Worker
PrepareMemoryForIO(armnn::InputTensors & inputs,armnn::OutputTensors & outputs,std::vector<android::nn::RunTimePoolInfo> & memPools,const Request & request,const bool pointerMemory) const252*89c4ff92SAndroid Build Coastguard Worker ErrorStatus ArmnnPreparedModel::PrepareMemoryForIO(armnn::InputTensors& inputs,
253*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors& outputs,
254*89c4ff92SAndroid Build Coastguard Worker std::vector<android::nn::RunTimePoolInfo>& memPools,
255*89c4ff92SAndroid Build Coastguard Worker const Request& request,
256*89c4ff92SAndroid Build Coastguard Worker const bool pointerMemory) const
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker //Check memory pools are not empty
259*89c4ff92SAndroid Build Coastguard Worker // add the inputs and outputs with their data
260*89c4ff92SAndroid Build Coastguard Worker try
261*89c4ff92SAndroid Build Coastguard Worker {
262*89c4ff92SAndroid Build Coastguard Worker if (!pointerMemory && !setRunTimePoolInfosFromMemoryPools(&memPools, request.pools))
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::INVALID_ARGUMENT;
265*89c4ff92SAndroid Build Coastguard Worker }
266*89c4ff92SAndroid Build Coastguard Worker
267*89c4ff92SAndroid Build Coastguard Worker if (PrepareMemoryForInputs(inputs, request, memPools) != ErrorStatus::NONE)
268*89c4ff92SAndroid Build Coastguard Worker {
269*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "Failed when preparing memory for Inputs";
270*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
271*89c4ff92SAndroid Build Coastguard Worker }
272*89c4ff92SAndroid Build Coastguard Worker
273*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputShape> outputShapes(request.outputs.size());
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Worker auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
276*89c4ff92SAndroid Build Coastguard Worker if (errorStatus != ErrorStatus::NONE)
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker return errorStatus;
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker }
281*89c4ff92SAndroid Build Coastguard Worker catch (armnn::Exception& e)
282*89c4ff92SAndroid Build Coastguard Worker {
283*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "armnn::Exception caught while preparing for EnqueueWorkload: " << e.what();
284*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
285*89c4ff92SAndroid Build Coastguard Worker }
286*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& e)
287*89c4ff92SAndroid Build Coastguard Worker {
288*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "std::exception caught while preparing for EnqueueWorkload: " << e.what();
289*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
290*89c4ff92SAndroid Build Coastguard Worker }
291*89c4ff92SAndroid Build Coastguard Worker
292*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::NONE;
293*89c4ff92SAndroid Build Coastguard Worker }
294*89c4ff92SAndroid Build Coastguard Worker
execute(const Request & request,MeasureTiming measureTiming,const OptionalTimePoint & deadline,const OptionalDuration &,const std::vector<android::nn::TokenValuePair> & hints,const std::vector<android::nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const295*89c4ff92SAndroid Build Coastguard Worker ExecutionResult<std::pair<std::vector<OutputShape>, Timing>> ArmnnPreparedModel::execute(
296*89c4ff92SAndroid Build Coastguard Worker const Request& request,
297*89c4ff92SAndroid Build Coastguard Worker MeasureTiming measureTiming,
298*89c4ff92SAndroid Build Coastguard Worker const OptionalTimePoint& deadline,
299*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration&,
300*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints,
301*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const
302*89c4ff92SAndroid Build Coastguard Worker {
303*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "CanonicalDriver::PreparedModel::execute()";
304*89c4ff92SAndroid Build Coastguard Worker
305*89c4ff92SAndroid Build Coastguard Worker CanonicalExecutionContext ctx;
306*89c4ff92SAndroid Build Coastguard Worker if (measureTiming == MeasureTiming::YES)
307*89c4ff92SAndroid Build Coastguard Worker {
308*89c4ff92SAndroid Build Coastguard Worker ctx.measureTimings = measureTiming;
309*89c4ff92SAndroid Build Coastguard Worker ctx.driverStart = Clock::now();
310*89c4ff92SAndroid Build Coastguard Worker }
311*89c4ff92SAndroid Build Coastguard Worker
312*89c4ff92SAndroid Build Coastguard Worker if (!m_PrepareFromCache)
313*89c4ff92SAndroid Build Coastguard Worker {
314*89c4ff92SAndroid Build Coastguard Worker const auto modelRequest = validateRequestForModel(request, m_Model);
315*89c4ff92SAndroid Build Coastguard Worker if (!modelRequest.ok())
316*89c4ff92SAndroid Build Coastguard Worker {
317*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << modelRequest.error();
318*89c4ff92SAndroid Build Coastguard Worker }
319*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::execute(): " << GetModelSummary(m_Model).c_str();
320*89c4ff92SAndroid Build Coastguard Worker }
321*89c4ff92SAndroid Build Coastguard Worker if (hasDeadlinePassed(deadline))
322*89c4ff92SAndroid Build Coastguard Worker {
323*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
324*89c4ff92SAndroid Build Coastguard Worker }
325*89c4ff92SAndroid Build Coastguard Worker
326*89c4ff92SAndroid Build Coastguard Worker // map the memory pool into shared pointers
327*89c4ff92SAndroid Build Coastguard Worker // use a shared memory pools vector on the heap, as it is passed to the request thread
328*89c4ff92SAndroid Build Coastguard Worker auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
329*89c4ff92SAndroid Build Coastguard Worker
330*89c4ff92SAndroid Build Coastguard Worker // allocate the tensors on the heap, as they are passed to the request thread
331*89c4ff92SAndroid Build Coastguard Worker auto inputTensors = std::make_shared<armnn::InputTensors>();
332*89c4ff92SAndroid Build Coastguard Worker auto outputTensors = std::make_shared<armnn::OutputTensors>();
333*89c4ff92SAndroid Build Coastguard Worker
334*89c4ff92SAndroid Build Coastguard Worker auto isPointerTypeMemory = IsPointerTypeMemory(request);
335*89c4ff92SAndroid Build Coastguard Worker ErrorStatus theErrorStatus = PrepareMemoryForIO(*inputTensors,
336*89c4ff92SAndroid Build Coastguard Worker *outputTensors,
337*89c4ff92SAndroid Build Coastguard Worker *memPools,
338*89c4ff92SAndroid Build Coastguard Worker request,
339*89c4ff92SAndroid Build Coastguard Worker isPointerTypeMemory);
340*89c4ff92SAndroid Build Coastguard Worker
341*89c4ff92SAndroid Build Coastguard Worker switch(theErrorStatus)
342*89c4ff92SAndroid Build Coastguard Worker {
343*89c4ff92SAndroid Build Coastguard Worker case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
344*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::OUTPUT_INSUFFICIENT_SIZE);
345*89c4ff92SAndroid Build Coastguard Worker case ErrorStatus::GENERAL_FAILURE:
346*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE);
347*89c4ff92SAndroid Build Coastguard Worker case ErrorStatus::INVALID_ARGUMENT:
348*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
349*89c4ff92SAndroid Build Coastguard Worker default:
350*89c4ff92SAndroid Build Coastguard Worker {}
351*89c4ff92SAndroid Build Coastguard Worker }
352*89c4ff92SAndroid Build Coastguard Worker
353*89c4ff92SAndroid Build Coastguard Worker std::vector<OutputShape> outputShapes(outputTensors->size());
354*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < outputTensors->size(); i++)
355*89c4ff92SAndroid Build Coastguard Worker {
356*89c4ff92SAndroid Build Coastguard Worker std::pair<int, armnn::Tensor> outputTensorPair = (*outputTensors)[i];
357*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor outputTensor = outputTensorPair.second;
358*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
359*89c4ff92SAndroid Build Coastguard Worker
360*89c4ff92SAndroid Build Coastguard Worker outputShapes[i] = ComputeShape(outputTensorInfo);
361*89c4ff92SAndroid Build Coastguard Worker }
362*89c4ff92SAndroid Build Coastguard Worker Timing theTiming;
363*89c4ff92SAndroid Build Coastguard Worker
364*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::execute(...) before ExecuteGraph";
365*89c4ff92SAndroid Build Coastguard Worker auto errorStatus = ExecuteGraph(memPools, *inputTensors, *outputTensors, ctx, isPointerTypeMemory);
366*89c4ff92SAndroid Build Coastguard Worker if (errorStatus != ErrorStatus::NONE)
367*89c4ff92SAndroid Build Coastguard Worker {
368*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(errorStatus) << "execute() failed";
369*89c4ff92SAndroid Build Coastguard Worker }
370*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::execute(...) after ExecuteGraph";
371*89c4ff92SAndroid Build Coastguard Worker
372*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(outputShapes, theTiming);
373*89c4ff92SAndroid Build Coastguard Worker }
374*89c4ff92SAndroid Build Coastguard Worker
ExecuteGraph(std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>> & pMemPools,armnn::InputTensors & inputTensors,armnn::OutputTensors & outputTensors,CanonicalExecutionContext ctx,const bool pointerMemory) const375*89c4ff92SAndroid Build Coastguard Worker ErrorStatus ArmnnPreparedModel::ExecuteGraph(
376*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<std::vector<android::nn::RunTimePoolInfo>>& pMemPools,
377*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors& inputTensors,
378*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors& outputTensors,
379*89c4ff92SAndroid Build Coastguard Worker CanonicalExecutionContext ctx,
380*89c4ff92SAndroid Build Coastguard Worker const bool pointerMemory) const
381*89c4ff92SAndroid Build Coastguard Worker {
382*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph(...)";
383*89c4ff92SAndroid Build Coastguard Worker
384*89c4ff92SAndroid Build Coastguard Worker DumpTensorsIfRequired("Input", inputTensors);
385*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::ImportedInputId> importedInputIds;
386*89c4ff92SAndroid Build Coastguard Worker std::vector<armnn::ImportedOutputId> importedOutputIds;
387*89c4ff92SAndroid Build Coastguard Worker try
388*89c4ff92SAndroid Build Coastguard Worker {
389*89c4ff92SAndroid Build Coastguard Worker if (ctx.measureTimings == MeasureTiming::YES)
390*89c4ff92SAndroid Build Coastguard Worker {
391*89c4ff92SAndroid Build Coastguard Worker ctx.deviceStart = Clock::now();
392*89c4ff92SAndroid Build Coastguard Worker }
393*89c4ff92SAndroid Build Coastguard Worker armnn::Status status;
394*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false";
395*89c4ff92SAndroid Build Coastguard Worker importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
396*89c4ff92SAndroid Build Coastguard Worker if (!importedInputIds.empty())
397*89c4ff92SAndroid Build Coastguard Worker {
398*89c4ff92SAndroid Build Coastguard Worker // Some or all of the input tensors been imported. We need to remove the ones that could from
399*89c4ff92SAndroid Build Coastguard Worker // inputTensors.
400*89c4ff92SAndroid Build Coastguard Worker for (armnn::ImportedInputId& importedId : importedInputIds)
401*89c4ff92SAndroid Build Coastguard Worker {
402*89c4ff92SAndroid Build Coastguard Worker inputTensors.erase(
403*89c4ff92SAndroid Build Coastguard Worker std::remove_if(
404*89c4ff92SAndroid Build Coastguard Worker inputTensors.begin(), inputTensors.end(),
405*89c4ff92SAndroid Build Coastguard Worker [&importedId](std::pair<armnn::LayerBindingId, class armnn::ConstTensor>& element) {
406*89c4ff92SAndroid Build Coastguard Worker return (element.first == static_cast<int>(importedId));
407*89c4ff92SAndroid Build Coastguard Worker }),
408*89c4ff92SAndroid Build Coastguard Worker inputTensors.end());
409*89c4ff92SAndroid Build Coastguard Worker }
410*89c4ff92SAndroid Build Coastguard Worker }
411*89c4ff92SAndroid Build Coastguard Worker importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
412*89c4ff92SAndroid Build Coastguard Worker if (!importedOutputIds.empty())
413*89c4ff92SAndroid Build Coastguard Worker {
414*89c4ff92SAndroid Build Coastguard Worker // Some or all of the output tensors could not be imported. We need to remove the ones that could
415*89c4ff92SAndroid Build Coastguard Worker // from outputTensors.
416*89c4ff92SAndroid Build Coastguard Worker for (armnn::ImportedInputId& importedId : importedOutputIds)
417*89c4ff92SAndroid Build Coastguard Worker {
418*89c4ff92SAndroid Build Coastguard Worker outputTensors.erase(
419*89c4ff92SAndroid Build Coastguard Worker std::remove_if(
420*89c4ff92SAndroid Build Coastguard Worker outputTensors.begin(), outputTensors.end(),
421*89c4ff92SAndroid Build Coastguard Worker [&importedId](std::pair<armnn::LayerBindingId, class armnn::Tensor>& element) {
422*89c4ff92SAndroid Build Coastguard Worker return (element.first == static_cast<int>(importedId));
423*89c4ff92SAndroid Build Coastguard Worker }),
424*89c4ff92SAndroid Build Coastguard Worker outputTensors.end());
425*89c4ff92SAndroid Build Coastguard Worker }
426*89c4ff92SAndroid Build Coastguard Worker }
427*89c4ff92SAndroid Build Coastguard Worker status = m_Runtime->EnqueueWorkload(m_NetworkId,
428*89c4ff92SAndroid Build Coastguard Worker inputTensors,
429*89c4ff92SAndroid Build Coastguard Worker outputTensors,
430*89c4ff92SAndroid Build Coastguard Worker importedInputIds,
431*89c4ff92SAndroid Build Coastguard Worker importedOutputIds);
432*89c4ff92SAndroid Build Coastguard Worker
433*89c4ff92SAndroid Build Coastguard Worker if (ctx.measureTimings == MeasureTiming::YES)
434*89c4ff92SAndroid Build Coastguard Worker {
435*89c4ff92SAndroid Build Coastguard Worker ctx.deviceEnd = Clock::now();
436*89c4ff92SAndroid Build Coastguard Worker }
437*89c4ff92SAndroid Build Coastguard Worker if (status != armnn::Status::Success)
438*89c4ff92SAndroid Build Coastguard Worker {
439*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel:ExecuteGraph EnqueueWorkload failed";
440*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
441*89c4ff92SAndroid Build Coastguard Worker }
442*89c4ff92SAndroid Build Coastguard Worker }
443*89c4ff92SAndroid Build Coastguard Worker catch (armnn::Exception& e)
444*89c4ff92SAndroid Build Coastguard Worker {
445*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "armnn:Exception caught from EnqueueWorkload: " << e.what();
446*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
447*89c4ff92SAndroid Build Coastguard Worker }
448*89c4ff92SAndroid Build Coastguard Worker catch (std::exception& e)
449*89c4ff92SAndroid Build Coastguard Worker {
450*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "std::exception caught from EnqueueWorkload: " << e.what();
451*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::GENERAL_FAILURE;
452*89c4ff92SAndroid Build Coastguard Worker }
453*89c4ff92SAndroid Build Coastguard Worker
454*89c4ff92SAndroid Build Coastguard Worker if (!pointerMemory && (!importedInputIds.empty() || !importedOutputIds.empty()))
455*89c4ff92SAndroid Build Coastguard Worker {
456*89c4ff92SAndroid Build Coastguard Worker CommitPools(*pMemPools);
457*89c4ff92SAndroid Build Coastguard Worker }
458*89c4ff92SAndroid Build Coastguard Worker DumpTensorsIfRequired("Output", outputTensors);
459*89c4ff92SAndroid Build Coastguard Worker
460*89c4ff92SAndroid Build Coastguard Worker if (ctx.measureTimings == MeasureTiming::YES)
461*89c4ff92SAndroid Build Coastguard Worker {
462*89c4ff92SAndroid Build Coastguard Worker ctx.driverEnd = Clock::now();
463*89c4ff92SAndroid Build Coastguard Worker Timing timing;
464*89c4ff92SAndroid Build Coastguard Worker timing.timeOnDevice = ctx.deviceEnd - ctx.deviceStart;
465*89c4ff92SAndroid Build Coastguard Worker timing.timeInDriver = ctx.driverEnd - ctx.driverStart;
466*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::execute timing - Device = "
467*89c4ff92SAndroid Build Coastguard Worker << timing.timeOnDevice << "Driver = " << timing.timeInDriver;
468*89c4ff92SAndroid Build Coastguard Worker }
469*89c4ff92SAndroid Build Coastguard Worker return ErrorStatus::NONE;
470*89c4ff92SAndroid Build Coastguard Worker }
471*89c4ff92SAndroid Build Coastguard Worker
GetModelPriority() const472*89c4ff92SAndroid Build Coastguard Worker Priority ArmnnPreparedModel::GetModelPriority() const
473*89c4ff92SAndroid Build Coastguard Worker {
474*89c4ff92SAndroid Build Coastguard Worker return m_ModelPriority;
475*89c4ff92SAndroid Build Coastguard Worker }
476*89c4ff92SAndroid Build Coastguard Worker
477*89c4ff92SAndroid Build Coastguard Worker
executeFenced(const Request & request,const std::vector<SyncFence> & waitFor,MeasureTiming measureTiming,const OptionalTimePoint & deadline,const OptionalDuration &,const OptionalDuration &,const std::vector<android::nn::TokenValuePair> & hints,const std::vector<android::nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const478*89c4ff92SAndroid Build Coastguard Worker GeneralResult<std::pair<SyncFence, ExecuteFencedInfoCallback>> ArmnnPreparedModel::executeFenced(
479*89c4ff92SAndroid Build Coastguard Worker const Request& request,
480*89c4ff92SAndroid Build Coastguard Worker const std::vector<SyncFence>& waitFor,
481*89c4ff92SAndroid Build Coastguard Worker MeasureTiming measureTiming,
482*89c4ff92SAndroid Build Coastguard Worker const OptionalTimePoint& deadline,
483*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration&,
484*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration&,
485*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints,
486*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const
487*89c4ff92SAndroid Build Coastguard Worker {
488*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::executeFenced()";
489*89c4ff92SAndroid Build Coastguard Worker
490*89c4ff92SAndroid Build Coastguard Worker if (!m_PrepareFromCache) {
491*89c4ff92SAndroid Build Coastguard Worker const auto modelRequest = validateRequestForModel(request, m_Model);
492*89c4ff92SAndroid Build Coastguard Worker if (!modelRequest.ok())
493*89c4ff92SAndroid Build Coastguard Worker {
494*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << modelRequest.error();
495*89c4ff92SAndroid Build Coastguard Worker }
496*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::executeFenced(): " << GetModelSummary(m_Model).c_str();
497*89c4ff92SAndroid Build Coastguard Worker }
498*89c4ff92SAndroid Build Coastguard Worker if (hasDeadlinePassed(deadline))
499*89c4ff92SAndroid Build Coastguard Worker {
500*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
501*89c4ff92SAndroid Build Coastguard Worker }
502*89c4ff92SAndroid Build Coastguard Worker
503*89c4ff92SAndroid Build Coastguard Worker CanonicalExecutionContext ctx;
504*89c4ff92SAndroid Build Coastguard Worker if (measureTiming == MeasureTiming::YES)
505*89c4ff92SAndroid Build Coastguard Worker {
506*89c4ff92SAndroid Build Coastguard Worker ctx.measureTimings = measureTiming;
507*89c4ff92SAndroid Build Coastguard Worker ctx.driverStart = Clock::now();
508*89c4ff92SAndroid Build Coastguard Worker }
509*89c4ff92SAndroid Build Coastguard Worker
510*89c4ff92SAndroid Build Coastguard Worker // Wait for the dependent events to signal
511*89c4ff92SAndroid Build Coastguard Worker for (const auto& syncFence : waitFor)
512*89c4ff92SAndroid Build Coastguard Worker {
513*89c4ff92SAndroid Build Coastguard Worker if (!syncFence.getSharedHandle())
514*89c4ff92SAndroid Build Coastguard Worker {
515*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT);
516*89c4ff92SAndroid Build Coastguard Worker }
517*89c4ff92SAndroid Build Coastguard Worker if (syncFence.syncWait({}) != SyncFence::FenceState::SIGNALED)
518*89c4ff92SAndroid Build Coastguard Worker {
519*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "syncWait failed";
520*89c4ff92SAndroid Build Coastguard Worker }
521*89c4ff92SAndroid Build Coastguard Worker }
522*89c4ff92SAndroid Build Coastguard Worker
523*89c4ff92SAndroid Build Coastguard Worker android::nn::TimePoint fenceExecutionStart;
524*89c4ff92SAndroid Build Coastguard Worker if (measureTiming == MeasureTiming::YES)
525*89c4ff92SAndroid Build Coastguard Worker {
526*89c4ff92SAndroid Build Coastguard Worker fenceExecutionStart = Clock::now();
527*89c4ff92SAndroid Build Coastguard Worker }
528*89c4ff92SAndroid Build Coastguard Worker
529*89c4ff92SAndroid Build Coastguard Worker // map the memory pool into shared pointers
530*89c4ff92SAndroid Build Coastguard Worker // use a shared memory pools vector on the heap, as it is passed to the request thread
531*89c4ff92SAndroid Build Coastguard Worker auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
532*89c4ff92SAndroid Build Coastguard Worker
533*89c4ff92SAndroid Build Coastguard Worker // allocate the tensors on the heap, as they are passed to the request thread
534*89c4ff92SAndroid Build Coastguard Worker auto inputTensors = std::make_shared<armnn::InputTensors>();
535*89c4ff92SAndroid Build Coastguard Worker auto outputTensors = std::make_shared<armnn::OutputTensors>();
536*89c4ff92SAndroid Build Coastguard Worker
537*89c4ff92SAndroid Build Coastguard Worker auto isPointerTypeMemory = IsPointerTypeMemory(request);
538*89c4ff92SAndroid Build Coastguard Worker ErrorStatus theErrorStatus = PrepareMemoryForIO(*inputTensors,
539*89c4ff92SAndroid Build Coastguard Worker *outputTensors,
540*89c4ff92SAndroid Build Coastguard Worker *memPools,
541*89c4ff92SAndroid Build Coastguard Worker request,
542*89c4ff92SAndroid Build Coastguard Worker isPointerTypeMemory);
543*89c4ff92SAndroid Build Coastguard Worker
544*89c4ff92SAndroid Build Coastguard Worker if (theErrorStatus != ErrorStatus::NONE)
545*89c4ff92SAndroid Build Coastguard Worker {
546*89c4ff92SAndroid Build Coastguard Worker return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "executeFenced() failed";
547*89c4ff92SAndroid Build Coastguard Worker }
548*89c4ff92SAndroid Build Coastguard Worker
549*89c4ff92SAndroid Build Coastguard Worker Timing timingSinceLaunch = {};
550*89c4ff92SAndroid Build Coastguard Worker Timing timingAfterFence = {};
551*89c4ff92SAndroid Build Coastguard Worker if (measureTiming == MeasureTiming::YES)
552*89c4ff92SAndroid Build Coastguard Worker {
553*89c4ff92SAndroid Build Coastguard Worker timingAfterFence.timeOnDevice = ctx.deviceEnd - ctx.deviceStart;
554*89c4ff92SAndroid Build Coastguard Worker timingAfterFence.timeInDriver = ctx.driverEnd - fenceExecutionStart;
555*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "executeFenced timingSinceLaunch = " << timingAfterFence.timeOnDevice;
556*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "executeFenced timingAfterFence = " << timingAfterFence.timeInDriver;
557*89c4ff92SAndroid Build Coastguard Worker }
558*89c4ff92SAndroid Build Coastguard Worker
559*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnCanonicalPreparedModel::executeFenced(...) before ExecuteGraph";
560*89c4ff92SAndroid Build Coastguard Worker auto errorStatus = ExecuteGraph(memPools, *inputTensors, *outputTensors, ctx, isPointerTypeMemory);
561*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnCanonicalPreparedModel::executeFenced(...) after ExecuteGraph";
562*89c4ff92SAndroid Build Coastguard Worker
563*89c4ff92SAndroid Build Coastguard Worker ExecuteFencedInfoCallback armnnFencedExecutionCallback =
564*89c4ff92SAndroid Build Coastguard Worker [timingSinceLaunch, timingAfterFence, errorStatus]() {
565*89c4ff92SAndroid Build Coastguard Worker
566*89c4ff92SAndroid Build Coastguard Worker GeneralResult<std::pair<Timing, Timing>> result;
567*89c4ff92SAndroid Build Coastguard Worker
568*89c4ff92SAndroid Build Coastguard Worker switch(errorStatus)
569*89c4ff92SAndroid Build Coastguard Worker {
570*89c4ff92SAndroid Build Coastguard Worker case ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
571*89c4ff92SAndroid Build Coastguard Worker result.error().code = (ErrorStatus::OUTPUT_INSUFFICIENT_SIZE);
572*89c4ff92SAndroid Build Coastguard Worker [[fallthrough]];
573*89c4ff92SAndroid Build Coastguard Worker case ErrorStatus::GENERAL_FAILURE:
574*89c4ff92SAndroid Build Coastguard Worker result.error().code = (ErrorStatus::GENERAL_FAILURE);
575*89c4ff92SAndroid Build Coastguard Worker [[fallthrough]];
576*89c4ff92SAndroid Build Coastguard Worker case ErrorStatus::INVALID_ARGUMENT:
577*89c4ff92SAndroid Build Coastguard Worker result.error().code = (ErrorStatus::INVALID_ARGUMENT);
578*89c4ff92SAndroid Build Coastguard Worker [[fallthrough]];
579*89c4ff92SAndroid Build Coastguard Worker default:
580*89c4ff92SAndroid Build Coastguard Worker {
581*89c4ff92SAndroid Build Coastguard Worker result.value() = std::make_pair(timingSinceLaunch, timingAfterFence);
582*89c4ff92SAndroid Build Coastguard Worker }
583*89c4ff92SAndroid Build Coastguard Worker }
584*89c4ff92SAndroid Build Coastguard Worker return result;
585*89c4ff92SAndroid Build Coastguard Worker };
586*89c4ff92SAndroid Build Coastguard Worker return std::make_pair(SyncFence::createAsSignaled(), std::move(armnnFencedExecutionCallback ));
587*89c4ff92SAndroid Build Coastguard Worker }
588*89c4ff92SAndroid Build Coastguard Worker
createReusableExecution(const Request & request,MeasureTiming measureTiming,const OptionalDuration & loopTimeoutDuration,const std::vector<android::nn::TokenValuePair> & hints,const std::vector<android::nn::ExtensionNameAndPrefix> & extensionNameToPrefix) const589*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedExecution> ArmnnPreparedModel::createReusableExecution(
590*89c4ff92SAndroid Build Coastguard Worker const Request& request,
591*89c4ff92SAndroid Build Coastguard Worker MeasureTiming measureTiming,
592*89c4ff92SAndroid Build Coastguard Worker const OptionalDuration& loopTimeoutDuration,
593*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::TokenValuePair>& hints,
594*89c4ff92SAndroid Build Coastguard Worker const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const
595*89c4ff92SAndroid Build Coastguard Worker {
596*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::createReusableExecution()";
597*89c4ff92SAndroid Build Coastguard Worker return std::make_shared<DefaultExecution>(shared_from_this(),
598*89c4ff92SAndroid Build Coastguard Worker request,
599*89c4ff92SAndroid Build Coastguard Worker measureTiming,
600*89c4ff92SAndroid Build Coastguard Worker loopTimeoutDuration);
601*89c4ff92SAndroid Build Coastguard Worker }
602*89c4ff92SAndroid Build Coastguard Worker
configureExecutionBurst() const603*89c4ff92SAndroid Build Coastguard Worker GeneralResult<SharedBurst> ArmnnPreparedModel::configureExecutionBurst() const
604*89c4ff92SAndroid Build Coastguard Worker {
605*89c4ff92SAndroid Build Coastguard Worker // TODO: Implement BURST
606*89c4ff92SAndroid Build Coastguard Worker return nullptr;
607*89c4ff92SAndroid Build Coastguard Worker }
608*89c4ff92SAndroid Build Coastguard Worker
getUnderlyingResource() const609*89c4ff92SAndroid Build Coastguard Worker std::any ArmnnPreparedModel::getUnderlyingResource() const
610*89c4ff92SAndroid Build Coastguard Worker {
611*89c4ff92SAndroid Build Coastguard Worker return &m_Model;
612*89c4ff92SAndroid Build Coastguard Worker }
613*89c4ff92SAndroid Build Coastguard Worker
614*89c4ff92SAndroid Build Coastguard Worker template<typename TensorBindingCollection>
DumpTensorsIfRequired(char const * tensorNamePrefix,const TensorBindingCollection & tensorBindings) const615*89c4ff92SAndroid Build Coastguard Worker void ArmnnPreparedModel::DumpTensorsIfRequired(char const* tensorNamePrefix,
616*89c4ff92SAndroid Build Coastguard Worker const TensorBindingCollection& tensorBindings) const
617*89c4ff92SAndroid Build Coastguard Worker {
618*89c4ff92SAndroid Build Coastguard Worker if (!m_RequestInputsAndOutputsDumpDir.empty())
619*89c4ff92SAndroid Build Coastguard Worker {
620*89c4ff92SAndroid Build Coastguard Worker const std::string requestName = std::to_string(m_NetworkId) + ".dump";
621*89c4ff92SAndroid Build Coastguard Worker for (std::size_t i = 0u; i < tensorBindings.size(); ++i)
622*89c4ff92SAndroid Build Coastguard Worker {
623*89c4ff92SAndroid Build Coastguard Worker DumpTensor(m_RequestInputsAndOutputsDumpDir,
624*89c4ff92SAndroid Build Coastguard Worker requestName,
625*89c4ff92SAndroid Build Coastguard Worker BuildTensorName(tensorNamePrefix, i),
626*89c4ff92SAndroid Build Coastguard Worker tensorBindings[i].second);
627*89c4ff92SAndroid Build Coastguard Worker }
628*89c4ff92SAndroid Build Coastguard Worker }
629*89c4ff92SAndroid Build Coastguard Worker }
630*89c4ff92SAndroid Build Coastguard Worker
~ArmnnPreparedModel()631*89c4ff92SAndroid Build Coastguard Worker ArmnnPreparedModel::~ArmnnPreparedModel()
632*89c4ff92SAndroid Build Coastguard Worker {
633*89c4ff92SAndroid Build Coastguard Worker VLOG(DRIVER) << "ArmnnPreparedModel::~ArmnnPreparedModel()";
634*89c4ff92SAndroid Build Coastguard Worker // Get a hold of the profiler used by this model.
635*89c4ff92SAndroid Build Coastguard Worker if (m_GpuProfilingEnabled)
636*89c4ff92SAndroid Build Coastguard Worker {
637*89c4ff92SAndroid Build Coastguard Worker auto profiler = m_Runtime->GetProfiler(m_NetworkId);
638*89c4ff92SAndroid Build Coastguard Worker if (profiler)
639*89c4ff92SAndroid Build Coastguard Worker {
640*89c4ff92SAndroid Build Coastguard Worker // Dump the profiling info to a file if required.
641*89c4ff92SAndroid Build Coastguard Worker DumpJsonProfilingIfRequired(m_GpuProfilingEnabled,
642*89c4ff92SAndroid Build Coastguard Worker m_RequestInputsAndOutputsDumpDir,
643*89c4ff92SAndroid Build Coastguard Worker m_NetworkId,
644*89c4ff92SAndroid Build Coastguard Worker profiler.get());
645*89c4ff92SAndroid Build Coastguard Worker }
646*89c4ff92SAndroid Build Coastguard Worker }
647*89c4ff92SAndroid Build Coastguard Worker // Unload the network associated with this model
648*89c4ff92SAndroid Build Coastguard Worker m_Runtime->UnloadNetwork(m_NetworkId);
649*89c4ff92SAndroid Build Coastguard Worker }
650*89c4ff92SAndroid Build Coastguard Worker
ExecuteWithDummyInputs(unsigned int numInputs,unsigned int numOutputs) const651*89c4ff92SAndroid Build Coastguard Worker bool ArmnnPreparedModel::ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs) const
652*89c4ff92SAndroid Build Coastguard Worker {
653*89c4ff92SAndroid Build Coastguard Worker std::vector<std::vector<char>> storage;
654*89c4ff92SAndroid Build Coastguard Worker armnn::InputTensors inputTensors;
655*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; i++)
656*89c4ff92SAndroid Build Coastguard Worker {
657*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
658*89c4ff92SAndroid Build Coastguard Worker // pInputTensors (of type InputTensors) is composed of a vector of ConstTensors.
659*89c4ff92SAndroid Build Coastguard Worker // Therefore, set all TensorInfo isConstant parameters of input Tensors to true.
660*89c4ff92SAndroid Build Coastguard Worker inputTensorInfo.SetConstant();
661*89c4ff92SAndroid Build Coastguard Worker storage.emplace_back(inputTensorInfo.GetNumBytes());
662*89c4ff92SAndroid Build Coastguard Worker const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
663*89c4ff92SAndroid Build Coastguard Worker
664*89c4ff92SAndroid Build Coastguard Worker inputTensors.emplace_back(i, inputTensor);
665*89c4ff92SAndroid Build Coastguard Worker }
666*89c4ff92SAndroid Build Coastguard Worker
667*89c4ff92SAndroid Build Coastguard Worker armnn::OutputTensors outputTensors;
668*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numOutputs; i++)
669*89c4ff92SAndroid Build Coastguard Worker {
670*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
671*89c4ff92SAndroid Build Coastguard Worker storage.emplace_back(outputTensorInfo.GetNumBytes());
672*89c4ff92SAndroid Build Coastguard Worker const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
673*89c4ff92SAndroid Build Coastguard Worker
674*89c4ff92SAndroid Build Coastguard Worker outputTensors.emplace_back(i, outputTensor);
675*89c4ff92SAndroid Build Coastguard Worker }
676*89c4ff92SAndroid Build Coastguard Worker CanonicalExecutionContext ctx;
677*89c4ff92SAndroid Build Coastguard Worker ctx.measureTimings = MeasureTiming::NO;
678*89c4ff92SAndroid Build Coastguard Worker auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
679*89c4ff92SAndroid Build Coastguard Worker
680*89c4ff92SAndroid Build Coastguard Worker auto errorStatus = ExecuteGraph(memPools,
681*89c4ff92SAndroid Build Coastguard Worker inputTensors,
682*89c4ff92SAndroid Build Coastguard Worker outputTensors,
683*89c4ff92SAndroid Build Coastguard Worker ctx);
684*89c4ff92SAndroid Build Coastguard Worker
685*89c4ff92SAndroid Build Coastguard Worker return errorStatus == ErrorStatus::NONE;
686*89c4ff92SAndroid Build Coastguard Worker }
687*89c4ff92SAndroid Build Coastguard Worker
688*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn_driver
689